# Semantic Segmentation with tf.data in TensorFlow 2 and ADE20K dataset

In this article I’m going to cover the usage of tensorflow 2 and tf.data on a popular semantic segmentation 2D images dataset: ADE20K. This code is now runnable on colab.

Update 20/04/26: Fix a bug in the Google Colab version (thanks to Agapetos!) and add few external links.
Update 20/04/25: Update the whole article to be easier to run the code. Also add a link to run it on Google Colab.

## 1. Introduction

In this notebook we are going to cover the usage of tensorflow 2 and tf.data on a popular semantic segmentation 2D images dataset: ADE20K.

The type of data we are going to manipulate consist in:

• an jpg image with 3 channels (RGB)
• a jpg mask with 1 channel (for each pixel we have 1 true class over 150 possible)

You can also find all the information by reading the official tensorflow tutorials:

This notebook assumes that you already downloaded ADE20k and extracted the content of the archive in ./data/ADEChallengeData2016/ with ./ meaning the directory where this notebook is running.
Also, if you run this notebook, a GPU is almost mandatory since the computations take A LOT of time on CPU.

## 2. Preparing the Environment

from glob import glob

import IPython.display as display
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import datetime, os
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from IPython.display import clear_output

# https://www.tensorflow.org/guide/data_performance#prefetching
AUTOTUNE = tf.data.experimental.AUTOTUNE
print(f"Tensorflow ver. {tf.__version__}")

Tensorflow ver. 2.1.0

# important for reproducibility
# this allows to generate the same random numbers
SEED = 42

# you can change this path to reflect your own settings if necessary
training_data = "training/"
val_data = "validation/"


By default, tensorflow uses 100% of the available GPU memory. It allows to do contiguous memory allocation with is potentially faster. You can deactivate this default behavior. I personally feel useful the fact that I can check how much memory is used considering the batch size.

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
print(e)

1 Physical GPUs, 1 Logical GPUs


# Image size that we are going to use
IMG_SIZE = 128
# Our images are RGB (3 channels)
N_CHANNELS = 3
# Scene Parsing has 150 classes + not labeled
N_CLASSES = 151


### 3.1. Creating a source dataset

TRAINSET_SIZE = len(glob(dataset_path + training_data + "*.jpg"))
print(f"The Training Dataset contains {TRAINSET_SIZE} images.")

VALSET_SIZE = len(glob(dataset_path + val_data + "*.jpg"))
print(f"The Validation Dataset contains {VALSET_SIZE} images.")

The Training Dataset contains 20210 images.
The Validation Dataset contains 2000 images.


For each images of our dataset, we will apply some operations wrapped into a function. Then we will map the whole dataset with this function.

So let’s write this function:

def parse_image(img_path: str) -> dict:
a dictionary.

Parameters
----------
img_path : str

Returns
-------
dict
Dictionary mapping an image and its annotation.
"""
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.uint8)

# For one Image path:
# Its corresponding annotation path is:
# The masks contain a class index for each pixels
# In scene parsing, "not labeled" = 255
# But it will mess up with our N_CLASS = 150
# Since 255 means the 255th class
# Which doesn't exist
# Note that we have to convert the new value (0)
# With the same dtype than the tensor itself



In case you would like to load any other image format, you should modify the parse_image function. TensorFlow I/O provide additional tools that might help you. For example, in the case of loading TIFF images, you can use:

import tensorflow as tf
import tensorflow.io as tfio
...
def parse_image(img_path: str) -> dict:
...
tfio.experimental.image.decode_tiff(image)
...


In this case, don’t forget to modify the number of channels when implementing the model later.

train_dataset = tf.data.Dataset.list_files(dataset_path + training_data + "*.jpg", seed=SEED)
train_dataset = train_dataset.map(parse_image)

val_dataset = tf.data.Dataset.list_files(dataset_path + val_data + "*.jpg", seed=SEED)
val_dataset =val_dataset.map(parse_image)


### 3.2. Applying some transformations to our dataset

# Here we are using the decorator @tf.function
# if you want to know more about it:
# https://www.tensorflow.org/api_docs/python/tf/function

@tf.function
def normalize(input_image: tf.Tensor, input_mask: tf.Tensor) -> tuple:
"""Rescale the pixel values of the images between 0.0 and 1.0
compared to [0,255] originally.

Parameters
----------
input_image : tf.Tensor
Tensorflow tensor containing an image of size [SIZE,SIZE,3].
Tensorflow tensor containing an annotation of size [SIZE,SIZE,1].

Returns
-------
tuple
Normalized image and its annotation.
"""
input_image = tf.cast(input_image, tf.float32) / 255.0

@tf.function
"""Apply some transformations to an input dictionary
containing a train image and its annotation.

Notes
-----
An annotation is a regular  channel image.
If a transformation such as rotation is applied to the image,
the same transformation has to be applied on the annotation also.

Parameters
----------
datapoint : dict
A dict containing an image and its annotation.

Returns
-------
tuple
A modified image and its annotation.
"""
input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))

if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)

@tf.function
"""Normalize and resize a test image and its annotation.

Notes
-----
Since this is for the test set, we don't need to apply
any data augmentation technique.

Parameters
----------
datapoint : dict
A dict containing an image and its annotation.

Returns
-------
tuple
A modified image and its annotation.
"""
input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))



Manipulating datasets in tensorflow can be complicated. You can read the official documentation to understand how they are working: https://www.tensorflow.org/guide/data#training_workflows

BATCH_SIZE = 32

# for reference about the BUFFER_SIZE in shuffle:
# https://stackoverflow.com/questions/46444018/meaning-of-buffer-size-in-dataset-map-dataset-prefetch-and-dataset-shuffle
BUFFER_SIZE = 1000

dataset = {"train": train_dataset, "val": val_dataset}

# -- Train Dataset --#
dataset['train'] = dataset['train'].shuffle(buffer_size=BUFFER_SIZE, seed=SEED)
dataset['train'] = dataset['train'].repeat()
dataset['train'] = dataset['train'].batch(BATCH_SIZE)
dataset['train'] = dataset['train'].prefetch(buffer_size=AUTOTUNE)

#-- Validation Dataset --#
dataset['val'] = dataset['val'].repeat()
dataset['val'] = dataset['val'].batch(BATCH_SIZE)
dataset['val'] = dataset['val'].prefetch(buffer_size=AUTOTUNE)

print(dataset['train'])
print(dataset['val'])

# how shuffle works: https://stackoverflow.com/a/53517848

<PrefetchDataset shapes: ((None, 128, 128, 3), (None, 128, 128, 1)), types: (tf.float32, tf.float32)>
<PrefetchDataset shapes: ((None, 128, 128, 3), (None, 128, 128, 1)), types: (tf.float32, tf.float32)>


## 4. Visualizing the Dataset

It seems everything is fine. It can be very hard to build your model by having bugs in your dataset. This makes the development process very painful since the potential bugs from your model are adding up to the potential bugs in your dataloaders. Therefore, it is recommended to make sure that you have what you expect.
For that, we are going to develop simple functions to visualize the content of our dataloaders.

def display_sample(display_list):
"""Show side-by-side an input image,
the ground truth and the prediction.
"""
plt.figure(figsize=(18, 18))

for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()

for image, mask in dataset['train'].take(1):



The dimensions are the ones we expect and it shows up properly. We can start the development of the model itself.

## 5. Developing the Model (UNet) Using Keras Functional API

For this example, we are going to implement a popular architecture: UNet. In a sense, it is not the best for a tutorial since this model is very heavy. But I found the exercise interesting. Especially because we are going to use the functional API provided by keras.

This architecture was introduced in the paper U-Net: Convolutional Networks for Biomedical Image Segmentation that you can read there: https://arxiv.org/abs/1505.04597

You can also read my notes on this paper there: https://yann-leguilly.gitlab.io/post/2019-12-11-unet-biomedical-images/
Basically we are going to reproduce this:

### 5.1. Implementation

# -- Keras Functional API -- #
# -- UNet Implementation -- #
# Everything here is from tensorflow.keras.layers
# I imported tensorflow.keras.layers * to make it easier to read
dropout_rate = 0.5
input_size = (IMG_SIZE, IMG_SIZE, N_CHANNELS)

# If you want to know more about why we are using he_normal:
# https://stats.stackexchange.com/questions/319323/whats-the-difference-between-variance-scaling-initializer-and-xavier-initialize/319849#319849
# Or the excellent fastai course:
# https://github.com/fastai/course-v3/blob/master/nbs/dl2/02b_initializing.ipynb
initializer = 'he_normal'

# -- Encoder -- #
# Block encoder 1
inputs = Input(shape=input_size)
conv_enc_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=initializer)(inputs)
conv_enc_1 = Conv2D(64, 3, activation = 'relu', padding='same', kernel_initializer=initializer)(conv_enc_1)

# Block encoder 2
max_pool_enc_2 = MaxPooling2D(pool_size=(2, 2))(conv_enc_1)
conv_enc_2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(max_pool_enc_2)
conv_enc_2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_enc_2)

# Block  encoder 3
max_pool_enc_3 = MaxPooling2D(pool_size=(2, 2))(conv_enc_2)
conv_enc_3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(max_pool_enc_3)
conv_enc_3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_enc_3)

# Block  encoder 4
max_pool_enc_4 = MaxPooling2D(pool_size=(2, 2))(conv_enc_3)
conv_enc_4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(max_pool_enc_4)
conv_enc_4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_enc_4)
# -- Encoder -- #

# ----------- #
maxpool = MaxPooling2D(pool_size=(2, 2))(conv_enc_4)
conv = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(maxpool)
conv = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv)
# ----------- #

# -- Decoder -- #
# Block decoder 1
up_dec_1 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv))
merge_dec_1 = concatenate([conv_enc_4, up_dec_1], axis = 3)
conv_dec_1 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge_dec_1)
conv_dec_1 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_dec_1)

# Block decoder 2
up_dec_2 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv_dec_1))
merge_dec_2 = concatenate([conv_enc_3, up_dec_2], axis = 3)
conv_dec_2 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge_dec_2)
conv_dec_2 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_dec_2)

# Block decoder 3
up_dec_3 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv_dec_2))
merge_dec_3 = concatenate([conv_enc_2, up_dec_3], axis = 3)
conv_dec_3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge_dec_3)
conv_dec_3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_dec_3)

# Block decoder 4
up_dec_4 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv_dec_3))
merge_dec_4 = concatenate([conv_enc_1, up_dec_4], axis = 3)
conv_dec_4 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge_dec_4)
conv_dec_4 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_dec_4)
conv_dec_4 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv_dec_4)
# -- Dencoder -- #

output = Conv2D(N_CLASSES, 1, activation = 'softmax')(conv_dec_4)


Now we can load and compile the model to make sure that there is no bugs.

model = tf.keras.Model(inputs = inputs, outputs = output)

model.compile(optimizer=Adam(learning_rate=0.0001), loss = tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])


### 5.2. Sanity Check

We can visualize a sample prediction to be sure that we see what we expect.

def create_mask(pred_mask: tf.Tensor) -> tf.Tensor:
"""Return a filter mask with the top 1 predictions
only.

Parameters
----------
A [IMG_SIZE, IMG_SIZE, N_CLASS] tensor. For each pixel we have
N_CLASS values (vector) which represents the probability of the pixel
being these classes. Example: A pixel with the vector [0.0, 0.0, 1.0]
has been predicted class 2 with a probability of 100%.

Returns
-------
tf.Tensor
A [IMG_SIZE, IMG_SIZE, 1] mask with top 1 predictions
for each pixels.
"""
# pred_mask -> [IMG_SIZE, SIZE, N_CLASS]
# 1 prediction for each class but we want the highest score only
# so we use argmax
# but matplotlib needs [IMG_SIZE, IMG_SIZE, 1]

def show_predictions(dataset=None, num=1):
"""Show a sample prediction.

Parameters
----------
dataset : [type], optional
[Input dataset, by default None
num : int, optional
Number of sample to show, by default 1
"""
if dataset:
else:
# The model is expecting a tensor of the size
# [BATCH_SIZE, IMG_SIZE, IMG_SIZE, 3]
# but sample_image[0] is [IMG_SIZE, IMG_SIZE, 3]
# and we want only 1 inference to be faster
one_img_batch = sample_image[0][tf.newaxis, ...]
# one_img_batch -> [1, IMG_SIZE, IMG_SIZE, 3]
inference = model.predict(one_img_batch)
# inference -> [1, IMG_SIZE, IMG_SIZE, N_CLASS]
# pred_mask -> [1, IMG_SIZE, IMG_SIZE, 1]

for image, mask in dataset['train'].take(1):

show_predictions()


The far right image is a prediction with random weights. We are ready to start the training!

## 6. Training our Model

### 6.1. Simpler training loop

Let’s run a simple training loop with only 1 epoch first.

EPOCHS = 1

STEPS_PER_EPOCH = TRAINSET_SIZE // BATCH_SIZE
VALIDATION_STEPS = VALSET_SIZE // BATCH_SIZE

# sometimes it can be very interesting to run some batches on cpu
# because the tracing is way better than on GPU
# you will have more obvious error message
# but in our case, it takes A LOT of time

# On CPU
# with tf.device("/cpu:0"):
#     model_history = model.fit(dataset['train'], epochs=EPOCHS,
#                               steps_per_epoch=STEPS_PER_EPOCH,
#                               validation_steps=VALIDATION_STEPS,
#                               validation_data=dataset['val'])

# On GPU
model_history = model.fit(dataset['train'], epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
validation_data=dataset['val'])

Train for 631 steps, validate for 62 steps
631/631 [==============================]
- 184s 292ms/step
- loss: 4.5188
- accuracy: 0.0957
- val_loss: 3.9329
- val_accuracy: 0.1289


Everything looks great! We can start our ‘real’ training.

### 6.2. More Advanced Training Loop

Keras implements what we call ‘callbacks’. We can use them to run custom functions at any step of the training. Here we are going to show the output of the model compared to the original image and the ground truth after each epochs. We are also going to collect some useful metrics to make sure our training is happening well by using tensorboard.

class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
clear_output(wait=True)
show_predictions()
print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

EPOCHS = 5

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

callbacks = [
# to show samples after each epoch
DisplayCallback(),
# to collect some useful metrics and visualize them in tensorboard
tensorboard_callback,
# if no accuracy improvements we can stop the training directly
tf.keras.callbacks.EarlyStopping(patience=10, verbose=1),
# to save checkpoints
tf.keras.callbacks.ModelCheckpoint('best_model_unet.h5', verbose=1, save_best_only=True, save_weights_only=True)
]

model = tf.keras.Model(inputs = inputs, outputs = output)

# here I'm using a new optimizer: https://arxiv.org/abs/1908.03265

loss = tf.keras.losses.SparseCategoricalCrossentropy()

model.compile(optimizer=optimizer, loss = loss,
metrics=['accuracy'])

model = train_model(model, optimizer, loss, callbacks)


Sample Prediction after epoch 5

Epoch 00005: val_loss improved from 2.76228 to 2.71097, saving model to best_model_unet.h5
631/631 [==============================]
- 189s 299ms/step
- loss: 2.7291
- accuracy: 0.3685
- val_loss: 2.7110
- val_accuracy: 0.3773


### 6.3. Monitoring the training

While your model is training, you want to monitor some metrics. In order to do this, tensorboard is a common tool (I will write another article about this topic later). I also feel useful to monitor your GPU(s)/CPU(s). I’m usually use nvtop to monitor my GPU and htop for the CPU.

## 7. Conclusion

This was a simple example of how to

• Simple training loop using model.fit()