Image Segmentation with Python

Introduction to image segmentation

In this article we look at an interesting data problem – making decisions about the algorithms used for image segmentation, or separating one qualitatively different part of an image from another.

Example code for this article may be found at the Kite Github repository. We have provided tips on how to use the code throughout.

As our example, we work through the process of differentiating vascular tissue in images, produced by Knife-edge Scanning Microscopy (KESM). While this may seem like a specialized use-case, there are far-reaching implications, especially regarding preparatory steps for statistical analysis and machine learning.

Data scientists and medical researchers alike could use this approach as a template for any complex, image-based data set (such as astronomical data), or even large sets of non-image data. After all, images are ultimately matrices of values, and we’re lucky to have an expert-sorted data set to use as ground truth. In this process, we’re going to expose and describe several tools available via image processing and scientific Python packages (opencv, scikit-image, and scikit-learn). We’ll also make heavy use of the numpy library to ensure consistent storage of values in memory.

The procedures we’ll explore could be used for any number of statistical or supervised machine learning problems, as there are a large number of ground truth data points. In order to choose our image segmentation algorithm and approach, we will demonstrate how to visualize the confusion matrix, using matplotlib to colorize where the algorithm was right and where it was wrong. In early stages, it’s more useful for a human to be able to clearly visualize the results than to aggregate them into a few abstract numerals.

Approach

Cleaning

To remove noise, we use a simple median filter to remove the outliers, but one can use a different noise removal approach or artifact removal approach. The artifacts vary across acquisition systems (microscopy techniques) and may require complicated algorithms to restore the missing data. Artifacts commonly fall into two categories:

  • blurry or out-of-focus areas
  • imbalanced foreground and background (correct with histogram modification)

Segmentation

For this article, we limit segmentation to Otsu’s approach, after smoothing an image using a median filter, followed by validation of results. You can use the same validation approach for any segmentation algorithm, as long as the segmentation result is binary. These algorithms include, but are not limited to, various Circular Thresholding approaches that consider different color space.

Some examples are:

  • Li Thresholding
  • An adaptive thresholding method that is dependent on local intensity
  • Deep learning algorithms like UNet used commonly in biomedical image segmentation
  • Deep learning approaches that semantically segment an image

Validation

We begin with a ground truth data set, which has already been manually segmented. To quantify the performance of a segmentation algorithm, we compare ground truth with the predicted binary segmentation, showing accuracy alongside more effective metrics. Accuracy can be abnormally high despite a low number of true positives (TP) or false negatives (FN). In such cases, F1 Score and MCC are better quantification metrics for the binary classification.We’ll go into detail on the pros and cons of these metrics later.

For qualitative validation, we overlay the confusion matrix results i.e where exactly the true positives, true negatives, false positives, false negatives pixels are onto the grayscale image. This validation can also be applied to a color image on a binary image segmentation result, although the data we used in this article is a grayscale image. In the end, we will present the whole process so that you can see the results for yourself.  Now, let’s look at the data–and the tools used to process that data.

Loading and visualizing data

We will use the below modules to load, visualize, and transform the data. These are useful for image processing and computer vision algorithms, with simple and complex array mathematics. The module names in parentheses will help if installing individually.

ModuleReason
numpyHistogram calculation, array math, and equality testing
matplotlibGraph plotting and Image visualization
scipyImage reading and median filter
cv2 (opencv-python)Alpha compositing to combine two images
skimage (scikit-image)Image thresholding
sklearn (scikit-learn)Binary classifier confusion matrix
noseTesting

Displaying Plots Sidebar: If you are running the example code in sections from the command line, or experience issues with the matplotlib backend, disable interactive mode by removing the plt.ion() call, and instead call plt.show() at the end of each section, by uncommenting suggested calls in the example code. Either ‘Agg’ or ‘TkAgg’ will serve as a backend for image display. Plots will be displayed as they appear in the article.

Importing modules

import cv2
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
import scipy.ndimage
import skimage.filters
import sklearn.metrics

# Turn on interactive mode. Turn off with plt.ioff()
plt.ion()

In this section, we load and visualize the data. The data is an image of mouse brain tissue stained with India ink,  generated by Knife-Edge Scanning Microscopy (KESM). This 512 x 512 image is a subset, referred to as a tile. The full data set is 17480 x 8026 pixels, 799 slices in depth, and 10gb in size. So, we will write algorithms to process the tile of size 512 x 512 which is only 150 KB.

Individual tiles can be mapped to run on multi processing/multi threaded (i.e. distributed infrastructure), and then stitched back together to obtain the full segmented image. The specific stitching method is not demonstrated here. Briefly, stitching involves indexing the full matrix and putting the tiles back together according to this index. For combining numerical values, you can use map-reduce. Map-Reduce yields metrics such as the sum of all the F1 scores along all tiles, which you can then average. Simply append the results to a list, and then perform your own statistical summary.

The dark circular/elliptical disks on the left are vessels and the rest is the tissue. So, our two classes in this dataset are:

  • foreground (vessels) – labeled as 255
  • background (tissue) – labeled as 0

The last image on the right below is the ground truth image. Vessels are traced manually by drawing up contours and filling them to obtain the ground truth by a board-certified pathologist. We can use several examples like these from experts to train supervised deep learning networks and validate them on a larger scale. We can also augment the data by giving these examples to crowdsourced platforms and training them to manually trace a different set of images on a larger scale for validation and training. The image in the middle is just an inverted grayscale image, which corresponds with the ground truth binary image.

Loading and visualizing images in figure above

grayscale = scipy.misc.imread('grayscale.png')
grayscale = 255 - grayscale
groundtruth = scipy.misc.imread('groundtruth.png')
plt.subplot(1, 3, 1)
plt.imshow(255 - grayscale, cmap='gray')
plt.title('grayscale')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(grayscale, cmap='gray')
plt.title('inverted grayscale')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(groundtruth, cmap='gray')
plt.title('groundtruth binary')
plt.axis('off')

Pre-processing

Before segmenting the data, you should go through the dataset thoroughly to determine if there are any artifacts due to the imaging system. In this example, we only have one image in question. By looking at the image, we can see that there aren’t any noticeable artifacts that would interfere with the segmentation. However, you can remove outlier noise and smooth an image using a median filter. A median filter replaces the outliers with the median (within a kernel of a given size).

Median filter of kernel size 3

median_filtered = scipy.ndimage.median_filter(grayscale, size=3)
plt.imshow(median_filtered, cmap='gray')
plt.axis('off')
plt.title('median filtered image')

To determine which thresholding technique is best for segmentation, you could start by thresholding to determine if there is a distinct pixel intensity that separates the two classes. In such cases, you can use that intensity obtained by the visual inspection to binarize the image. In our case, there seem to be a lot of pixels with intensities of less than 50 which correspond to the background class in the inverted grayscale image.

Although the distribution of the classes is not bimodal (having two distinct peaks), it still has a distinction between foreground and background, which is where the lower intensity pixels peak and then hit a valley. This exact value can be obtained by various thresholding techniques. The segmentation section examines one such method in detail.

Visualize histogram of the pixel intensities

counts, vals = np.histogram(grayscale, bins=range(2 ** 8))
plt.plot(range(0, (2 ** 8) - 1), counts)
plt.title('Grayscale image histogram')
plt.xlabel('Pixel intensity')
plt.ylabel('Count')

Segmentation

After removing noise, you can apply the skimage filters module to try all thresholds to explore which thresholding methods fare well. Sometimes, in an image, a histogram of its pixel intensities is not bimodal. So, there might be another thresholding method that can fare better like an adaptive thresholding method that does thresholding based on local pixel intensities within a kernel shape. It’s good to see what the different thresholding methods results are, and skimage.filters.thresholding.try_all_threshold() is handy for that.

Try all thresholding method

result = skimage.filters.thresholding.try_all_threshold(median_filtered)

The simplest thresholding approach uses a manually set threshold for an image. On the other hand, using an automated threshold method on an image calculates its numerical value better than the human eye and may be easily replicated. For our image in this example, it seems like Otsu, Yen, and the Triangle method are performing well. The other results for this case are noticeably worse.

We’ll use the Otsu thresholding to segment our image into a binary image for this article. Otsu calculates thresholds by calculating a value that maximizes inter-class variance (variance between foreground and background) and minimizes intra-class variance (variance within foreground or variance within background).  It does well if there is either a bimodal histogram (with two distinct peaks) or a threshold value that separates classes better.

Otsu thresholding and visualization

threshold = skimage.filters.threshold_otsu(median_filtered)
print('Threshold value is {}'.format(threshold))
predicted = np.uint8(median_filtered > threshold) * 255
plt.imshow(predicted, cmap='gray')
plt.axis('off')
plt.title('otsu predicted binary image')

If the above simple techniques don’t serve the purpose for binary segmentation of the image, then one can use UNet, ResNet with FCN or various other supervised deep learning techniques to segment the images. To remove small objects due to the segmented foreground noise, you may also consider trying skimage.morphology.remove_objects().

Validation

In any of the cases, we need the ground truth to be manually generated by a human with expertise in the image type to validate the accuracy and other metrics to see how well the image is segmented.

The confusion matrix

We use sklearn.metrics.confusion_matrix() to get the confusion matrix elements as shown below. Scikit-learn confusion matrix function returns 4 elements of the confusion matrix, given that the input is a list of elements with binary elements. For edge cases where everything is one binary value(0) or other(1), sklearn returns only one element. We wrap the sklearn confusion matrix function and write our own with these edge cases covered as below:

get_confusion_matrix_elements()

def get_confusion_matrix_elements(groundtruth_list, predicted_list):
    """returns confusion matrix elements i.e TN, FP, FN, TP as floats
	See example code for helper function definitions
    """
    _assert_valid_lists(groundtruth_list, predicted_list)

    if _all_class_1_predicted_as_class_1(groundtruth_list, predicted_list) is True:
        tn, fp, fn, tp = 0, 0, 0, np.float64(len(groundtruth_list))

    elif _all_class_0_predicted_as_class_0(groundtruth_list, predicted_list) is True:
        tn, fp, fn, tp = np.float64(len(groundtruth_list)), 0, 0, 0

    else:
        tn, fp, fn, tp = sklearn.metrics.confusion_matrix(groundtruth_list, predicted_list).ravel()
        tn, fp, fn, tp = np.float64(tn), np.float64(fp), np.float64(fn), np.float64(tp)

    return tn, fp, fn, tp

Accuracy

Accuracy is a common validation metric in case of binary classification. It is calculated as

accuracy

where TP  = True Positive, TN = True Negative, FP = False Positive, FN = False Negative

get_accuracy()

def get_accuracy(groundtruth_list, predicted_list):

    tn, fp, fn, tp = get_confusion_matrix_elements(groundtruth_list, predicted_list)
    
    total = tp + fp + fn + tn
    accuracy = (tp + tn) / total
    
    return accuracy

It varies between 0 to 1, with 0 being the worst and 1 being the best. If an algorithm detects everything as either entirely background or foreground, there would still be a high accuracy. Hence we need a metric that considers the imbalance in class count. Especially since the current image has more foreground pixels(class 1) than background 0.

F1 score

The F1 score varies from 0 to 1 and is calculated as:

F1 Score

with 0 being the worst and 1 being the best prediction. Now let’s handle F1 score calculation considering edge cases.

get_f1_score()

def get_f1_score(groundtruth_list, predicted_list):
    """Return f1 score covering edge cases"""

    tn, fp, fn, tp = get_confusion_matrix_elements(groundtruth_list, predicted_list)
    
    if _all_class_0_predicted_as_class_0(groundtruth_list, predicted_list) is True:
        f1_score = 1
    elif _all_class_1_predicted_as_class_1(groundtruth_list, predicted_list) is True:
        f1_score = 1
    else:
        f1_score = (2 * tp) / ((2 * tp) + fp + fn)

    return f1_score

An F1 score of above 0.8 is considered a good F1 score indicating prediction is doing well.

MCC

MCC stands for Matthews Correlation Coefficient, and is calculated as:

mcc equation

It lies between -1 and +1. -1 is absolutely an opposite correlation between ground truth and predicted, 0 is a random result where some predictions match and +1 is where absolutely everything matches between ground and prediction resulting in positive correlation. Hence we need better validation metrics like MCC.

In MCC calculation, the numerator consists of just the four inner cells (cross product of the elements) while the denominator consists of the four outer cells (dot product of the) of the confusion matrix. In the case where the denominator is 0, MCC would then be able to notice that your classifier is going in the wrong direction, and it would notify you by setting it to the undefined value (i.e. numpy.nan). But, for the purpose of getting valid values, and being able to average the MCC over different images if necessary, we set the MCC to -1, the worst possible value within the range. Other edge cases include all elements correctly detected as foreground and background with MCC and F1 score set to 1. Otherwise, MCC is set to -1 and F1 score is 0.

To learn more about MCC and the edge cases, this is a good article. To understand why MCC is better than accuracy or F1 score more in detail, Wikipedia does good work here.

get_mcc()

def get_mcc(groundtruth_list, predicted_list):
    """Return mcc covering edge cases"""   

    tn, fp, fn, tp = get_confusion_matrix_elements(groundtruth_list, predicted_list)
    
    if _all_class_0_predicted_as_class_0(groundtruth_list, predicted_list) is True:
        mcc = 1
    elif _all_class_1_predicted_as_class_1(groundtruth_list, predicted_list) is True:
        mcc = 1
    elif _all_class_1_predicted_as_class_0(groundtruth_list, predicted_list) is True:
        mcc = -1
    elif _all_class_0_predicted_as_class_1(groundtruth_list, predicted_list) is True :
        mcc = -1

    elif _mcc_denominator_zero(tn, fp, fn, tp) is True:
        mcc = -1

    # Finally calculate MCC
    else:
        mcc = ((tp * tn) - (fp * fn)) / (
            np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)))
    
    return mcc

Finally, we can compare the validation metrics by result, side-by-side.

>>> validation_metrics = get_validation_metrics(groundtruth, predicted)

{'mcc': 0.8533910225863214, 'f1_score': 0.8493358633776091, 'tp': 5595.0, 'fn': 1863.0, 'fp': 122.0, 'accuracy': 0.9924278259277344, 'tn': 254564.0}

Accuracy is close to 1, as we have a lot of background pixels in our example image that are correctly detected as background (i.e. true negatives are are naturally higher). This shows why accuracy isn’t a good measure for binary classification.

F1 score is 0.84. So, in this case, we probably don’t need a more sophisticated thresholding algorithm for binary segmentation. If all the images in the stack had similar histogram distribution and noise, then we could use Otsu and have satisfactory prediction results.

The MCC of 0.85 is high, also indicating the ground truth and predicted image have a high correlation, clearly seen from the predicted image picture from the previous section.

Now, let’s visualize and see where the confusion matrix elements TP, FP, FN, TN are distributed along the image. It shows us where the threshold is picking up foreground (vessels) when they are not present (FP) and where true vessels are not detected (FN), and vice-versa.

Validation visualization

To visualize confusion matrix elements, we figure out exactly where in the image the confusion matrix elements fall. For example, we find the TP array (i.e. pixels correctly detected as foreground) is by finding the logical “and” of the ground truth and the predicted array. Similarly, we use logical boolean operations commonly called as Bit blit to find the FP, FN, TN arrays.

get_confusion_matrix_intersection_mats()

def get_confusion_matrix_intersection_mats(groundtruth, predicted):
    """ Returns dict of 4 boolean numpy arrays with True at TP, FP, FN, TN
    """

    confusion_matrix_arrs = {}

    groundtruth_inverse = np.logical_not(groundtruth)
    predicted_inverse = np.logical_not(predicted)

    confusion_matrix_arrs['tp'] = np.logical_and(groundtruth, predicted)
    confusion_matrix_arrs['tn'] = np.logical_and(groundtruth_inverse, predicted_inverse)
    confusion_matrix_arrs['fp'] = np.logical_and(groundtruth_inverse, predicted)
    confusion_matrix_arrs['fn'] = np.logical_and(groundtruth, predicted_inverse)

    return confusion_matrix_arrs

Then, we can map pixels in each of these arrays to different colors. For the figure below we mapped TP, FP, FN, TN to the CMYK (Cyan, Magenta, Yellow, Black) space. One could similarly also map them to (Green, Red, Red, Green) colors. We would then get an image where everything in red signifies the incorrect predictions. The CMYK space allows us to distinguish between TP, TN.

get_confusion_matrix_overlaid_mask()

def get_confusion_matrix_overlaid_mask(image, groundtruth, predicted, alpha, colors):
    """
    Returns overlay the 'image' with a color mask where TP, FP, FN, TN are
    each a color given by the 'colors' dictionary
    """
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    masks = get_confusion_matrix_intersection_mats(groundtruth, predicted)
    color_mask = np.zeros_like(image)
    for label, mask in masks.items():
        color = colors[label]
        mask_rgb = np.zeros_like(image)
        mask_rgb[mask != 0] = color
        color_mask += mask_rgb
    return cv2.addWeighted(image, alpha, color_mask, 1 - alpha, 0)

alpha = 0.5
confusion_matrix_colors = {
   'tp': (0, 255, 255),  #cyan
   'fp': (255, 0, 255),  #magenta
   'fn': (255, 255, 0),  #yellow
   'tn': (0, 0, 0)     #black
   }
validation_mask = get_confusion_matrix_overlaid_mask(255 - grayscale, groundtruth, predicted, alpha, confusion_matrix_colors)
print('Cyan - TP')
print('Magenta - FP')
print('Yellow - FN')
print('Black - TN')
plt.imshow(validation_mask)
plt.axis('off')
plt.title('confusion matrix overlay mask')

We use opencv here to overlay this color mask onto the original (non-inverted) grayscale image as a transparent layer. This is called Alpha compositing:

Final notes

The last two examples in the repository are testing the edge cases and a random prediction scenario on a small array (fewer than 10 elements), by calling the test functions. It is important to test for edge cases and potential issues if we are writing production level code, or just to test the simple logic of an algorithm.

Travis CI is very useful for testing whether your code works on the module versions described in your requirements, and if all the tests pass as new changes are merged into master. Keeping your code clean, well documented, and with all statements unit tested and covered is a best practice. These habits limit the need to chase down bugs, when a complex algorithm is built on top of simple functional pieces that could have been unit tested. Generally, documentation and unit testing helps others stay informed about your intentions for a function. Linting helps improve readability of the code, and flake8 is good Python package for that.

Here are the important takeaways from this article:

  1. Tiling and stitching approach for data that doesn’t fit in memory
  2. Trying different thresholding techniques
  3. Subtleties of Validation Metrics
  4. Validation visualization
  5. Best Practices

There are many directions you could go from here with your work or projects. Applying the same strategy to different data sets, or automating the validation selection approach would be excellent places to start. Further, imagine you needed to analyze a database with many of these 10gb files. How could you automate the process? How could you validate and justify the results to human beings? How does better analysis improve the outcomes of real-world scenarios (like the development of surgical procedures and medicine)? Asking questions like these will allow continued improvements in Statistics, Data Science, and Machine Learning.

Finally, Thanks to Navid Farahani for annotations, Katherine Scott for the guidance, Allen Teplitsky for the motivation, and all of the 3Scan team for the data.


This article was originally published at Kite‘s blog and republished here as part of content partnership program. Kite is a plugin for your IDE that uses machine learning to give you useful code completions for Python. Available for Atom, PyCharm, Sublime, VS Code, and Vim.

Leave a Reply

Your email address will not be published. Required fields are marked *