Atlassegmentierung auf DWI-Bildern mit Deep Learning

Nov 30 2022
Geschrieben von Asmina Barkhandinova, Cerebra.ai Ltd.
Inhaltsverzeichnis: Eine der Aufgaben, die mit Deep Learning gelöst werden können, ist die Atlassegmentierung, um Ärzten zu helfen, Anomalien in einem bestimmten Bereich des Gehirns zu erkennen und menschliche Faktoren zu reduzieren. Da es verschiedene Untersuchungsarten gibt (wie MRT, CTA, CT etc.

Inhaltsverzeichnis:

  1. Einführung
  2. Daten
  3. Vorverarbeitung
  4. Ausbildung
  5. Servicegebäude
  6. Ergebnisse

Eine der Aufgaben, die mit Deep Learning gelöst werden können, ist die Atlassegmentierung, um Ärzten zu helfen, Anomalien in einem bestimmten Bereich des Gehirns zu erkennen und menschliche Faktoren zu reduzieren. Da es verschiedene Arten von Scans gibt (wie MRT, CTA, CT usw.), gibt es auch mehrere Arten von Modellen, um die Aufgabe für jede Art von Scan zu lösen. In diesem Artikel behandeln wir den Aufbau des DWI Atlas Service.

Bevor wir in den Prozess eintauchen, lassen Sie uns einige Definitionen beschreiben:

Atlas der ASPEKTE – die Regionen des Gehirns. Atlas besteht aus 10 Regionen: m1, m2, m3, m4, m5, m6, i, l, c, ic.

Quelle: radiopedia.org

ASPEKTE — die Vermessung von Atlanten von Verstorbenen. Je höher die Zahl, desto mehr ASPECTS-Zonen blieben unbeschädigt. ASPEKTE = 7 bedeutet zB, dass 7 von 10 Atlanten normal blieben, nicht vom Schlaganfall berührt wurden.

DWI – eine Art von MRT-Untersuchung. In der üblichen MRT gibt es zwei Magnete, die den Spin von Wasserstoffatomen ändern und die reflektierte Energie empfangen, um das Bild des Gehirns zu erstellen. Darüber hinaus erkennt DWI auch den Fluss der Flüssigkeit (Wasser) im Gehirn und spiegelt ihn auf dem Bild wider.

Der Artikel besteht aus 2 Hauptteilen:

  1. Training des Atlas-Segmentierungsmodells auf DWI-Bildern und -Markups.
  2. Erstellen eines Dienstes, der eine Nifti-Datei akzeptiert und Aspektkonturen zurückgibt (Modellvorhersage).

Im Datensatz haben wir DWI b0-Bilder verwendet. Jede Datei hatte das Format [512, 512, 20], dh 20 Scheiben eines 3D-Modells des Gehirns. Wir haben es in 20 2D-Bilder des Gehirns geschnitten, mit Markierungen für jede Schicht, die von Ärzten beschriftet wurde. Insgesamt hatten wir 170 Bilder von geschnittenen Niftis und 170 Markups. Die DWI-Bilder waren die Eingabe und die Markups waren Ziele.

DWI Image Slice und das Markup des aktuellen Slice.

Vorverarbeitung

Jedes „Ziel“ wurde durch One-Hot Encoding vorverarbeitet und als [11, 512, 512] 3D-ndarray gespeichert, wobei jede der 11 Schichten eine bestimmte Klasse darstellte (Hintergrund und die Atlanten). Für die Bilderweiterung haben wir die ImageNet-Bildnormalisierung angewendet und sie in Tensors konvertiert.

def make_onehot_markup(im: Image) -> np.ndarray:

  '''Returns ndarray of (11, 512, 512) in one-hot encoding format
  
  Args:
      im (Image): Image object
  Returns:
      one_hot (np.ndarray): 3D one-hot encoding tensor
  
  '''
  red = process_mask_red(im)
  brown = process_mask_brown(im)
  blue = process_mask_blue(im)
  yellow = process_mask_yellow(im)
  cyan = process_mask_cyan(im)
  mag = process_mask_magenta(im)
  dblue = process_mask_dblue(im)
  green = process_mask_green(im)
  orange = process_mask_orange(im)
  purple = process_mask_purple(im)
  
  background = np.logical_not(np.sum(matrix, axis=0))

  matrix = np.stack([background, red, brown, blue, yellow, cyan,
                    mag, dblue, green, orange, purple])
  
  return matrix

Wir haben das Unet-Segmentierungsmodell mit 11 Klassen (10 Atlanten und 1 Klasse für den Hintergrund) trainiert. Für das Hyperparameter-Tuning haben wir Optuna verwendet, das einen Optimierer und einen Planer vorschlug und versuchte, den Validierungs-Beat-Score zu maximieren. Die Rückrufe speicherten die 3 besten Modelle in jeder Falte.

def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
  
  '''Chooses one of three optimizers and trains with this variant.
  '''

  model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")

  optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
  optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)

  MODEL_DIR = 'logs/optuna_pl'
  os.makedirs(MODEL_DIR, exist_ok=True)
  
  checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")

  logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')

  # Learner is a custom class inherited from LightningModule in pytorch_lightning
  lightning_model = Learner(
          dataloaders=dataloaders,
          model=model,
          evaluator=evaluator,
          summury_writer=summury_writer,
          optimizer=optimizer,
          scheduler=scheduler,
          config=myconfig)

  trainer = pl.Trainer(
          gpus=[0],
          max_epochs=10,
          num_sanity_val_steps=0,
          log_every_n_steps=2,
          logger=logger,
          gradient_clip_val=0.1,
          accumulate_grad_batches=4,
          precision=16,
          callbacks=checkpoint_callback)

  trainer.fit(lightning_model)
  return trainer.callback_metrics["val_best_score"].item()

def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):

  study = optuna.create_study(study_name='optuna_test',direction="maximize")
  study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
  
  df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
  df.to_csv('optuna_test.csv')

  print("Number of finished trials: {}".format(len(study.trials)))
  print("Best trial:")
  trial = study.best_trial
  print("Value: ", trial.value)
  print("Parameteres: ", trial.params)

Der Dienst hatte folgende Struktur:

  1. Akzeptiert die Nachricht vom Backend.
  2. Lädt die nifti-Datei vom s3-Server herunter und speichert sie vorübergehend in einem lokalen Ordner.
  3. Teilt die 3D-Nifti-Datei in 2D-Schnitte.
  4. def build_3d_slices(filename: str) -> List[np.ndarray]:
    
      ''' Loads nifti from filename and returns list of 3-channel slices.
      '''
    
      nifti_data = nib.load(filename).get_fdata()
      nifti_data = np.rot90(nifti_data)
      list_of_slices = []
      for idx in range(nifti_data.shape[-1]):
          slice_img = nifti_data[:, :, idx]
          pil_slice_img = Image.fromarray(slice_img).convert("RGB")
          res = np.asarray(pil_slice_img)
          list_of_slices.append(res)
      return list_of_slices
    

    def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
      
      ''' Returns a dict with augmented images. 
      '''
    
      list_of_slices = build_3d_slices(filename)
      augmentations = get_valid_transforms()
      result = {}
      for i, img in enumerate(list_of_slices):
          sample = augmentations(image=img)
          preprocessed_image = sample['image']
          result[i] = preprocessed_image
      return result
    

    def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
    
      ''' Returns masks of each slice of the nifti.
      '''
    
      result = {}
      for i in range(len(data)):
          preprocessed_image = data[i]
          with torch.no_grad():
              preprocessed_image = preprocessed_image.to(self.device)
              mask = self.model(preprocessed_image.unsqueeze(0))
          mask_soft = torch.softmax(mask, dim=1)
          pr_mask = mask_soft
          pr_mask = pr_mask > self.threshold_prob
          if pr_mask.sum() < self.threshold_area:
              pr_mask = torch.zeros_like(pr_mask)
          pr_mask = pr_mask.type(torch.uint8)
          pr_mask = (pr_mask.squeeze().cpu().numpy())
          pr_mask = pr_mask.astype(np.uint8)
          pr_mask = pr_mask[1:, :, :]
          pr_mask = pr_mask.astype(np.uint8)*255
          result[i] = pr_mask
          
      return result
    

    '''Dict for each slice, the result stored in JSON'''
      conts = {
                  'm1': {'r': [], 'l': []},
                  'm2': {'r': [], 'l': []},
                  'm3': {'r': [], 'l': []},
                  'm4': {'r': [], 'l': []},
                  'm5': {'r': [], 'l': []},
                  'm6': {'r': [], 'l': []},
                  'i': {'r': [], 'l': []},
                  'l': {'r': [], 'l': []},
                  'c': {'r': [], 'l': []},
                  'ic': {'r': [], 'l': []}
              }