Segmentacja atlasu na obrazach DWI przy użyciu głębokiego uczenia

Nov 30 2022
Napisane przez Asminę Barkhandinovą, Cerebra.ai Ltd.
Spis treści: Jednym z zadań, które można rozwiązać za pomocą głębokiego uczenia, jest segmentacja atlasu, aby pomóc lekarzom wykryć anomalie w określonym obszarze mózgu i zredukować czynnik ludzki. Ponieważ istnieją różne rodzaje skanowania (takie jak MRI, CTA, CT itp.

Spis treści:

  1. Wstęp
  2. Dane
  3. Przetwarzanie wstępne
  4. Trening
  5. Budynek usługowy
  6. Wyniki

Jednym z zadań, które można rozwiązać za pomocą głębokiego uczenia się, jest segmentacja atlasu, aby pomóc lekarzom wykryć anomalie w określonym obszarze mózgu i zredukować czynnik ludzki. Ponieważ istnieją różne rodzaje skanowania (takie jak MRI, CTA, CT itp.), będzie również wiele typów modeli do rozwiązania zadania dla każdego typu skanowania. W tym artykule omówimy proces budowania usługi Atlas DWI.

Zanim zagłębimy się w ten proces, opiszmy kilka definicji:

Atlas ASPEKTÓW — regionów mózgu. Atlas składa się z 10 regionów: m1, m2, m3, m4, m5, m6, i, l, c, ic.

Źródło: radipopedia.org

ASPEKTY — pomiar atlasów zmarłych. Im wyższa liczba, tym więcej stref ASPEKTÓW pozostało nietkniętych. Np. ASPEKT = 7 oznacza, że ​​7 z 10 atlasów pozostało normalnych, nietkniętych uderzeniem.

DWI — rodzaj skanowania MRI. W zwykłym rezonansie magnetycznym są dwa magnesy, które zmieniają spin atomów wodoru i odbierają odbitą energię, aby zbudować obraz mózgu. Oprócz tego DWI wykrywa również przepływ płynu (wody) w mózgu i odzwierciedla go na obrazie.

Artykuł składa się z 2 głównych części:

  1. Szkolenie modelu segmentacji Atlasu na obrazach i znacznikach DWI.
  2. Zbudowanie usługi, która zaakceptuje plik nifti i zwróci kontury aspektów (predykcja modelu).

W zbiorze danych wykorzystaliśmy obrazy DWI b0. Każdy plik miał format [512, 512, 20], czyli 20 wycinków trójwymiarowego modelu mózgu. Pokroiliśmy go na 20 obrazów 2D mózgu ze znacznikami dla każdego wycinka oznaczonymi przez lekarzy. Ogółem mieliśmy 170 zdjęć pokrojonych nifti i 170 znaczników. Obrazy DWI były danymi wejściowymi, a znaczniki były celami.

Wycinek obrazu DWI i znaczniki obecnego wycinka.

Przetwarzanie wstępne

Każdy „cel” został wstępnie przetworzony przez One-Hot Encoding i zapisany jako [11, 512, 512] 3D ndarray, gdzie każda z 11 warstw reprezentowała określoną klasę (tło i atlasy). Do powiększenia obrazu zastosowaliśmy normalizację obrazu ImageNet i przekonwertowaliśmy go na tensory.

def make_onehot_markup(im: Image) -> np.ndarray:

  '''Returns ndarray of (11, 512, 512) in one-hot encoding format
  
  Args:
      im (Image): Image object
  Returns:
      one_hot (np.ndarray): 3D one-hot encoding tensor
  
  '''
  red = process_mask_red(im)
  brown = process_mask_brown(im)
  blue = process_mask_blue(im)
  yellow = process_mask_yellow(im)
  cyan = process_mask_cyan(im)
  mag = process_mask_magenta(im)
  dblue = process_mask_dblue(im)
  green = process_mask_green(im)
  orange = process_mask_orange(im)
  purple = process_mask_purple(im)
  
  background = np.logical_not(np.sum(matrix, axis=0))

  matrix = np.stack([background, red, brown, blue, yellow, cyan,
                    mag, dblue, green, orange, purple])
  
  return matrix

Przeszkoliliśmy model segmentacji Unet z 11 klasami (10 atlasów i 1 klasa dla tła). Do strojenia hiperparametrów użyliśmy Optuny, która zasugerowała optymalizator i harmonogram i próbowała zmaksymalizować wynik beatu walidacji. Wywołania zwrotne uratowały 3 najlepsze modele w każdej zakładce.

def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
  
  '''Chooses one of three optimizers and trains with this variant.
  '''

  model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")

  optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
  optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)

  MODEL_DIR = 'logs/optuna_pl'
  os.makedirs(MODEL_DIR, exist_ok=True)
  
  checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")

  logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')

  # Learner is a custom class inherited from LightningModule in pytorch_lightning
  lightning_model = Learner(
          dataloaders=dataloaders,
          model=model,
          evaluator=evaluator,
          summury_writer=summury_writer,
          optimizer=optimizer,
          scheduler=scheduler,
          config=myconfig)

  trainer = pl.Trainer(
          gpus=[0],
          max_epochs=10,
          num_sanity_val_steps=0,
          log_every_n_steps=2,
          logger=logger,
          gradient_clip_val=0.1,
          accumulate_grad_batches=4,
          precision=16,
          callbacks=checkpoint_callback)

  trainer.fit(lightning_model)
  return trainer.callback_metrics["val_best_score"].item()

def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):

  study = optuna.create_study(study_name='optuna_test',direction="maximize")
  study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
  
  df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
  df.to_csv('optuna_test.csv')

  print("Number of finished trials: {}".format(len(study.trials)))
  print("Best trial:")
  trial = study.best_trial
  print("Value: ", trial.value)
  print("Parameteres: ", trial.params)

Usługa miała następującą strukturę:

  1. Akceptuje wiadomość z zaplecza.
  2. Pobiera plik nifti z serwera s3 i tymczasowo zapisuje go w folderze lokalnym.
  3. Dzieli plik 3D nifti na plasterki 2D.
  4. def build_3d_slices(filename: str) -> List[np.ndarray]:
    
      ''' Loads nifti from filename and returns list of 3-channel slices.
      '''
    
      nifti_data = nib.load(filename).get_fdata()
      nifti_data = np.rot90(nifti_data)
      list_of_slices = []
      for idx in range(nifti_data.shape[-1]):
          slice_img = nifti_data[:, :, idx]
          pil_slice_img = Image.fromarray(slice_img).convert("RGB")
          res = np.asarray(pil_slice_img)
          list_of_slices.append(res)
      return list_of_slices
    

    def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
      
      ''' Returns a dict with augmented images. 
      '''
    
      list_of_slices = build_3d_slices(filename)
      augmentations = get_valid_transforms()
      result = {}
      for i, img in enumerate(list_of_slices):
          sample = augmentations(image=img)
          preprocessed_image = sample['image']
          result[i] = preprocessed_image
      return result
    

    def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
    
      ''' Returns masks of each slice of the nifti.
      '''
    
      result = {}
      for i in range(len(data)):
          preprocessed_image = data[i]
          with torch.no_grad():
              preprocessed_image = preprocessed_image.to(self.device)
              mask = self.model(preprocessed_image.unsqueeze(0))
          mask_soft = torch.softmax(mask, dim=1)
          pr_mask = mask_soft
          pr_mask = pr_mask > self.threshold_prob
          if pr_mask.sum() < self.threshold_area:
              pr_mask = torch.zeros_like(pr_mask)
          pr_mask = pr_mask.type(torch.uint8)
          pr_mask = (pr_mask.squeeze().cpu().numpy())
          pr_mask = pr_mask.astype(np.uint8)
          pr_mask = pr_mask[1:, :, :]
          pr_mask = pr_mask.astype(np.uint8)*255
          result[i] = pr_mask
          
      return result
    

    '''Dict for each slice, the result stored in JSON'''
      conts = {
                  'm1': {'r': [], 'l': []},
                  'm2': {'r': [], 'l': []},
                  'm3': {'r': [], 'l': []},
                  'm4': {'r': [], 'l': []},
                  'm5': {'r': [], 'l': []},
                  'm6': {'r': [], 'l': []},
                  'i': {'r': [], 'l': []},
                  'l': {'r': [], 'l': []},
                  'c': {'r': [], 'l': []},
                  'ic': {'r': [], 'l': []}
              }