Segmentação de Atlas em Imagens DWI usando Deep Learning

Nov 30 2022
Escrito por Asmina Barkhandinova, Cerebra.ai Ltd.
Tabela de conteúdo: Uma das tarefas que podem ser resolvidas com o Deep Learning é a segmentação do atlas, a fim de auxiliar os médicos a detectar anomalias em uma determinada área do cérebro e reduzir os fatores humanos. Uma vez que existem diferentes tipos de digitalização (como ressonância magnética, CTA, CT, etc.

Índice:

  1. Introdução
  2. Dados
  3. Pré-processando
  4. Treinamento
  5. edifício de serviços
  6. Resultados

Uma das tarefas que podem ser resolvidas com o Deep Learning é a segmentação do atlas, para auxiliar os médicos a detectar anomalias em uma determinada área do cérebro e reduzir os fatores humanos. Como existem diferentes tipos de varredura (como ressonância magnética, CTA, tomografia computadorizada, etc.), também haverá vários tipos de modelos para resolver a tarefa de cada tipo de varredura. Neste artigo, abordaremos o processo de criação do DWI Atlas Service.

Antes de mergulharmos no processo, vamos descrever algumas definições:

Atlas de ASPECTOS — as regiões do cérebro. O Atlas consiste em 10 regiões: m1, m2, m3, m4, m5, m6, i, l, c, ic.

Fonte: radipopedia.org

ASPECTOS — a medição de atlas falecidos. Quanto maior o número, mais zonas de ASPECTS permaneceram ilesas. Por exemplo, ASPECTOS = 7 significa que 7 de 10 atlas permaneceram normais, não tocados por acidente vascular cerebral.

DWI — um tipo de varredura de ressonância magnética. Na ressonância magnética usual, existem dois ímãs que alteram o spin dos átomos de hidrogênio e recebem a energia refletida para construir a imagem do cérebro. Além disso, o DWI também detecta o fluxo do líquido (água) no cérebro e o reflete na imagem.

O artigo é composto por 2 partes principais:

  1. Modelo de segmentação Atlas de treinamento em imagens e marcações DWI.
  2. Construir um serviço que aceitará um arquivo nifti e retornará contornos de aspectos (previsão de modelo).

No conjunto de dados, usamos imagens DWI b0. Cada arquivo tinha formato [512, 512, 20], ou seja, 20 fatias de um modelo 3D do cérebro. Nós o cortamos em 20 imagens 2D do cérebro com marcação para cada fatia rotulada pelos médicos. No geral, tivemos 170 imagens de niftis fatiadas e 170 marcações. As imagens DWI eram a entrada e as marcações eram os alvos.

Fatia de imagem DWI e a marcação da fatia atual.

Pré-processando

Cada 'alvo' foi pré-processado por One-Hot Encoding e armazenado como [11, 512, 512] 3D ndarray, onde cada uma das 11 camadas representava uma determinada classe (fundo e os atlas). Para o aumento da imagem, aplicamos a normalização da imagem ImageNet e a convertemos em Tensores.

def make_onehot_markup(im: Image) -> np.ndarray:

  '''Returns ndarray of (11, 512, 512) in one-hot encoding format
  
  Args:
      im (Image): Image object
  Returns:
      one_hot (np.ndarray): 3D one-hot encoding tensor
  
  '''
  red = process_mask_red(im)
  brown = process_mask_brown(im)
  blue = process_mask_blue(im)
  yellow = process_mask_yellow(im)
  cyan = process_mask_cyan(im)
  mag = process_mask_magenta(im)
  dblue = process_mask_dblue(im)
  green = process_mask_green(im)
  orange = process_mask_orange(im)
  purple = process_mask_purple(im)
  
  background = np.logical_not(np.sum(matrix, axis=0))

  matrix = np.stack([background, red, brown, blue, yellow, cyan,
                    mag, dblue, green, orange, purple])
  
  return matrix

Treinamos o modelo de segmentação Unet com 11 aulas (10 atlas e 1 aula de fundo). Para o ajuste de hiperparâmetros, usamos o Optuna, que sugeriu um otimizador e um agendador e tentamos maximizar a pontuação do batimento de validação. Os callbacks salvaram os 3 melhores modelos em cada dobra.

def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
  
  '''Chooses one of three optimizers and trains with this variant.
  '''

  model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")

  optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
  optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)

  MODEL_DIR = 'logs/optuna_pl'
  os.makedirs(MODEL_DIR, exist_ok=True)
  
  checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")

  logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')

  # Learner is a custom class inherited from LightningModule in pytorch_lightning
  lightning_model = Learner(
          dataloaders=dataloaders,
          model=model,
          evaluator=evaluator,
          summury_writer=summury_writer,
          optimizer=optimizer,
          scheduler=scheduler,
          config=myconfig)

  trainer = pl.Trainer(
          gpus=[0],
          max_epochs=10,
          num_sanity_val_steps=0,
          log_every_n_steps=2,
          logger=logger,
          gradient_clip_val=0.1,
          accumulate_grad_batches=4,
          precision=16,
          callbacks=checkpoint_callback)

  trainer.fit(lightning_model)
  return trainer.callback_metrics["val_best_score"].item()

def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):

  study = optuna.create_study(study_name='optuna_test',direction="maximize")
  study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
  
  df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
  df.to_csv('optuna_test.csv')

  print("Number of finished trials: {}".format(len(study.trials)))
  print("Best trial:")
  trial = study.best_trial
  print("Value: ", trial.value)
  print("Parameteres: ", trial.params)

O serviço tinha a seguinte estrutura:

  1. Aceita a mensagem do back-end.
  2. Baixa o arquivo nifti do servidor s3 e o armazena temporariamente em uma pasta local.
  3. Divide o arquivo nifti 3D em fatias 2D.
  4. def build_3d_slices(filename: str) -> List[np.ndarray]:
    
      ''' Loads nifti from filename and returns list of 3-channel slices.
      '''
    
      nifti_data = nib.load(filename).get_fdata()
      nifti_data = np.rot90(nifti_data)
      list_of_slices = []
      for idx in range(nifti_data.shape[-1]):
          slice_img = nifti_data[:, :, idx]
          pil_slice_img = Image.fromarray(slice_img).convert("RGB")
          res = np.asarray(pil_slice_img)
          list_of_slices.append(res)
      return list_of_slices
    

    def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
      
      ''' Returns a dict with augmented images. 
      '''
    
      list_of_slices = build_3d_slices(filename)
      augmentations = get_valid_transforms()
      result = {}
      for i, img in enumerate(list_of_slices):
          sample = augmentations(image=img)
          preprocessed_image = sample['image']
          result[i] = preprocessed_image
      return result
    

    def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
    
      ''' Returns masks of each slice of the nifti.
      '''
    
      result = {}
      for i in range(len(data)):
          preprocessed_image = data[i]
          with torch.no_grad():
              preprocessed_image = preprocessed_image.to(self.device)
              mask = self.model(preprocessed_image.unsqueeze(0))
          mask_soft = torch.softmax(mask, dim=1)
          pr_mask = mask_soft
          pr_mask = pr_mask > self.threshold_prob
          if pr_mask.sum() < self.threshold_area:
              pr_mask = torch.zeros_like(pr_mask)
          pr_mask = pr_mask.type(torch.uint8)
          pr_mask = (pr_mask.squeeze().cpu().numpy())
          pr_mask = pr_mask.astype(np.uint8)
          pr_mask = pr_mask[1:, :, :]
          pr_mask = pr_mask.astype(np.uint8)*255
          result[i] = pr_mask
          
      return result
    

    '''Dict for each slice, the result stored in JSON'''
      conts = {
                  'm1': {'r': [], 'l': []},
                  'm2': {'r': [], 'l': []},
                  'm3': {'r': [], 'l': []},
                  'm4': {'r': [], 'l': []},
                  'm5': {'r': [], 'l': []},
                  'm6': {'r': [], 'l': []},
                  'i': {'r': [], 'l': []},
                  'l': {'r': [], 'l': []},
                  'c': {'r': [], 'l': []},
                  'ic': {'r': [], 'l': []}
              }