Segmentación de Atlas en imágenes DWI usando Deep Learning

Nov 30 2022
Escrito por Asmina Barkhandinova, Cerebra.ai Ltd.
Tabla de contenido: Una de las tareas que se pueden resolver con Deep Learning es la segmentación del atlas, con el fin de ayudar a los médicos a detectar anomalías en una determinada área del cerebro y reducir los factores humanos. Dado que existen diferentes tipos de exploración (como MRI, CTA, CT, etc.

Tabla de contenidos:

  1. Introducción
  2. Datos
  3. preprocesamiento
  4. Capacitación
  5. Edificio de servicios
  6. Resultados

Una de las tareas que se pueden resolver con Deep Learning es la segmentación del atlas, con el fin de ayudar a los médicos a detectar anomalías en una determinada área del cerebro y reducir los factores humanos. Dado que existen diferentes tipos de exploración (como MRI, CTA, CT, etc.), también habrá varios tipos de modelos para resolver la tarea de cada tipo de exploración. En este artículo, cubriremos el proceso de creación de DWI Atlas Service.

Antes de sumergirnos en el proceso, describamos algunas definiciones:

Atlas de ASPECTOS : las regiones del cerebro. Atlas consta de 10 regiones: m1, m2, m3, m4, m5, m6, i, l, c, ic.

Fuente: radiopopedia.org

ASPECTOS — la medición de atlas fallecidos. Cuanto mayor sea el número, más zonas de ASPECTOS permanecerán ilesas. Por ejemplo, ASPECTOS = 7 significa que 7 de cada 10 atlas permanecieron normales, no afectados por un accidente cerebrovascular.

DWI : un tipo de resonancia magnética. En la resonancia magnética habitual, hay dos imanes que cambian el giro de los átomos de hidrógeno y reciben la energía reflejada para construir la imagen del cerebro. Además de eso, DWI también detecta el flujo del líquido (agua) en el cerebro y lo refleja en la imagen.

El artículo consta de 2 partes principales:

  1. Capacitación del modelo de segmentación de Atlas en imágenes y marcas DWI.
  2. Creación de un servicio que acepte un archivo nifti y devuelva contornos de aspectos (predicción del modelo).

En el conjunto de datos, usamos imágenes DWI b0. Cada archivo tenía formato [512, 512, 20], es decir, 20 cortes de un modelo 3D del cerebro. Lo cortamos en 20 imágenes 2D del cerebro con marcas para cada corte etiquetadas por los médicos. En general, teníamos 170 imágenes de niftis en rodajas y 170 marcas. Las imágenes DWI fueron la entrada y las marcas fueron los objetivos.

Segmento de imagen DWI y el marcado del segmento actual.

preprocesamiento

Cada 'objetivo' fue preprocesado por One-Hot Encoding y almacenado como [11, 512, 512] 3D ndarray, donde cada una de las 11 capas representaba una determinada clase (fondo y atlas). Para el aumento de imágenes, aplicamos la normalización de imágenes ImageNet y la convertimos a Tensores.

def make_onehot_markup(im: Image) -> np.ndarray:

  '''Returns ndarray of (11, 512, 512) in one-hot encoding format
  
  Args:
      im (Image): Image object
  Returns:
      one_hot (np.ndarray): 3D one-hot encoding tensor
  
  '''
  red = process_mask_red(im)
  brown = process_mask_brown(im)
  blue = process_mask_blue(im)
  yellow = process_mask_yellow(im)
  cyan = process_mask_cyan(im)
  mag = process_mask_magenta(im)
  dblue = process_mask_dblue(im)
  green = process_mask_green(im)
  orange = process_mask_orange(im)
  purple = process_mask_purple(im)
  
  background = np.logical_not(np.sum(matrix, axis=0))

  matrix = np.stack([background, red, brown, blue, yellow, cyan,
                    mag, dblue, green, orange, purple])
  
  return matrix

Entrenamos el modelo de segmentación de Unet con 11 clases (10 atlas y 1 clase de fondo). Para el ajuste de hiperparámetros, utilizamos Optuna, que sugirió un optimizador y un programador y trató de maximizar la puntuación de latido de validación. Las devoluciones de llamada guardaron los 3 mejores modelos en cada pliegue.

def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
  
  '''Chooses one of three optimizers and trains with this variant.
  '''

  model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")

  optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
  optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)

  MODEL_DIR = 'logs/optuna_pl'
  os.makedirs(MODEL_DIR, exist_ok=True)
  
  checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")

  logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')

  # Learner is a custom class inherited from LightningModule in pytorch_lightning
  lightning_model = Learner(
          dataloaders=dataloaders,
          model=model,
          evaluator=evaluator,
          summury_writer=summury_writer,
          optimizer=optimizer,
          scheduler=scheduler,
          config=myconfig)

  trainer = pl.Trainer(
          gpus=[0],
          max_epochs=10,
          num_sanity_val_steps=0,
          log_every_n_steps=2,
          logger=logger,
          gradient_clip_val=0.1,
          accumulate_grad_batches=4,
          precision=16,
          callbacks=checkpoint_callback)

  trainer.fit(lightning_model)
  return trainer.callback_metrics["val_best_score"].item()

def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):

  study = optuna.create_study(study_name='optuna_test',direction="maximize")
  study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
  
  df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
  df.to_csv('optuna_test.csv')

  print("Number of finished trials: {}".format(len(study.trials)))
  print("Best trial:")
  trial = study.best_trial
  print("Value: ", trial.value)
  print("Parameteres: ", trial.params)

El servicio tenía la siguiente estructura:

  1. Acepta el mensaje del backend.
  2. Descarga el archivo nifti del servidor s3 y lo almacena temporalmente en una carpeta local.
  3. Divide el archivo nifti 3D en segmentos 2D.
  4. def build_3d_slices(filename: str) -> List[np.ndarray]:
    
      ''' Loads nifti from filename and returns list of 3-channel slices.
      '''
    
      nifti_data = nib.load(filename).get_fdata()
      nifti_data = np.rot90(nifti_data)
      list_of_slices = []
      for idx in range(nifti_data.shape[-1]):
          slice_img = nifti_data[:, :, idx]
          pil_slice_img = Image.fromarray(slice_img).convert("RGB")
          res = np.asarray(pil_slice_img)
          list_of_slices.append(res)
      return list_of_slices
    

    def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
      
      ''' Returns a dict with augmented images. 
      '''
    
      list_of_slices = build_3d_slices(filename)
      augmentations = get_valid_transforms()
      result = {}
      for i, img in enumerate(list_of_slices):
          sample = augmentations(image=img)
          preprocessed_image = sample['image']
          result[i] = preprocessed_image
      return result
    

    def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
    
      ''' Returns masks of each slice of the nifti.
      '''
    
      result = {}
      for i in range(len(data)):
          preprocessed_image = data[i]
          with torch.no_grad():
              preprocessed_image = preprocessed_image.to(self.device)
              mask = self.model(preprocessed_image.unsqueeze(0))
          mask_soft = torch.softmax(mask, dim=1)
          pr_mask = mask_soft
          pr_mask = pr_mask > self.threshold_prob
          if pr_mask.sum() < self.threshold_area:
              pr_mask = torch.zeros_like(pr_mask)
          pr_mask = pr_mask.type(torch.uint8)
          pr_mask = (pr_mask.squeeze().cpu().numpy())
          pr_mask = pr_mask.astype(np.uint8)
          pr_mask = pr_mask[1:, :, :]
          pr_mask = pr_mask.astype(np.uint8)*255
          result[i] = pr_mask
          
      return result
    

    '''Dict for each slice, the result stored in JSON'''
      conts = {
                  'm1': {'r': [], 'l': []},
                  'm2': {'r': [], 'l': []},
                  'm3': {'r': [], 'l': []},
                  'm4': {'r': [], 'l': []},
                  'm5': {'r': [], 'l': []},
                  'm6': {'r': [], 'l': []},
                  'i': {'r': [], 'l': []},
                  'l': {'r': [], 'l': []},
                  'c': {'r': [], 'l': []},
                  'ic': {'r': [], 'l': []}
              }