Tutorial 7 - Explainability & Geometric DL - April 2, 2026#

This tutorial is divided in two parts:

  1. In the first part the goal will be to train and explain a network trained to detect Covid-19 from chest X-rays.

  2. In the second part you will look at the equivariance properties of a U-Net for segmentation.

Part 1 - Explain Covid-19 diagnosis#

As Covid-19 became a hot topic in healthcare research, a large amount of medical images has been made publicly available to study the disease, and a lot of machine learning studies using these images were published. This tutorial was inspired by the methodology described in (DeGrave et al., 2021), in which the authors combined several public data sets to train and explain a network learning to detect Covid-19 from chest X-rays.

More precisely the data set for this part is made of two data sets:

  1. Covid-19 images are from a public data set available on GitHub,

  2. For normal participants, we reused the images of Tutorial 3 on rib segmentation.

Images were all resized to the same shape, then the original shape of the participant has been modified.

The main goal of this tutorial is to interpret a deep learning network trained to detect Covid-19 from chest X-rays. Then the network is a classifier which learns to find the correct diagnosis (“normal” or “covid”) associated to an input image.

from os import path, listdir, remove
import monai
import numpy as np
import pandas as pd
import torch

Data set management#

Download the data, unzip it, and set the data path:

!wget https://surfdrive.surf.nl/files/index.php/s/nku51iYxnr8ue7c/download -O Tutorial_6.zip
!unzip -qo Tutorial_6.zip
data_path = "Tutorial_6"

In the same way as in Tutorial 3, we will use monai.data.CacheDataset to build our data set from:

  • the list of samples built by build_sample_list (choose the mode to access train, validation or test data),

  • a composition of transforms that will be applied to our image. The first transform that must be applied is LoadChestData, which will load images according to the paths.

The main difference with Tutorial 3 is the structure of our samples. Before, a sample included an image ('img') and a mask ('mask'). As we are now performing a classification task, our sample (after the application of LoadChestData) will now contain an image ('img') and a label ('label'). This label is an integer value:

  • 0 corresponds to “normal” diagnosis,

  • 1 corresponds to “covid” diagnosis.

def build_sample_list(data_path, mode="train"):
    """
    This function creates a list containing all the samples of a subset.        
        
    Args:
        data_path (str): path to the root folder of the data set.
        mode (str): subset that must be loaded. Must be chosen in ["train", "val", "test"].
        
    Returns:
        (List[Dict[str, str]]) list of all samples of the data set. 
        One sample is a dictionary with the following keys:
            - img_path (str): path to the image file.
            - idx (str): unique index allowing to identify an individual image (associated to the diagnosis).
            - diagnosis (str): value of the diagnosis, "covid" or "normal".
    """
    
    possible_modes = ["train", "val", "test"]
    
    if mode not in possible_modes:
        raise ValueError(f"Please choose a mode in {possible_modes}.\n"
                         f"Current mode is {mode}.")
    
    data_path = path.join(data_path, mode)
    file_name_list = [file_name for file_name in listdir(data_path) if not file_name.startswith(".")]
    sample_dict_list = list()
    
    for file_name in file_name_list:
        keys_str = path.splitext(file_name)[0]
        keys = {
            pair_str.split("-")[0]: pair_str.split("-")[1] 
            for pair_str in keys_str.split("_")
        }
        keys["img_path"] = path.join(data_path, file_name)
        sample_dict_list.append(keys)
    return sample_dict_list


class LoadChestData(monai.transforms.Transform):
    """
    This transform loads the image and computes the label corresponding to a sample computed by `build_sample_list`.
    After the transform, the sample includes three new keys:
        - img (Tensor): corresponds to the chest X-ray image.
        - label (int): is the code corresponding to the diagnosis.
        - img_meta_dict (dict): includes meta-data that may be useful to apply some transforms of Monai.
    """
    def __init__(self):
        self.label_code = {"normal": 0, "covid": 1}

    def __call__(self, sample):
        from PIL import Image
        
        image = Image.open(sample["img_path"]).convert('L') # import as grayscale image
        image = np.array(image, dtype=np.uint8)
        label = self.label_code[sample["diagnosis"]]
        sample.update({
            "img": torch.from_numpy(image).unsqueeze(0).float(), 
            "label": label,
            "img_meta_dict": {"affine": np.eye(2)},
            
        })
        return sample

Exercise

Use build_sample_list and LoadChestData to build the training set.

Exercise

Describe your data. Are the classes balanced? Are there differences between train, validation and test data? Add any relevant information to your answer.

You can use visualize_sample to see the images in a data set.

def visualize_sample(sample):
    """
    Plot the chest X-ray image included in a sample transformed by `LoadChestData`.
    """
    import matplotlib.pyplot as plt
    
    if not isinstance(sample, dict):
        raise ValueError(f"Sample should be a dictionary. Current type is {type(sample)}")
    
    # Visualize the x-ray and describe the sample in title
    image = np.squeeze(sample['img'])
    plt.imshow(image, 'gray')
    plt.title(f"Image #{sample['idx']} associated with {sample['diagnosis']} diagnosis")
    
    plt.show()

After this analysis you may want to add more transforms to your CacheDataset to perform data augmentation. Use monai.transforms.Compose to add all the transforms you want!

Exercise

Complete the cell below and add additional transformations to your training data to perform data augmentation.

Resizing: Your teachers already cropped or resized the images, so they all have the same size (512x512). At the end of the tutorial, this information is hard-coded in some functions, so if you change the size of the images now, TERRIBLE THINGS MAY HAPPEN LATER.

You may change the batch size if you want. However, even though a larger batch size will accelerate the training process, at some point it may also deteriorate the performance.

train_transforms = monai.transforms.Compose(
    [
        LoadChestData(),
        # Add other transforms here if you want
    ]
)

train_dataset = monai.data.CacheDataset(build_sample_list(data_path, mode="train"), transform=train_transforms)
train_loader = monai.data.DataLoader(train_dataset, batch_size=8, shuffle=True)

validation_dataset = monai.data.CacheDataset(build_sample_list(data_path, mode="validation"), transform=LoadChestData())
validation_loader = monai.data.DataLoader(validation_dataset, batch_size=16, shuffle=False)

Train a classifier#

It is now time to train a classifier to perform a binary classification task: covid VS normal images. Check that you are working on a GPU by running the following cell:

  • if the device is “cuda” you are working on a GPU,

  • if the device is “cpu” call a teacher.

# Check whether we're using a GPU
if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()  # Total number of GPUs
    gpu_idx = random.randint(0, n_gpus - 1)  # Random GPU index
    device = torch.device(f'cuda:{gpu_idx}')
    print('Using GPU: {}'.format(device))
else:
    device = torch.device('cpu')
    print('GPU not found. Using CPU.')

We will again use Weight & Biases to log all our results. As we want to log our transformation and because Weight & Biases only log simple objects, we provide again from_compose_to_list to write your transforms in your config file.

import wandb

wandb.login()
def from_compose_to_list(transform_compose):
    """
    Transform an object monai.transforms.Compose in a list fully describing the transform.
    /!\ Random seed is not saved, then reproducibility is not enabled.
    """
    from copy import deepcopy
        
    if not isinstance(transform_compose, monai.transforms.Compose):
        raise TypeError("transform_compose should be a monai.transforms.Compose object.")
    
    output_list = list()
    for transform in transform_compose.transforms:
        kwargs = deepcopy(vars(transform))
        
        # Remove attributes which are not arguments
        args = list(transform.__init__.__code__.co_varnames[1: transform.__init__.__code__.co_argcount])
        for key, obj in vars(transform).items():
            if key not in args:
                del kwargs[key]

        output_list.append({"class": transform.__class__, "kwargs": kwargs})
    return output_list

In this tutorial, we work with the Classifier of Monai. This network includes a series of convolutional layers and ends with a fully-connected layer computing two values:

  • the first value corresponds to the prediction for “normal”,

  • the second value corresponds to the prediction for “covid”.

Then the final prediction of the network will correspond to the node with the highest value.

model = monai.networks.nets.Classifier(
    in_shape=train_dataset[0]["img"].shape,
    classes=2,
    channels=[16, 32, 64, 128, 128, 128],
    strides=[2, 2, 2, 2, 2],
    num_res_units=0
).to(device)

As the goal of this tutorial is to interpret a neural network, the training loop is already ready to be used. You can of course try to improve your results by changing your settings!

from tqdm.notebook import tqdm

# Set your parameters here
learning_rate = 1e-4
epochs = 20

# Set the loss function and the optimizer
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

run = wandb.init(
    project='tutorial6_explainability',
    name='covid_detection',
    config={
        'loss function': str(loss_function), 
        'lr': learning_rate,
        'transform': from_compose_to_list(train_transforms),
        'batch_size': train_loader.batch_size,
        'epochs': epochs,
        'n_conv': len(model.net)
    }
)

run_id = run.id # We remember here the run ID to be able to write the evaluation metrics later

for epoch in tqdm(range(epochs)):
    model.train()    
    epoch_loss = 0
    for batch_data in train_loader: 
        optimizer.zero_grad()
        outputs = model(batch_data["img"].to(device))
        loss = loss_function(outputs, batch_data["label"].to(device))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    train_loss = epoch_loss / len(train_loader)
    
    val_loss = 0
    for batch_data in validation_loader:
        model.eval()
        outputs = model(batch_data["img"].to(device))
        loss = loss_function(outputs, batch_data["label"].to(device))
        val_loss+= loss.item()
    val_loss = val_loss / len(validation_loader)
    
    wandb.log({'train_loss': train_loss, 'val_loss': val_loss})

# Log trained model in W&B
torch.save(model.state_dict(), r'covid_classifier.pt')
artifact = wandb.Artifact(name=f"covid_classifier", type="model")
artifact.add_file(r'covid_classifier.pt')
wandb.log_artifact(artifact)
remove(r'covid_classifier.pt')

run.finish()

Evaluate your trained classifier#

If you are satisfied with the performance of your model, you can now evaluate it (and log its performance in W&B).

We provide two tools to evaluate the performance of the network on a data set:

  • compute_prediction will output a DataFrame with the individual prediction of the model on each image

  • compute_confusion_matrix computes the confusion matrix, which allows to see how images belonging to each class were predicted.

def compute_prediction(dataloader, model):
    """
    Computes a DataFrame whose rows correspond to images in the data set wrapped by dataloader.
    
    Args:
        dataloader: a DataLoader wrapping a DataSet.
        model: a torch or monai network.
        
    Returns:
        A pandas DataFrame with 4 columns:
        - img_path (str): path to the image file,
        - diagnosis (str): diagnosis ("covid" or "normal")
        - label (int): ground truth label corresponding to the diagnosis (0 is "normal" and 1 is "covid")
        - prediction (int): prediction of the network, corresponding to the node with the highest value.
            Can be directly compared to "label".
    """
    
    model.eval()
    results_df = pd.DataFrame(columns=["img_path", "diagnosis", "label", "prediction"])
    
    for batch_dict in dataloader:
        image_paths, diagnoses, labels = batch_dict["img_path"], batch_dict["diagnosis"], batch_dict["label"]
        images = batch_dict["img"]
        outputs = model(images.to(device))
        prediction = torch.argmax(outputs.data, axis=1)
        for idx in range(len(prediction)):
            row = [image_paths[idx], diagnoses[idx], labels[idx].item(), prediction[idx].item()]
            row_df = pd.DataFrame([row], columns=results_df.columns)
            results_df = pd.concat([results_df, row_df])
    
    results_df.reset_index(inplace=True, drop=True)
    return results_df


def compute_confusion_matrix(dataloader, model):
    """
    Computes the confusion matrix for the labels and predictions "normal" and "covid"
    
    Args:
        dataloader (DataLoader): a DataLoader wrapping the evaluated data set.
        model (Module): a torch or monai network.
        
    Returns:
        (pd.DataFrame) the confusion matrix
    """
    
    prediction_df = compute_prediction(dataloader, model)
    confusion_df = pd.DataFrame(index=["covid", "normal"], columns=["covid", "normal"])
    confusion_df.loc["normal", "normal"] = len(prediction_df[(prediction_df.label == 0) & (prediction_df.prediction == 0)])
    confusion_df.loc["normal", "covid"] = len(prediction_df[(prediction_df.label == 0) & (prediction_df.prediction == 1)])
    confusion_df.loc["covid", "covid"] = len(prediction_df[(prediction_df.label == 1) & (prediction_df.prediction == 1)])
    confusion_df.loc["covid", "normal"] = len(prediction_df[(prediction_df.label == 1) & (prediction_df.prediction == 0)])
    return confusion_df

Exercise

Complete the cell below to evaluate your network on the test set. Log your confusion matrix with log_confusion_matrix.

def log_confusion_matrix(run_id, confusion_df, mode="test"):
    """
    Saves the values of a confusion matrix to W&B interface.
    
    Args:
        run_id (str): ID of the run you want to log to.
            In the training cell, this value was assigned to the variable "run_id".
            You can also retrieve it in the log of the cell or on the W&B interface.
        confusion_df (pd.DataFrame): output of `compute_confusion_matrix`.
        mode (str): name of the subset used to compute `confusion_df`.
            May correspond to "train", "validation" or "test".
    """
    print(f"Logging the results on {mode} set of run {run_id}")
    api = wandb.Api()
    run = api.run(f"tutorial6_explainability/{run_id}")
    run.summary[f"{mode}_TP"] = confusion_df.loc["covid", "covid"]
    run.summary[f"{mode}_FP"] = confusion_df.loc["normal", "covid"]
    run.summary[f"{mode}_TN"] = confusion_df.loc["normal", "normal"]
    run.summary[f"{mode}_FN"] = confusion_df.loc["covid", "normal"]
    run.save()

Explain your trained classifier#

You may have obtained very good results with your trained classifier on test data, i.e., data that was never seen before by the classifier during training. That’s the first step to validate your network!

In this section we are now going to check which parts of the image the network focuses on to compute its prediction. In this tutorial, we will use the algorithm described in (Selvaraju et al., 2017): Grad-CAM.

Grad-CAM algorithm#

The Grad-CAM map is the weighted sum of the feature maps produced by the last convolutional layer of your network:

  1. The feature maps of the last convolutional layer are computed during a forward pass

  1. Gradients at the level of the last convolutional layer are computed. This operation actually computes an output of the same size as the feature maps (in our example 128x16x16). Then gradients are pooled to only obtain one value per feature map (in our example 128 scalars).

  1. Grad-CAM is the sum of the feature maps multiplied by their respective pooled gradients. This results in one low-resolution map (in our case 16x16)

  1. The map is upsampled to the size of the input image (in our case 512x512).

Because of the upsampling step, Grad-CAM maps look very smooth, which is why they are very appreciated in the community. But be aware that actually this map, even if it was resized, has a low spatial resolution. This is why you should be careful when choosing your architecture when you want to use Grad-CAM: if the resolution of the feature maps you want to use is too low (for example 4x4), in the end you won’t be able to see anything.

class GradCam:
    """
    Produces Grad-CAM to a monai.networks.nets.Classifier
    """
    def __init__(self, model):
        self.model = model
        self.model.eval()
        self.device = next(model.parameters()).device

    def generate_gradients(self, input_batch, target_class=None):
        """
        Generate the gradients map corresponding to the input_tensor.
        
        Args:
            input_tensor (Tensor): tensor representing a batch of images.
            target_class (int): allows to choose from which node the gradients are back-propagated.
                Default will back-propagate from the node corresponding to the true class of the image.
            
        Returns:
            (Tensor): the gradients map
        """
        input_tensor = input_batch["img"].to(self.device)        
        # Dissect model
        conv_part = self.model.net
        final_part = self.model.final
        
        # Get last conv feature map
        feature_maps = conv_part(input_tensor).detach()
        feature_maps.requires_grad = True
        model_output = final_part(feature_maps)
        # Target for backprop
        one_hot_output = torch.zeros_like(model_output)
        if target_class is not None:
            one_hot_output[:, target_class] = 1
        else:
            labels = input_batch["label"]
            for i, target_class in enumerate(labels):
                one_hot_output[i, target_class] = 1
        one_hot_output = one_hot_output.to(self.device)
        # Backward pass
        model_output.backward(gradient=one_hot_output)
        # Convert Pytorch variable to numpy array
        gradients = feature_maps.grad
        pooled_gradients = torch.mean(gradients, dim=[2, 3]).unsqueeze(2).unsqueeze(3)
        
        # Weight feature maps according to pooled gradients
        feature_maps.requires_grad = False
        feature_maps *= pooled_gradients
        # Take the mean of all weighted feature maps
        grad_cam = torch.mean(feature_maps, dim=1).cpu()
        resize_transform = monai.transforms.Resize(input_tensor.shape[-2::], mode="bilinear")
        
        return resize_transform(grad_cam).unsqueeze(1)

Exercise

Visualize Grad-CAM maps obtained on the test set with the function visualize_grad_cam.

def visualize_grad_cam(batch_dict, model, target_class=None, v_display=None):
    """
    Plots chest X-rays images with their corresponding grad-CAM maps.
    
    Args:
        batch_dict (dict): batch of samples produced by a DataLoader.
        model (Classifier): a monai Classifier with two output classes.
        target_class (int): allows to choose from which node the gradients are back-propagated.
            Default will back-propagate from the node corresponding to the true class of the image.
        v_display (float): changes the scale of the gradient maps.
    """
    import matplotlib.pyplot as plt
    
    gradients_transform = GradCam(model)
    gradients = gradients_transform.generate_gradients(batch_dict, target_class)
    outputs = model(batch_dict["img"].to(device))
    prediction = torch.argmax(outputs.data, axis=1)
    for i in range(len(gradients)):
        plt.imshow(batch_dict["img"][i, 0], cmap="gray")
        if v_display is None:
            v = max(-gradients.min(), gradients.max())
        else:
            v = v_display
        plt.imshow(gradients[i, 0], alpha=0.5, vmin=-v, vmax=v, cmap="bwr")
        plt.title(f"Label={sample['label'][i]}, prediction={prediction[i]}")
        plt.show()

Exercise

What parts of the images is the network mostly using to complete its task? Is it clinically relevant? Why did this happen?

Answer: …

Experimental confirmation of learnt shortcuts#

In this last section, the goal is to fool a network by generating data to artificially transform a “normal” image in a “covid” image for the network.

These images are generated according to the existing images of your data set, which are transformed according to custom transforms, which were manually made by your teachers to reproduce biases that were possibly learnt by your network in your data set.

Resizing: If you changed the size of the images, TERRIBLE THINGS WILL HAPPEN NOW.

The first two classes, MaskingTransform and CropResizeTransform are not meant to be used directly. They are utilities used to create other transforms in a cell below.

from copy import copy

class MaskingTransform(monai.transforms.Transform):
    """This transform applies a binary mask of the same size as the image it is applied to."""
    def __init__(self, mask_pt, label=None, value=0):
        """
        Args:
            mask_pt (Tensor): a binary mask that will be applied to occlude an image.
            label (int): if given, transform will only be performed on images with the given label.
                Default will transform all images.
            value (float): constant value used to perturb the image.
        """
        self.label = label
        self.value = value
        self.mask_pt = mask_pt.float()
        self.invert_mask_pt = self.invert_mask(self.mask_pt)
        
    def __call__(self, sample):
        sample = copy(sample)
        
        if self.label is None or self.label == sample["label"]:
            image = sample["img"] * self.invert_mask_pt + self.mask_pt * self.value
            sample["img"] = image

        return sample
    
    @staticmethod
    def invert_mask(pt):
        negative_image = -pt + 1
        return (negative_image - negative_image.min()) / (negative_image.max() - negative_image.min())

    
class CropResizeTransform(monai.transforms.Transform):
    """This transform crop a region of interest and resize the image to its initial size."""
    def __init__(self, roi_center, roi_size, label=None):
        """
        Args:
            roi_center (Tuple[int, int]): coordinates of the center of the region of interest.
            roi_size (Tuple[int, int]): size of the region of interest.
            label (int): if given, transform will only be performed on images with the given label.
                Default will transform all images.
        """
        self.label = label
        self.crop_transform = monai.transforms.SpatialCrop(roi_center=roi_center, roi_size=roi_size)
        self.resize_transform = monai.transforms.Resize((512, 512), mode="bilinear")
        
    def __call__(self, sample):        
        sample = copy(sample)
        
        if self.label is None or self.label == sample["label"]:
            image = self.resize_transform(self.crop_transform(sample["img"]))
            sample["img"] = image

        return sample

In the cell below we provide 5 transforms which can be easily applied to your samples.

Use the key label if you want to apply them to one label only (0 or 1) in your data set!

class RemoveShoulders(MaskingTransform):
    def __init__(self, label=None):
        mask_pt = torch.zeros((1, 512, 512))
        mask_pt[:, :100, :200] = 1
        mask_pt[:, :100, 312:] = 1
        super().__init__(mask_pt, label)
    

class AddSideBackground(MaskingTransform):
    def __init__(self, label=None):
        mask_pt = torch.zeros((1, 512, 512))
        mask_pt[:, 150:, :50] = 1
        mask_pt[:, 150:, 512 - 50:] = 1
        super().__init__(mask_pt, label)
    
    
class CropSideBackground(CropResizeTransform):
    def __init__(self, label=None):
        roi_center = (256, 256)
        roi_size = (512, 412)
        super().__init__(roi_center, roi_size, label)
    

class CropShouldersUp(CropResizeTransform):
    def __init__(self, label=None):
        roi_center = (312, 256)
        roi_size = (412, 512)
        super().__init__(roi_center, roi_size, label)

    
class RWriter(MaskingTransform):
    def __init__(self, label=None, upsampling=3):
        from PIL import Image, ImageDraw, ImageFont
        
        size = 512
        h_offset = 120
        v_offset = 60

        # Create black image with white R letter
        image = Image.new("L", (size // upsampling, size // upsampling)) # As the size of the font cannot be easily chosen, a smaller image is created and the resizing will increase the size of the font
        draw = ImageDraw.Draw(image)
        draw.text((h_offset // upsampling, v_offset // upsampling), "R", fill="white")
        image = image.resize((size, size))
        
        # Convert to Tensor
        mask_pt = torch.from_numpy(np.asarray(image)).float() / 255
        super().__init__(mask_pt, label, value=1)

Exercise

Visualize the images produced by each transform on a sample to understand what they do.

Exercise

Compute and compare the confusion matrices obtained with or without transforms of your choice on the test images,.

Exercise

Does this experiment highlight the same biases you found with Grad-CAM? What are the limitations of this procedure? Are some transforms more prone to limitations than others?

Answer: …

Take-home messages#

  1. Data curation and data analysis prior to training is not fun, but it is absolutely essential. Knowing your data set and its biases will avoid bad surprises after months choosing hyperparameters and training hundreds of networks…

  2. We always need more data for deep learning. However, mixing different sources to increase the size of your data set may not always be a good idea. You should always assess if data sets are compatible, or even better, use only one for training and the other for testing to assess the generalizability to other cohorts!

  3. Be fair and explain precisely your limitations. Modest but honest results are better than breakthroughs relying on lies.

Part 2 - Equivariance#

In this part, we are going to look at the equivariance properties of a neural network architecture that you should by now be very familiar with: the U-Net. We will use the same problem as in Tutorial 3: chest X-ray segmentation. Because training a network is not the focus here, we have pretrained a network that you can use for these experiments. First, set the data path as before and have it point to where you have stored the data for Tutorial 3. If you don’t have that data anymore, you can download it, unzip it, and set the data path as follows.

!wget https://surfdrive.surf.nl/files/index.php/s/Y4psc2pQnfkJuoT/download -O Tutorial_3.zip
!unzip -qo Tutorial_3.zip
data_path = "ribs"

Data loading#

Next, we will use the same utility functions as in Tutorial 3 to build a dictionary of files and load rib data.

import os
import numpy as np
import matplotlib.pyplot as plt
import glob
import monai
from PIL import Image
import torch

def build_dict_ribs(data_path, mode='train'):
    """
    This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' 
    that returns the path to the corresponding image.
    
    Args:
        data_path (str): path to the root folder of the data set.
        mode (str): subset used. Must correspond to 'train', 'val' or 'test'.
        
    Returns:
        (List[Dict[str, str]]) list of the dictionaries containing the paths of X-ray images and masks.
    """
    # test if mode is correct
    if mode not in ["train", "val", "test"]:
        raise ValueError(f"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.")
    
    # define empty dictionary
    dicts = []
    # list all .png files in directory, including the path
    paths_xray = glob.glob(os.path.join(data_path, mode, 'img', '*.png'))
    # make a corresponding list for all the mask files
    for xray_path in paths_xray:
        if mode == 'test':
            suffix = 'val'
        else:
            suffix = mode
        # find the binary mask that belongs to the original image, based on indexing in the filename
        image_index = os.path.split(xray_path)[1].split('_')[-1].split('.')[0]
        # define path to mask file based on this index and add to list of mask paths
        mask_path = os.path.join(data_path, mode, 'mask', f'VinDr_RibCXR_{suffix}_{image_index}.png')
        if os.path.exists(mask_path):
            dicts.append({'img': xray_path, 'mask': mask_path})
    return dicts

class LoadRibData(monai.transforms.Transform):
    """
    This custom Monai transform loads the data from the rib segmentation dataset.
    Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.
    """
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        image = Image.open(sample['img']).convert('L') # import as grayscale image
        image = np.array(image, dtype=np.uint8)
        mask = Image.open(sample['mask']).convert('L') # import as grayscale image
        mask = np.array(mask, dtype=np.uint8)
        # mask has value 255 on rib pixels. Convert to binary array
        mask[np.where(mask==255)] = 1
        return {'img': image, 'mask': mask, 'img_meta_dict': {'affine': np.eye(2)}, 
                'mask_meta_dict': {'affine': np.eye(2)}}

Use the cell below to make a validation loader with a single image. This is sufficient for the small experiment that you will perform.

validation_dict_list = build_dict_ribs(data_path, mode='val')
validation_transform = monai.transforms.Compose(
    [
        LoadRibData(),
        monai.transforms.EnsureChannelFirstd(keys=['img', 'mask'], channel_dim="no_channel"),
        monai.transforms.HistogramNormalized(keys=['img']),     
        monai.transforms.ScaleIntensityd(keys=['img'], minv=0, maxv=1),
        monai.transforms.Zoomd(keys=['img', 'mask'], zoom=0.25, mode=['bilinear', 'nearest'], keep_size=False),
        # monai.transforms.RandSpatialCropd(keys=['img', 'mask'], roi_size=[384, 384], random_size=False)
        monai.transforms.SpatialCropd(keys=['img', 'mask'], roi_center=[300, 300], roi_size=[384 + 64, 384])        
    ]
)
validation_data = monai.data.CacheDataset([validation_dict_list[3]], transform=validation_transform)
validation_loader = monai.data.DataLoader(validation_data, batch_size=1, shuffle=False)

Loading a pretrained model#

We have already trained a model for you, the parameters of which were shared in JupyterLab as well. Note: if you downloaded the data set yourself, the model should be in the same folder as the images. If you already downloaded the data set but not the model, the model file is available here.

pretrained_file = path.join(data_path, "trainedUNet.pt")

Next, we initialize a standard U-Net architecture and load the parameters of the pretrained network using the load_state_dict function.

import torch
import monai

# Check whether we're using a GPU
if torch.cuda.is_available():
    n_gpus = torch.cuda.device_count()  # Total number of GPUs
    gpu_idx = random.randint(0, n_gpus - 1)  # Random GPU index
    device = torch.device(f'cuda:{gpu_idx}')
    print('Using GPU: {}'.format(device))
else:
    device = torch.device('cpu')
    print('GPU not found. Using CPU.')

model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels = (8, 16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    dropout=0.5
).to(device)

model.load_state_dict(torch.load(pretrained_file))
model.eval()

Let’s use the pretrained network to segment (part of) our image. Run the cell below.

for sample in validation_loader:

    img = sample['img'][:, :, :384, :384]    
    mask = sample['mask'][:, :, :384, :384]
    output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()   
    
    fig, ax = plt.subplots(1,2, figsize = [12, 10])    
    # Plot X-ray image
    ax[0].imshow(img.squeeze(), 'gray')
    # Plot ground truth
    mask = np.squeeze(mask)
    overlay_mask = np.ma.masked_where(mask == 0, mask == 1)
    ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')
    ax[0].set_title('Ground truth')
    # Plot output
    overlay_output = np.ma.masked_where(output_noshift < 0.1, output_noshift > 0.99)
    ax[1].imshow(img.squeeze(), 'gray')
    ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])
    ax[1].set_title('Prediction')
    plt.show()      

As you can see, segmentation isn’t perfect, but that’s also not the goal of this exercise. What we are going to look into is the translation equivariance (Lecture 8) of the U-Net. That is: if you translate the image by \(d\) pixels, does the output also simply change by \(d\) pixels. Note that this is a nice feature to have for a segmentation network: in principle we’d want our network to give us the same label for a pixel regardless of where the image was cut. The image below visualizes this principle. For segmentation of the pixels in the orange square, it shouldn’t matter if we provide the red square or the green square as input to the U-Net.

Exercise

What do you think will happen to the U-Net’s prediction if we give it a slightly shifted version of the image as input?

Now we make a small script that performs the above experiment. First, we obtain the segmentation in the red box and we call this output_noshift. Then we shift the green box by an offset and each time obtain a segmentation in this box using the same model. We start small with a shift/offset of just a single pixel.

Exercise

Run the cell below and observe the outputs. Can you spot differences between the two segmentation masks?

offset = 1

for sample in validation_loader:

    # Original image
    img = sample['img'][:, :, :384, :384]    
    mask = sample['mask'][:, :, :384, :384]
    output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()   

    # Plot X-ray image
    fig, ax = plt.subplots(1,2, figsize = [12, 10])    
    ax[0].imshow(img.squeeze(), 'gray')
    # Plot ground truth
    mask = np.squeeze(mask)
    overlay_mask = np.ma.masked_where(mask == 0, mask == 1)
    ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')
    ax[0].set_title('Ground truth')
    # Plot output
    overlay_output = np.ma.masked_where(output_noshift < 0.1, output_noshift >0.99)
    ax[1].imshow(img.squeeze(), 'gray')
    ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])
    ax[1].set_title('Prediction')
    plt.show()
    
    # Shifted image
    img = sample['img'][:, :, offset:offset+384, :384]
    mask = sample['mask'][:, :, offset:offset+384, :384]
    output = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()

    # Plot X-ray image
    fig, ax = plt.subplots(1,2, figsize = [12, 10])
    ax[0].imshow(img.squeeze(), 'gray')
    # Plot ground truth
    mask = np.squeeze(mask)
    overlay_mask = np.ma.masked_where(mask == 0, mask == 1)
    ax[0].imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')
    ax[0].set_title('Ground truth shifted')
    # Plot output
    overlay_output = np.ma.masked_where(output < 0.1, output >0.99)
    ax[1].imshow(img.squeeze(), 'gray')
    ax[1].imshow(overlay_output.squeeze(), 'Reds', alpha = 0.7, clim=[0,1])
    ax[1].set_title('Prediction shifted')
    plt.show()

To highlight the differences between both segmentation masks a bit more, we make a difference image. We correct for the shift applied so that we’re not comparing apples and oranges. The next cell shows the difference image between the original image and what we get when we process an image that is shifted by one pixel.

Exercise

Given these results, is a U-Net translation equivariant, invariant, or neither?

plt.figure(figsize=(6, 6))
diffout = output_noshift[offset:, :384] - output[:-offset, :384]
plt.imshow(diffout, cmap='seismic', clim=[-1, 1])
plt.title('Offset {}'.format(offset))
plt.colorbar()
plt.show()

We can repeat this for larger offsets. Let’s take offsets up to 64 pixels, and each time compute the difference between the original and shifted image, in a subimage that should be unaffected by the shift. We store the L1 norm of the difference image in an array norms and plot these as a function of offset.

Exercise

The resulting plot shows that the U-Net is equivariant for none of the translations. This is due to a combination of border effects and downsampling layers. However, the plot also shows a particular pattern, in which the norm dips every 16 pixels of offset. Can you explain this based on the U-Net architecture?

norms = []
offsets = []
plot_differences = False  # Set to True to plot difference images for every offset

img = sample['img'][:, :, :384, :384]    
mask = sample['mask'][:, :, :384, :384]
output_noshift = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()   

for offset in range(1, 65):
    for sample in validation_loader:
        img = sample['img'][:, :, offset:offset+384, :384]
        mask = sample['mask'][:, :, offset:offset+384, :384]

        output = torch.sigmoid(model(img.to(device))).detach().cpu().numpy().squeeze()  

        diffout = (output_noshift[offset:, :384] - output[:-offset, :384])[100:284, 100:284]
        offsets.append(offset)
        norms.append(np.sum(np.abs(diffout)))
        if plot_differences:
            plt.figure()
            plt.imshow(diffout, cmap='seismic', clim=[-1, 1])
            plt.title(f"Offset {offset}")
            plt.colorbar()
            plt.show()

plt.figure()
plt.plot(offsets, norms)
plt.xlabel('Offset')
plt.ylabel('Difference')
plt.show()

This is the end of this course!#