Segmentasi Atlas pada Gambar DWI menggunakan Deep Learning
Daftar Isi:
- pengantar
- Data
- Preprocessing
- Pelatihan
- Gedung Layanan
- Hasil
Salah satu tugas yang dapat diselesaikan dengan Deep Learning adalah segmentasi atlas, untuk membantu dokter mendeteksi anomali di area otak tertentu dan mengurangi faktor manusia. Karena ada berbagai jenis pemindaian (seperti MRI, CTA, CT, dll.), akan ada juga beberapa jenis model untuk menyelesaikan tugas untuk setiap jenis pemindaian. Pada artikel ini, kami akan membahas proses pembangunan DWI Atlas Service.
Sebelum kita menyelami prosesnya, mari kita jelaskan beberapa definisi:
Atlas ASPEK — wilayah otak. Atlas terdiri dari 10 wilayah: m1, m2, m3, m4, m5, m6, i, l, c, ic.

ASPEK — pengukuran atlas yang telah meninggal. Semakin tinggi angkanya, semakin banyak zona ASPECTS yang tidak terluka. Misalnya ASPECTS = 7 berarti 7 dari 10 atlas tetap normal, tidak tersentuh stroke.
DWI — sejenis pemindaian MRI. Pada MRI biasa terdapat dua magnet yang mengubah putaran atom hidrogen dan menerima energi pantulan untuk membangun citra otak. Selain itu, DWI juga mendeteksi aliran cairan (air) di otak dan memantulkannya ke dalam gambar.
Artikel ini terdiri dari 2 bagian utama:
- Pelatihan model Segmentasi Atlas pada gambar dan markup DWI.
- Membangun layanan yang akan menerima file nifti dan mengembalikan kontur aspek (prediksi model).
Dalam dataset, kami menggunakan gambar DWI b0. Setiap file memiliki format [512, 512, 20], yaitu 20 irisan model 3D otak. Kami mengirisnya menjadi 20 gambar otak 2D dengan markup untuk setiap irisan yang diberi label oleh dokter. Secara keseluruhan kami memiliki 170 gambar irisan niftis dan 170 markup. Gambar DWI adalah masukan dan markup adalah target.


Preprocessing
Setiap 'target' diproses sebelumnya oleh One-Hot Encoding dan disimpan sebagai [11, 512, 512] ndarray 3D, di mana masing-masing dari 11 lapisan mewakili kelas tertentu (latar belakang dan atlas). Untuk augmentasi gambar, kami menerapkan normalisasi gambar ImageNet dan mengonversinya menjadi Tensor.
def make_onehot_markup(im: Image) -> np.ndarray:
'''Returns ndarray of (11, 512, 512) in one-hot encoding format
Args:
im (Image): Image object
Returns:
one_hot (np.ndarray): 3D one-hot encoding tensor
'''
red = process_mask_red(im)
brown = process_mask_brown(im)
blue = process_mask_blue(im)
yellow = process_mask_yellow(im)
cyan = process_mask_cyan(im)
mag = process_mask_magenta(im)
dblue = process_mask_dblue(im)
green = process_mask_green(im)
orange = process_mask_orange(im)
purple = process_mask_purple(im)
background = np.logical_not(np.sum(matrix, axis=0))
matrix = np.stack([background, red, brown, blue, yellow, cyan,
mag, dblue, green, orange, purple])
return matrix
Kami melatih model segmentasi Unet dengan 11 kelas (10 atlas dan 1 kelas untuk latar belakang). Untuk penyetelan hyperparameter, kami menggunakan Optuna yang menyarankan pengoptimal dan penjadwal dan mencoba memaksimalkan skor ketukan validasi. Panggilan balik menyimpan 3 model terbaik teratas di setiap lipatan.
def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
'''Chooses one of three optimizers and trains with this variant.
'''
model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")
optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)
MODEL_DIR = 'logs/optuna_pl'
os.makedirs(MODEL_DIR, exist_ok=True)
checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")
logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')
# Learner is a custom class inherited from LightningModule in pytorch_lightning
lightning_model = Learner(
dataloaders=dataloaders,
model=model,
evaluator=evaluator,
summury_writer=summury_writer,
optimizer=optimizer,
scheduler=scheduler,
config=myconfig)
trainer = pl.Trainer(
gpus=[0],
max_epochs=10,
num_sanity_val_steps=0,
log_every_n_steps=2,
logger=logger,
gradient_clip_val=0.1,
accumulate_grad_batches=4,
precision=16,
callbacks=checkpoint_callback)
trainer.fit(lightning_model)
return trainer.callback_metrics["val_best_score"].item()
def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):
study = optuna.create_study(study_name='optuna_test',direction="maximize")
study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
df.to_csv('optuna_test.csv')
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print("Value: ", trial.value)
print("Parameteres: ", trial.params)
Layanan memiliki struktur berikut:
- Menerima pesan dari backend.
- Unduh file nifti dari server s3 dan simpan sementara di folder lokal.
- Membagi file nifti 3D menjadi irisan 2D.
def build_3d_slices(filename: str) -> List[np.ndarray]:
''' Loads nifti from filename and returns list of 3-channel slices.
'''
nifti_data = nib.load(filename).get_fdata()
nifti_data = np.rot90(nifti_data)
list_of_slices = []
for idx in range(nifti_data.shape[-1]):
slice_img = nifti_data[:, :, idx]
pil_slice_img = Image.fromarray(slice_img).convert("RGB")
res = np.asarray(pil_slice_img)
list_of_slices.append(res)
return list_of_slices
def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
''' Returns a dict with augmented images.
'''
list_of_slices = build_3d_slices(filename)
augmentations = get_valid_transforms()
result = {}
for i, img in enumerate(list_of_slices):
sample = augmentations(image=img)
preprocessed_image = sample['image']
result[i] = preprocessed_image
return result
def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
''' Returns masks of each slice of the nifti.
'''
result = {}
for i in range(len(data)):
preprocessed_image = data[i]
with torch.no_grad():
preprocessed_image = preprocessed_image.to(self.device)
mask = self.model(preprocessed_image.unsqueeze(0))
mask_soft = torch.softmax(mask, dim=1)
pr_mask = mask_soft
pr_mask = pr_mask > self.threshold_prob
if pr_mask.sum() < self.threshold_area:
pr_mask = torch.zeros_like(pr_mask)
pr_mask = pr_mask.type(torch.uint8)
pr_mask = (pr_mask.squeeze().cpu().numpy())
pr_mask = pr_mask.astype(np.uint8)
pr_mask = pr_mask[1:, :, :]
pr_mask = pr_mask.astype(np.uint8)*255
result[i] = pr_mask
return result
'''Dict for each slice, the result stored in JSON'''
conts = {
'm1': {'r': [], 'l': []},
'm2': {'r': [], 'l': []},
'm3': {'r': [], 'l': []},
'm4': {'r': [], 'l': []},
'm5': {'r': [], 'l': []},
'm6': {'r': [], 'l': []},
'i': {'r': [], 'l': []},
'l': {'r': [], 'l': []},
'c': {'r': [], 'l': []},
'ic': {'r': [], 'l': []}
}
