Tutorial 5#

June 4, 2024#

In the previous tutorials, you have familiarized yourself with PyTorch, MONAI, and Weights & Biases. In last week’s lecture, you have learned about registration. In this tutorial, you will develop, train, and evaluate a CNN for denoising of (synthetic) CT images.

First, let’s take care of the necessities:

  • If you’re using Google Colab, make sure to select a GPU Runtime.

  • Connect to Weights & Biases using the code below.

  • Install a few libraries that we will use in this tutorial.

import os
import wandb

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
wandb.login()
!pip install dival
!pip install kornia

Reconstruction#

In this tutorial, you will reconstruct CT images. To not use too much disk storage, we will synthetise images on the fly using the Deep Inversion Validation Library (dival). These are 2D images with \(128\times 128\) pixels that contain a random number of ellipses with random sizes and random intensities.

First, make a dataset of ellipses. This will make an object that we can call for images using a generator. Next, we take a look at what this dataset contains. We will use the generator to ask for a sample. Each sample contains a sinogram and a ground truth (original) synthetic image that we can visualize. You may recall from the lecture that the sinogram is made up of integrals along projections. The horizontal axis in the sinogram corresponds to the location \(s\) along the detector, the vertical axis to the projection angle \(\theta\).

import dival

dataset = dival.get_standard_dataset('ellipses', impl='skimage')
dat_gen = dataset.generator(part='train')

Run the cell below to show a sinogram and image in the dataset.

import numpy as np
import matplotlib.pyplot as plt

# Get a sample from the generator
sinogram, ground_truth = next(dat_gen)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# Show the sinogram
axs[0].imshow(sinogram, cmap='gray', extent=[0, 183, -90, 90])
axs[0].set_title('Sinogram')
axs[0].set_xlabel('$s$')
axs[0].set_ylabel('$\Theta$')

# Show the ground truth image
axs[1].imshow(ground_truth, cmap='gray')
axs[1].set_title('Ground truth')
axs[1].set_xlabel('$x$')
axs[1].set_ylabel('$y$')
plt.show()   

Exercise

What kind of CT reconstruction problem is this? Limited-view or sparse-angle CT? Why?

Answer key

This is a sparse-angle CT recontruction problem. The view spans 180 degrees, but the number of angles is low.

Not only does the sinogram contain few angles, it also contains added white noise. If we simply backproject the sinogram to the image domain we end up with a low-quality image. Let’s give it a try using the standard Filtered Backprojection (FBP) algorithm for CT and its implementation in scikit-image.

import skimage.transform as sktr

# Get a sample from the generator
sinogram, ground_truth = next(dat_gen)
sinogram = np.asarray(sinogram).transpose()

# This defines the projectiona angles
theta = np.linspace(-90., 90., sinogram.shape[1], endpoint=True)

# Perform FBP
fbp_recon = sktr.iradon(sinogram, theta=theta, filter_name='ramp')[28:-27, 28:-27]
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(sinogram.transpose(), cmap='gray', extent=[0, 183, -90, 90])
axs[0].set_title('Sinogram')
axs[0].set_xlabel('$s$')
axs[0].set_ylabel('$\Theta$')
axs[1].imshow(ground_truth, cmap='gray', clim=[0, 1])
axs[1].set_title('Ground truth')
axs[1].set_xlabel('$x$')
axs[1].set_ylabel('$y$')
axs[2].imshow(fbp_recon, cmap='gray', clim=[0, 1])
axs[2].set_title('FBP')
axs[2].set_xlabel('$x$')
axs[2].set_ylabel('$y$')
plt.show()

Exercise

What do you think of the quality of the reconstructed FBP algorithm? Use the cell below to quantify the similarity between the images using the structural similarity index (SSIM). Does this reflect your intuition? Also compute the PSNR using the peak_signal_noise_ratio method in scikit-image.

Answer key

import skimage.metrics as skme

print('SSIM = {:.2f}'.format(skme.structural_similarity(np.asarray(ground_truth), fbp_recon, data_range=np.max(ground_truth)-np.min(ground_truth))))
print('PSNR = {:.2f}'.format(skme.peak_signal_noise_ratio(np.asarray(ground_truth), fbp_recon)))

Datasets and dataloaders#

Our (or your) goal now is to obtain high(er) quality reconstructed images based on the sinogram measurements. As you have seen in the lecture, this can be done in four ways:

  1. Train a reconstruction method that directly maps from the measurement (sinogram) domain to the image domain.

  2. Preprocessing Clean up the sinogram using a neural network, then backproject to the image domain.

  3. Postprocessing First backproject to the image domain, then improve the reconstruction using a neural network.

  4. Iterative methods that integrate data consistency.

Here, we will follow the third approach, postprocessing. We create reconstructions from the generated sinograms using filtered backprojection and use a neural network to learn corrections on this FBP image and improve the reconstruction, as shown in the image below. The data that we need for training this network is the reconstructions from FBP, and the ground-truth reconstructions from the dival dataset.

We will make a training dataset of 512 samples from the ellipses dival dataset that we store in a MONAI DataSet. The code below does this in four steps:

  1. Create a dival generator that creates sinograms and ground-truth reconstructions.

  2. Make a dictionary (like we did in the previous tutorial) that contains the ground-truth reconstructions and the reconstructions constructed by FBP as separate keys.

  3. Define the transforms for the data (also like the previous tutorial). In this case we require an additional ‘channels’ dimension, as that is what the neural network expects. We will not make use of extra data augmentation.

  4. Construct the dataset using the dictionary and the defined transform.

import tqdm
import monai

theta = np.linspace(-90., 90., sinogram.shape[1], endpoint=True)

# Make a generator for the training part of the dataset
train_gen = dataset.generator(part='train')
train_samples = []

# Make a list of (in this case) 512 random training samples. We store the filtered backprojection (FBP) and ground truth image
# in a dictionary for each sample, and add these to a list.
for ns in tqdm.tqdm(range(512)):
    sinogram, ground_truth = next(train_gen)
    sinogram = np.asarray(sinogram).transpose()
    fbp_recon = sktr.iradon(sinogram, theta=theta, filter_name='ramp')[28:-27, 28:-27]
    train_samples.append({'fbp': fbp_recon, 'ground_truth': np.asarray(ground_truth)})

# You can add or remove transforms here
train_transform = monai.transforms.Compose([
    monai.transforms.AddChanneld(keys=['fbp', 'ground_truth'])
])    

# Use the list of dictionaries and the transform to initialize a MONAI CacheDataset
train_dataset = monai.data.CacheDataset(train_samples, transform=train_transform)    

Exercise

Also make a validation dataset and call it val_dataset. This dataset can be smaller, e.g., 64 or 128 samples.

Answer key

val_gen = dataset.generator(part='validation')
val_samples = []
val_transform = monai.transforms.Compose([
    monai.transforms.AddChanneld(keys=['fbp', 'ground_truth'])
])
for ns in tqdm.tqdm(range(128)):
    sinogram, ground_truth = next(val_gen)
    sinogram = np.asarray(sinogram).transpose()
    fbp_recon = sktr.iradon(sinogram, theta=theta, filter_name='ramp')[28:-27, 28:-27]
    val_samples.append({'fbp': fbp_recon, 'ground_truth': np.asarray(ground_truth)})
val_dataset = monai.data.CacheDataset(val_samples, transform=val_transform)  

Exercise

Now, make a dataloader for both the validation and training data, called train_loader and validation_loader, that we can use for sampling batches during training of the network. Give them a reasonable batch size, e.g., 16.

Answer key

train_loader = monai.data.DataLoader(train_dataset, batch_size=16)
validation_loader = monai.data.DataLoader(val_dataset, batch_size=16)

Model#

Now that we have datasets and dataloaders, the next step is to define a model, optimizer and criterion. Because we want to improve the FBP-reconstructed image, we are dealing with an image-to-image task. A standard U-Net as implemented in MONAI is therefore a good starting point. First, make sure that you are using the GPU (CUDA), otherwise training will be extremely slow.

import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = "cpu"
print(f'The used device is {device}')

Exercise

Initialize a U-Net with the correct settings, e.g. channels and dimensions, and call it model. Here, it’s convenient to use the BasicUNet as implemented in MONAI.

Answer key

model = monai.networks.nets.BasicUNet(
    spatial_dims=2,
    out_channels=1
).to(device)

# model = monai.networks.nets.SegResNet(
#     spatial_dims=2,
#     out_channels=1
# ).to(device)

Loss function#

An important aspect is the loss function that you will use to optimize the model. The problem that we are trying to solve using a neural network is a regression problem, which differs from the classification approach we covered in the segmentation tutorial. Instead of classifying each pixel as a certain class, we alter their intensities to obtain a better overall reconstruction of the image.

Because this task is substantially different, we need to change our loss function. In the previous tutorial we used the Dice loss, which measures the overlap for each of the classes to segment. In this case, an L2 (mean squared error) or L1 (mean average error) loss suits our objective. Alternatively, we can use a loss that aims to maximize the structural similarity (SSIM). For this, we use the kornia library.

import kornia 

# Three loss functions, turn them on or off by commenting

loss_function = torch.nn.MSELoss()
# loss_function = torch.nn.L1Loss()
# loss_function = kornia.losses.SSIMLoss(window_size=3)

As in previous tutorials, we use an adaptive SGD (Adam) optimizer to train our network. This tutorial, we add a learning rate scheduler. This scheduler lowers the learning rate every step_size steps, meaning that the optimizer will take smaller steps in the direction of the gradient after a set amount of epochs. Therefore, the optimizer can potentially find a better local minimum for the weights of the neural network.

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

Exercise

Complete the code below and train the U-Net.

What does the model learn? Look carefully at how we determine the output of the model. Can you describe what happens in the following line: outputs = model(batch_data['fbp'].float().to(device)) + batch_data["fbp"].float().to(device)?

Answer key

from tqdm.notebook import tqdm
import wandb
from skimage.metrics import structural_similarity as ssim


run = wandb.init(
    project='tutorial4_reconstruction',
    config={
        'loss function': str(loss_function), 
        'lr': optimizer.param_groups[0]["lr"],
        'batch_size': train_loader.batch_size,
    }
)
# Do not hesitate to enrich this list of settings to be able to correctly keep track of your experiments!
# For example you should include information on your model architecture

run_id = run.id # We remember here the run ID to be able to write the evaluation metrics

def log_to_wandb(epoch, train_loss, val_loss, batch_data, outputs):
    """ Function that logs ongoing training variables to W&B """

    # Create list of images that have segmentation masks for model output and ground truth
    # log_imgs = [wandb.Image(PIL.Image.fromarray(img.detach().cpu().numpy())) for img in outputs]
    val_ssim = []
    for im_id in range(batch_data['ground_truth'].shape[0]):
        val_ssim.append(ssim(batch_data['ground_truth'].detach().cpu().numpy()[im_id, 0, :, :].squeeze(), 
                             outputs.detach().cpu().numpy()[im_id, 0, :, :].squeeze() ))
    val_ssim = np.mean(np.asarray(val_ssim))
    # Send epoch, losses and images to W&B
    wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'val_ssim': val_ssim}) 
    
for epoch in tqdm(range(75)):
    model.train()    
    epoch_loss = 0
    step = 0
    for batch_data in train_loader: 
        step += 1
        optimizer.zero_grad()
        outputs = model(batch_data["fbp"].float().to(device)) + batch_data["fbp"].float().to(device)
        loss = loss_function(outputs, batch_data["ground_truth"].to(device))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    train_loss = epoch_loss/step
    # validation part
    step = 0
    val_loss = 0
    for batch_data in validation_loader:
        step += 1
        model.eval()
        outputs = model(batch_data['fbp'].float().to(device)) + batch_data["fbp"].float().to(device)
        loss = loss_function(outputs, batch_data['ground_truth'].to(device))   
        val_loss+= loss.item()
    val_loss = val_loss / step
    log_to_wandb(epoch, train_loss, val_loss, batch_data, outputs)
    scheduler.step()

# Store the network parameters        
torch.save(model.state_dict(), r'trainedUNet.pt')
run.finish()

Exercise

Now make a DataSet and DataLoader for the test set. Just a handful of images should be enough.

Answer key

import tqdm

test_gen = dataset.generator(part='test')
test_samples = []
test_transform = monai.transforms.Compose([
    monai.transforms.AddChanneld(keys=['fbp', 'ground_truth'])
])
for ns in tqdm.tqdm(range(4)):
    sinogram, ground_truth = next(test_gen)
    sinogram = np.asarray(sinogram).transpose()
    fbp_recon = sktr.iradon(sinogram, theta=theta, filter_name='ramp')[28:-27, 28:-27]
    test_samples.append({'sinogram': sinogram, 'fbp': fbp_recon, 'ground_truth': np.asarray(ground_truth)})
test_dataset = monai.data.CacheDataset(test_samples, transform=val_transform)

test_loader = monai.data.DataLoader(test_dataset, batch_size=1)

Exercise

Visualize a number of reconstructions from the neural network and compare them to the fbp reconstructed images, using the code below. The performance of the network is evaluated using the structural similarity function in scikit-image. Does the neural network improve this metric a lot compared to the filtered back projection?

model.eval()

for test_sample in test_loader:
    output = model(test_sample['fbp'].to(device)) + test_sample['fbp'].to(device)
    output = output.detach().cpu().numpy()[0, 0, :, :].squeeze()
    ground_truth = test_sample['ground_truth'][0, 0, :, :].squeeze()
    fbp_recon = test_sample['fbp'][0, 0, :, :].squeeze()
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(fbp_recon, cmap='gray', clim=[0, 1])
    axs[0].set_title('FBP SSIM={:.2f}'.format(ssim(ground_truth.cpu().numpy(), fbp_recon.cpu().numpy())))
    axs[0].set_xlabel('$x$')
    axs[0].set_ylabel('$y$')
    axs[1].imshow(ground_truth, cmap='gray', clim=[0, 1])
    axs[1].set_title('Ground truth')
    axs[1].set_xlabel('$x$')
    axs[1].set_ylabel('$y$')
    axs[2].imshow(output, cmap='gray', clim=[0, 1])
    axs[2].set_title('CNN SSIM={:.2f}'.format(ssim(ground_truth.cpu().numpy(), output)))
    axs[2].set_xlabel('$x$')
    axs[2].set_ylabel('$y$')
    plt.show()   

Answer key

Some observations that you could make:

  • The SSIM is definitely improved compared to the standard filtered back projection (FBP). CNN results should be in the order of ~0.8 SSIM.

  • The output images of the CNN are less noisy than the FBP reconstructions. However, they’re also a bit more blotchy/cartoonish if you use the CNN.

Exercise

Instead of a U-Net, try a different model, e.g., a SegResNet in MONAI. Evaluate how the different loss functions affect the performance of the network. Notes that the SSIM on the validation set is also written to Weights & Biases during training. Which loss leads to the best SSIM scores? Which loss results in the worst SSIM scores?

Answer key

In general, using an SSIM loss will lead to better SSIM scores. The L1 loss is also expected to lead to better then the MSE loss, as it’s less susceptible to outliers and will smooth the resulting images less.

From post-processing to pre-processing#

So far, you have used a post-processing approach for reconstruction. In the lecture, we have discussed an alternative pre-processing approach, in which the sinogram image is improved before FBP. This additional exercise is entirely optional, but you could try to turn the current model into such a model, and see if the results that you get are better or worse than the results obtained so far. Good luck!