Segmentasi Atlas pada Gambar DWI menggunakan Deep Learning

Nov 30 2022
Ditulis oleh Asmina Barkhandinova, Cerebra.ai Ltd.
Daftar isi: Salah satu tugas yang dapat diselesaikan dengan Deep Learning adalah segmentasi atlas, untuk membantu dokter mendeteksi anomali di area otak tertentu dan mengurangi faktor manusia. Karena ada berbagai jenis pemindaian (seperti MRI, CTA, CT, dll.

Daftar Isi:

  1. pengantar
  2. Data
  3. Preprocessing
  4. Pelatihan
  5. Gedung Layanan
  6. Hasil

Salah satu tugas yang dapat diselesaikan dengan Deep Learning adalah segmentasi atlas, untuk membantu dokter mendeteksi anomali di area otak tertentu dan mengurangi faktor manusia. Karena ada berbagai jenis pemindaian (seperti MRI, CTA, CT, dll.), akan ada juga beberapa jenis model untuk menyelesaikan tugas untuk setiap jenis pemindaian. Pada artikel ini, kami akan membahas proses pembangunan DWI Atlas Service.

Sebelum kita menyelami prosesnya, mari kita jelaskan beberapa definisi:

Atlas ASPEK — wilayah otak. Atlas terdiri dari 10 wilayah: m1, m2, m3, m4, m5, m6, i, l, c, ic.

Sumber: radipopedia.org

ASPEK — pengukuran atlas yang telah meninggal. Semakin tinggi angkanya, semakin banyak zona ASPECTS yang tidak terluka. Misalnya ASPECTS = 7 berarti 7 dari 10 atlas tetap normal, tidak tersentuh stroke.

DWI — sejenis pemindaian MRI. Pada MRI biasa terdapat dua magnet yang mengubah putaran atom hidrogen dan menerima energi pantulan untuk membangun citra otak. Selain itu, DWI juga mendeteksi aliran cairan (air) di otak dan memantulkannya ke dalam gambar.

Artikel ini terdiri dari 2 bagian utama:

  1. Pelatihan model Segmentasi Atlas pada gambar dan markup DWI.
  2. Membangun layanan yang akan menerima file nifti dan mengembalikan kontur aspek (prediksi model).

Dalam dataset, kami menggunakan gambar DWI b0. Setiap file memiliki format [512, 512, 20], yaitu 20 irisan model 3D otak. Kami mengirisnya menjadi 20 gambar otak 2D dengan markup untuk setiap irisan yang diberi label oleh dokter. Secara keseluruhan kami memiliki 170 gambar irisan niftis dan 170 markup. Gambar DWI adalah masukan dan markup adalah target.

Irisan Gambar DWI dan markup dari irisan sekarang.

Preprocessing

Setiap 'target' diproses sebelumnya oleh One-Hot Encoding dan disimpan sebagai [11, 512, 512] ndarray 3D, di mana masing-masing dari 11 lapisan mewakili kelas tertentu (latar belakang dan atlas). Untuk augmentasi gambar, kami menerapkan normalisasi gambar ImageNet dan mengonversinya menjadi Tensor.

def make_onehot_markup(im: Image) -> np.ndarray:

  '''Returns ndarray of (11, 512, 512) in one-hot encoding format
  
  Args:
      im (Image): Image object
  Returns:
      one_hot (np.ndarray): 3D one-hot encoding tensor
  
  '''
  red = process_mask_red(im)
  brown = process_mask_brown(im)
  blue = process_mask_blue(im)
  yellow = process_mask_yellow(im)
  cyan = process_mask_cyan(im)
  mag = process_mask_magenta(im)
  dblue = process_mask_dblue(im)
  green = process_mask_green(im)
  orange = process_mask_orange(im)
  purple = process_mask_purple(im)
  
  background = np.logical_not(np.sum(matrix, axis=0))

  matrix = np.stack([background, red, brown, blue, yellow, cyan,
                    mag, dblue, green, orange, purple])
  
  return matrix

Kami melatih model segmentasi Unet dengan 11 kelas (10 atlas dan 1 kelas untuk latar belakang). Untuk penyetelan hyperparameter, kami menggunakan Optuna yang menyarankan pengoptimal dan penjadwal dan mencoba memaksimalkan skor ketukan validasi. Panggilan balik menyimpan 3 model terbaik teratas di setiap lipatan.

def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
  
  '''Chooses one of three optimizers and trains with this variant.
  '''

  model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")

  optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
  optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)

  MODEL_DIR = 'logs/optuna_pl'
  os.makedirs(MODEL_DIR, exist_ok=True)
  
  checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")

  logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')

  # Learner is a custom class inherited from LightningModule in pytorch_lightning
  lightning_model = Learner(
          dataloaders=dataloaders,
          model=model,
          evaluator=evaluator,
          summury_writer=summury_writer,
          optimizer=optimizer,
          scheduler=scheduler,
          config=myconfig)

  trainer = pl.Trainer(
          gpus=[0],
          max_epochs=10,
          num_sanity_val_steps=0,
          log_every_n_steps=2,
          logger=logger,
          gradient_clip_val=0.1,
          accumulate_grad_batches=4,
          precision=16,
          callbacks=checkpoint_callback)

  trainer.fit(lightning_model)
  return trainer.callback_metrics["val_best_score"].item()

def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):

  study = optuna.create_study(study_name='optuna_test',direction="maximize")
  study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
  
  df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
  df.to_csv('optuna_test.csv')

  print("Number of finished trials: {}".format(len(study.trials)))
  print("Best trial:")
  trial = study.best_trial
  print("Value: ", trial.value)
  print("Parameteres: ", trial.params)

Layanan memiliki struktur berikut:

  1. Menerima pesan dari backend.
  2. Unduh file nifti dari server s3 dan simpan sementara di folder lokal.
  3. Membagi file nifti 3D menjadi irisan 2D.
  4. def build_3d_slices(filename: str) -> List[np.ndarray]:
    
      ''' Loads nifti from filename and returns list of 3-channel slices.
      '''
    
      nifti_data = nib.load(filename).get_fdata()
      nifti_data = np.rot90(nifti_data)
      list_of_slices = []
      for idx in range(nifti_data.shape[-1]):
          slice_img = nifti_data[:, :, idx]
          pil_slice_img = Image.fromarray(slice_img).convert("RGB")
          res = np.asarray(pil_slice_img)
          list_of_slices.append(res)
      return list_of_slices
    

    def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
      
      ''' Returns a dict with augmented images. 
      '''
    
      list_of_slices = build_3d_slices(filename)
      augmentations = get_valid_transforms()
      result = {}
      for i, img in enumerate(list_of_slices):
          sample = augmentations(image=img)
          preprocessed_image = sample['image']
          result[i] = preprocessed_image
      return result
    

    def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
    
      ''' Returns masks of each slice of the nifti.
      '''
    
      result = {}
      for i in range(len(data)):
          preprocessed_image = data[i]
          with torch.no_grad():
              preprocessed_image = preprocessed_image.to(self.device)
              mask = self.model(preprocessed_image.unsqueeze(0))
          mask_soft = torch.softmax(mask, dim=1)
          pr_mask = mask_soft
          pr_mask = pr_mask > self.threshold_prob
          if pr_mask.sum() < self.threshold_area:
              pr_mask = torch.zeros_like(pr_mask)
          pr_mask = pr_mask.type(torch.uint8)
          pr_mask = (pr_mask.squeeze().cpu().numpy())
          pr_mask = pr_mask.astype(np.uint8)
          pr_mask = pr_mask[1:, :, :]
          pr_mask = pr_mask.astype(np.uint8)*255
          result[i] = pr_mask
          
      return result
    

    '''Dict for each slice, the result stored in JSON'''
      conts = {
                  'm1': {'r': [], 'l': []},
                  'm2': {'r': [], 'l': []},
                  'm3': {'r': [], 'l': []},
                  'm4': {'r': [], 'l': []},
                  'm5': {'r': [], 'l': []},
                  'm6': {'r': [], 'l': []},
                  'i': {'r': [], 'l': []},
                  'l': {'r': [], 'l': []},
                  'c': {'r': [], 'l': []},
                  'ic': {'r': [], 'l': []}
              }