Segmentação de Atlas em Imagens DWI usando Deep Learning
Índice:
- Introdução
- Dados
- Pré-processando
- Treinamento
- edifício de serviços
- Resultados
Uma das tarefas que podem ser resolvidas com o Deep Learning é a segmentação do atlas, para auxiliar os médicos a detectar anomalias em uma determinada área do cérebro e reduzir os fatores humanos. Como existem diferentes tipos de varredura (como ressonância magnética, CTA, tomografia computadorizada, etc.), também haverá vários tipos de modelos para resolver a tarefa de cada tipo de varredura. Neste artigo, abordaremos o processo de criação do DWI Atlas Service.
Antes de mergulharmos no processo, vamos descrever algumas definições:
Atlas de ASPECTOS — as regiões do cérebro. O Atlas consiste em 10 regiões: m1, m2, m3, m4, m5, m6, i, l, c, ic.

ASPECTOS — a medição de atlas falecidos. Quanto maior o número, mais zonas de ASPECTS permaneceram ilesas. Por exemplo, ASPECTOS = 7 significa que 7 de 10 atlas permaneceram normais, não tocados por acidente vascular cerebral.
DWI — um tipo de varredura de ressonância magnética. Na ressonância magnética usual, existem dois ímãs que alteram o spin dos átomos de hidrogênio e recebem a energia refletida para construir a imagem do cérebro. Além disso, o DWI também detecta o fluxo do líquido (água) no cérebro e o reflete na imagem.
O artigo é composto por 2 partes principais:
- Modelo de segmentação Atlas de treinamento em imagens e marcações DWI.
- Construir um serviço que aceitará um arquivo nifti e retornará contornos de aspectos (previsão de modelo).
No conjunto de dados, usamos imagens DWI b0. Cada arquivo tinha formato [512, 512, 20], ou seja, 20 fatias de um modelo 3D do cérebro. Nós o cortamos em 20 imagens 2D do cérebro com marcação para cada fatia rotulada pelos médicos. No geral, tivemos 170 imagens de niftis fatiadas e 170 marcações. As imagens DWI eram a entrada e as marcações eram os alvos.


Pré-processando
Cada 'alvo' foi pré-processado por One-Hot Encoding e armazenado como [11, 512, 512] 3D ndarray, onde cada uma das 11 camadas representava uma determinada classe (fundo e os atlas). Para o aumento da imagem, aplicamos a normalização da imagem ImageNet e a convertemos em Tensores.
def make_onehot_markup(im: Image) -> np.ndarray:
'''Returns ndarray of (11, 512, 512) in one-hot encoding format
Args:
im (Image): Image object
Returns:
one_hot (np.ndarray): 3D one-hot encoding tensor
'''
red = process_mask_red(im)
brown = process_mask_brown(im)
blue = process_mask_blue(im)
yellow = process_mask_yellow(im)
cyan = process_mask_cyan(im)
mag = process_mask_magenta(im)
dblue = process_mask_dblue(im)
green = process_mask_green(im)
orange = process_mask_orange(im)
purple = process_mask_purple(im)
background = np.logical_not(np.sum(matrix, axis=0))
matrix = np.stack([background, red, brown, blue, yellow, cyan,
mag, dblue, green, orange, purple])
return matrix
Treinamos o modelo de segmentação Unet com 11 aulas (10 atlas e 1 aula de fundo). Para o ajuste de hiperparâmetros, usamos o Optuna, que sugeriu um otimizador e um agendador e tentamos maximizar a pontuação do batimento de validação. Os callbacks salvaram os 3 melhores modelos em cada dobra.
def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
'''Chooses one of three optimizers and trains with this variant.
'''
model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")
optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)
MODEL_DIR = 'logs/optuna_pl'
os.makedirs(MODEL_DIR, exist_ok=True)
checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")
logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')
# Learner is a custom class inherited from LightningModule in pytorch_lightning
lightning_model = Learner(
dataloaders=dataloaders,
model=model,
evaluator=evaluator,
summury_writer=summury_writer,
optimizer=optimizer,
scheduler=scheduler,
config=myconfig)
trainer = pl.Trainer(
gpus=[0],
max_epochs=10,
num_sanity_val_steps=0,
log_every_n_steps=2,
logger=logger,
gradient_clip_val=0.1,
accumulate_grad_batches=4,
precision=16,
callbacks=checkpoint_callback)
trainer.fit(lightning_model)
return trainer.callback_metrics["val_best_score"].item()
def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):
study = optuna.create_study(study_name='optuna_test',direction="maximize")
study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
df.to_csv('optuna_test.csv')
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print("Value: ", trial.value)
print("Parameteres: ", trial.params)
O serviço tinha a seguinte estrutura:
- Aceita a mensagem do back-end.
- Baixa o arquivo nifti do servidor s3 e o armazena temporariamente em uma pasta local.
- Divide o arquivo nifti 3D em fatias 2D.
def build_3d_slices(filename: str) -> List[np.ndarray]:
''' Loads nifti from filename and returns list of 3-channel slices.
'''
nifti_data = nib.load(filename).get_fdata()
nifti_data = np.rot90(nifti_data)
list_of_slices = []
for idx in range(nifti_data.shape[-1]):
slice_img = nifti_data[:, :, idx]
pil_slice_img = Image.fromarray(slice_img).convert("RGB")
res = np.asarray(pil_slice_img)
list_of_slices.append(res)
return list_of_slices
def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
''' Returns a dict with augmented images.
'''
list_of_slices = build_3d_slices(filename)
augmentations = get_valid_transforms()
result = {}
for i, img in enumerate(list_of_slices):
sample = augmentations(image=img)
preprocessed_image = sample['image']
result[i] = preprocessed_image
return result
def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
''' Returns masks of each slice of the nifti.
'''
result = {}
for i in range(len(data)):
preprocessed_image = data[i]
with torch.no_grad():
preprocessed_image = preprocessed_image.to(self.device)
mask = self.model(preprocessed_image.unsqueeze(0))
mask_soft = torch.softmax(mask, dim=1)
pr_mask = mask_soft
pr_mask = pr_mask > self.threshold_prob
if pr_mask.sum() < self.threshold_area:
pr_mask = torch.zeros_like(pr_mask)
pr_mask = pr_mask.type(torch.uint8)
pr_mask = (pr_mask.squeeze().cpu().numpy())
pr_mask = pr_mask.astype(np.uint8)
pr_mask = pr_mask[1:, :, :]
pr_mask = pr_mask.astype(np.uint8)*255
result[i] = pr_mask
return result
'''Dict for each slice, the result stored in JSON'''
conts = {
'm1': {'r': [], 'l': []},
'm2': {'r': [], 'l': []},
'm3': {'r': [], 'l': []},
'm4': {'r': [], 'l': []},
'm5': {'r': [], 'l': []},
'm6': {'r': [], 'l': []},
'i': {'r': [], 'l': []},
'l': {'r': [], 'l': []},
'c': {'r': [], 'l': []},
'ic': {'r': [], 'l': []}
}
