Practical image segmentation with Unet

Practical image segmentation with Unet


In this post we will learn how Unet works, what it is used for and how to implement it. To do so we will use the original Unet paper, Pytorch and a Kaggle competition where Unet was massively used.
If you don't know anything about Pytorch, you are afraid of implementing a deep learning paper by yourself or you never participated to a Kaggle competition, this is the right post for you.
We won't follow the paper at 100% here, we will just implement the parts that we need and adapt the paper to our needs.

Problem presentation

Our problem is the one on this Kaggle competition.
Basically, given an image of a car and a mask, we want to create a model which will be able to automatically extract the image of the car from its background with a pixel-wise precision over 99%. See below:

The image on the left is the image of the car, in the middle its mask and on the right the mask applied to the car. We are given both the image of the car and the masks for the training set. We will use a Unet neural network which will learn how to automatically create the masks:

  1. By feeding into the neural net the images of the cars
  2. By using loss function comparing the output of the neural net with the appropriate masks and backpropagating the error for the network to learn.

Code structure

The code has been simplified at its maximum so that you can understand how it works just by looking at the file. We will go through each line of the code to explain how everything is glued together. Notice we are working on a special branch of the github repository (original_unet), this is not the branch which scored best on the competition for me but its goal is try to respect the original Unet paper as much as we can.

The code

We will iteratively go through the code in and the paper. Do not worry about the details hidden in the other files of the project, we will bring up some of them eventualy.
Lets start by the beginning:

def main():
    # Hyperparameters
    input_img_resize = (572, 572)  # The resize size of the input images of the neural net
    output_img_resize = (388, 388)  # The resize size of the output images of the neural net
    batch_size = 3
    epochs = 50
    threshold = 0.5
    validation_size = 0.2
    sample_size = None

    # -- Optional parameters
    threads = cpu_count()
    use_cuda = torch.cuda.is_available()
    script_dir = os.path.dirname(os.path.abspath(__file__))
    # Training callbacks
    tb_viz_cb = TensorboardVisualizerCallback(os.path.join(script_dir, '../logs/tb_viz'))
    tb_logs_cb = TensorboardLoggerCallback(os.path.join(script_dir, '../logs/tb_logs'))
    model_saver_cb = ModelSaverCallback(os.path.join(script_dir, '../output/models/model_' + helpers.get_model_timestamp()), verbose=True)

The first section is where you define your hyperparameters, you can adjust them as you like depending on your GPU memory for instance. The Optional parameters section define some useful parameters and callbacks. TensorboardVisualizerCallback is the class which will save the predictions to tensorboard at each epochs of the training step, TensorboardLoggerCallback will save the losses and pixel-wise "accuracies" to tensorboard. Finally ModelSaverCallback will save your model once the training step is finished.

# Download the datasets
ds_fetcher = DatasetFetcher()

This section automatically download and extract the dataset from Kaggle. Note that for this command to succeed you need to have a Kaggle account with login and password that you will put in KAGGLE_USER and KAGGLE_PASSWD variable environment before running the script. You also need to have accepted the rules of the competition here by clicking on the download button of a dataset. If you prefer downloading the data manually, put them in the input folder of the project, they won't be downloaded from Kaggle.

# Get the path to the files for the neural net
X_train, y_train, X_valid, y_valid = ds_fetcher.get_train_files(sample_size=sample_size, validation_size=validation_size)
full_x_test = ds_fetcher.get_test_files(sample_size)

Here we just split the training set into train/validation and retrieve the test set.

# Testing callbacks
pred_saver_cb = PredictionsSaverCallback(os.path.join(script_dir, '../output/submit.csv.gz'), origin_img_size, threshold)

This line is a callback for the test (or predict) pass. It will store the predictions into a gzip file each time a new batch of prediction is made. This way the predictions are not stored into memory as they are very big.
You can submit the resulting submit.csv.gz from the output folder to Kaggle when the predictions are finished.

# -- Define our neural net architecture
# The original paper has 1 input channel, in our case we have 3 (RGB)
net = unet_origin.UNetOriginal((3, *img_resize))
classifier = nn.classifier.CarvanaClassifier(net, epochs)
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.99)

train_ds = TrainImageDataset(X_train, y_train, input_img_resize, output_img_resize, X_transform=aug.augment_img)
train_loader = DataLoader(train_ds, batch_size,

valid_ds = TrainImageDataset(X_valid, y_valid, input_img_resize, output_img_resize, threshold=threshold)
valid_loader = DataLoader(valid_ds, batch_size,

Here we define our neural network net and our optimizer optimizer (more on that later) then create a loader for both the training and validation set which will load the data in batches.

print("Training on {} samples and validating on {} samples "
          .format(len(train_loader.dataset), len(valid_loader.dataset)))

# Train the classifier
classifier.train(train_loader, valid_loader, epochs, callbacks=[tb_viz_cb, tb_logs_cb, model_saver_cb])

Now we launch the training step by passing in the train/valid dataset loaders and the training callback we defined. We will go into some details of the implementation of this method later.

test_ds = TestImageDataset(full_x_test, img_resize)
test_loader = DataLoader(test_ds, batch_size,

# Predict & save
classifier.predict(test_loader, callbacks=[pred_saver_cb])

Finally we do the same as above but for the prediction pass and we call our pred_saver_cb.close_saver() to flush and close the file containing the predictions.

Designing the neural net

The Unet paper present itself as a way to do image segmentation for biomedical data. It turns out you can use it for various image segmentation problems such as the one we will work on.

Before going forward you should read the paper entirely at least once. Don't worry if you didn't get the mathematical formulas, you can skip them as well as the "Experiments" chapter, the idea being for you to get the big picture.

The problem of the original paper being different than this one, we'll need to adapt some parts to our needs.
At the time the paper was written there was 2 things missing which were necessary to accelerate the neural network convergence:

  1. BatchNorm
  2. Powerful GPUs

The first one was invented only 3 months before Unet and it was probably too early for the Unet authors to add it to their paper.
As of today BatchNorm is used pretty much everywhere. You can get rid of it in the code if you want to respect the paper at 100%, but you'll take ages to reach convergeance.

As for the GPUs, the paper states:

To minimize the overhead and make maximum use of the GPU memory, we favor large input tiles over a large batch size and hence reduce the batch to a single image

They were using a GPU with 6gb of VRAM but nowadays GPU have more memory to fit more images into a single batch. The current batch size of 3 works for a GPU with at least 8gb of VRAM. If you don't have such GPU, try lowering the batch size to 2 or 1.
As for the augmentation methods seen in the paper we will also use our own as the images are different of biomedical images.
Now lets start by the beginning, designing the neural network architecture:


This is what a Unet looks like. You can find the equivalent Pytorch implementation into the module.
All the classes in this file have at least 2 methods:

  • __init__() where we will initialize our neural network layers
  • forward() which is the method called when the neural network is receiving an input

Lets go into the details of the implementation:

  • ConvBnRelu is a shortcut containing a Conv2D, a BatchNorm and Relu operation. Instead of writting them 3 for each encoder stack (the group of operations going down) and decoder stacks (the group of operations going up), we group them up into this object and reuse it as we need it.
  • StackEncoder is encapsuling an entire "stack" of the down operations including the ConvBnRelu operations and a MaxPool as illustrated below:


We keep track of the output of the last ConvBnRelu operation in x_trace and return it because we will concatenate this output with the decoder stacks.

  • StackDecoder is the same as StackEncoder, but for the decoding operations as surrounded below in red:


Notice it take into account the cropping/concatenation operation (surrounded in orange) by passing in a down_tensor which is nothing else than the x_trace tensor returned by our StackEncoder.

  • UNetOriginal is where the magic happens. It is our neural network which will take into account all the little bricks presented above. The init and forward methods are really straigthforward, they adds a bunch of StackEncoder, a center piece and finally few StackDecoder. We then get the output of the StackDecoder, add a 1x1 convolution to it according to the paper but instead of defining 2 filters as output, we only define 1 which will actually be our grayscale mask prediction. Finally we "squeeze" our output to remove the channel dimension (there is only 1 so we don't need to keep it).

If you want to go into the details of each forward pass, put a debugging breakpoint onto the forward method of each class to see in details the objects. You can also print the shape of your output tensors between the layers by doing print(x.size()).

Training the neural net

1. The loss function

Now onto the real stuff. According to the paper:

The energy function is computed by a pixel-wise soft-max over the final feature map combined with the cross entropy loss function.

Thing is, in our case we want to use the dice coefficient as loss function instead what they call the "energy function" as it is the metric used in the Kaggle competition which is defined by: $$\frac{2 * |X \cap Y|}{|X| + |Y|}$$

With \(X\) being our prediction matrix and \(Y\) our target matrix. \(|X|\) stands for the cardinality of the set \(X\) (the number of elements in this set) and \(\cap\) for the intersection between \(X\) and \(Y\).

The code for the dice loss can be found in nn.losses.SoftDiceLoss.

class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()

    def forward(self, logits, targets):
        smooth = 1
        num = targets.size(0)
        probs = F.sigmoid(logits)
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)

        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score

The reason why intersection is implemented as a multiplication and the cardinality as sum() on axis 1 (each 3 channels sum) is because predictions and targets are one-hot encoded vectors.
For example lets say the prediction on pixel (0, 0) is 0.567 and the target is 1, we get 0.567 * 1 = 0.567. If the target was 0, we get 0 on that pixel position.
We also used a smooth factor of 1 for backpropagation. If the prediction is a hard threshold to 0 and 1, it is difficult to back propagate the dice loss.

We will then combine this dice loss with the cross entropy to get our total loss function that you can find in the _criterion method from nn.Classifier.CarvanaClassifier. According to the paper they also use a weight map in the cross entropy loss function to give some pixels more importance during the training. In our case we don't need such thing so we will just use cross entropy without any weight map.

2. The optimizer

Here we'll try to respect the paper by using the SGD optimizer and a momentum of 0.99. You can find the optimizer in the main method:

optimizer = optim.SGD(, lr=0.01, momentum=0.99)

That's all we need to do for the optimizer.

3. Augmentations

As we are not dealing with biomedical images we'll use our own augmentations. You can find the code in img.augmentation.augment_img. There we do random shifting, rotation, flipping and scaling.

Training the neural net

Now we can launch the training. As you go over each epochs you'll be able to visualize how well your model creates the segmentation of the validation set.
To do so you'll need to run Tensorboard in the logs folder with:

tensorboard --logdir=./logs

Here is a preview of what you can see on Tensorboard at epoch 1:

and at epoch 50:

After training over 50 epochs we get a pixel-wise precision of about 95-96%. It's much better than our first epoch but it's still imperfect and we cannot rely on this to automate the task of image segmentation made by humans.

Of course I promised you a pixel-wise precision over 99%, but we were not able to get over 95% here. Well, let me tell you something:

What we will do now is to use a custom Unet that you can find in nn.unet.UNet1024. I won't go into the details of the implementation of this architecture as it is pretty similar to our original Unet with some modifications. All you need to do to use it is to modify the file with the following changes:

input_img_resize = (572, 572)
output_img_resize = (388, 388)


input_img_resize = (1024, 1024)
output_img_resize = (1024, 1024)


net = unet_origin.UNetOriginal((3, *input_img_resize))


net = unet_custom.UNet1024((3, *input_img_resize))

You can also change the optimizer from SGD to RMSProp from:

optimizer = optim.SGD(, lr=0.01, momentum=0.99)


optimizer = optim.RMSprop(, lr=0.0002)

Then launch the training back again for 50 epochs and a batch size of 2 (or lower it to 1 if you don't have enough VRAM).
This new architecture takes inputs with a higher resolution which allows the Unet to learn more representations.
For instance here is what the epoch 24 looks like after these changes:

At a 0.995 pixel-wise accuracy it looks much better isn't it? Now you can try to submit the resulting submit.csv.gz on the competition page to see how you score on the leaderboard. Of course you can try to tweak the model by yourself to reach even better performance (and less variance) by changing the optimizer/epochs/architecture. You can find the winners solution in the discussion page of the competition.


We saw in this tutorial how to create a Unet for image segmentation. Although there exist a plenty of other methods for to do this, Unet is very powerful for these kind of tasks. We'll probably explore more techniques for image segmentation in the future, stay tuned!

Show Comments