Phân đoạn bản đồ trên hình ảnh DWI bằng Deep Learning
Mục lục:
- Giới thiệu
- Dữ liệu
- Sơ chế
- Tập huấn
- tòa nhà dịch vụ
- Kết quả
Một trong những nhiệm vụ có thể giải quyết bằng Deep Learning là phân đoạn bản đồ, nhằm hỗ trợ các bác sĩ phát hiện những bất thường ở một khu vực nhất định của não bộ và giảm thiểu yếu tố con người. Vì có nhiều loại quét khác nhau (chẳng hạn như MRI, CTA, CT, v.v.), nên cũng sẽ có nhiều loại mô hình để giải quyết nhiệm vụ cho từng loại quét. Trong bài viết này, chúng tôi sẽ đề cập đến quá trình xây dựng DWI Atlas Service.
Trước khi đi sâu vào quy trình, hãy mô tả một số định nghĩa:
Atlas of ASPECTS — các vùng của não bộ. Atlas gồm 10 vùng: m1, m2, m3, m4, m5, m6, i, l, c, ic.
NHIỆM VỤ - phép đo các tập bản đồ đã chết. Con số càng cao, càng nhiều vùng ASPECTS không hề hấn gì. Ví dụ: ASPECTS = 7 có nghĩa là 7 trong số 10 tập bản đồ vẫn bình thường, không bị đột quỵ.
DWI — một loại quét MRI. Trong MRI thông thường, có hai nam châm làm thay đổi spin của các nguyên tử hydro và nhận năng lượng phản xạ để tạo ra hình ảnh của não. Ngoài ra, DWI còn phát hiện dòng chảy của chất lỏng (nước) trong não và phản ánh nó trên hình ảnh.
Bài viết gồm 2 phần chính:
- Đào tạo mô hình Atlas Segmentation trên hình ảnh và đánh dấu DWI.
- Xây dựng một dịch vụ sẽ chấp nhận tệp nifti và trả về các đường nét của các khía cạnh (dự đoán mô hình).
Trong tập dữ liệu, chúng tôi đã sử dụng hình ảnh DWI b0. Mỗi tệp có định dạng [512, 512, 20], tức là 20 lát mô hình 3D của bộ não. Chúng tôi đã cắt nó thành 20 hình ảnh 2D của bộ não với đánh dấu cho mỗi lát do các bác sĩ dán nhãn. Nhìn chung, chúng tôi có 170 hình ảnh của niftis cắt lát và 170 đánh dấu. Hình ảnh DWI là đầu vào và phần đánh dấu là mục tiêu.
Sơ chế
Mỗi 'mục tiêu' được xử lý trước bằng Mã hóa một lần nóng và được lưu trữ dưới dạng [11, 512, 512] ndarray 3D, trong đó mỗi lớp trong số 11 lớp đại diện cho một lớp nhất định (nền và tập bản đồ). Để tăng cường hình ảnh, chúng tôi đã áp dụng chuẩn hóa hình ảnh ImageNet và chuyển đổi nó thành Tensors.
def make_onehot_markup(im: Image) -> np.ndarray:
'''Returns ndarray of (11, 512, 512) in one-hot encoding format
Args:
im (Image): Image object
Returns:
one_hot (np.ndarray): 3D one-hot encoding tensor
'''
red = process_mask_red(im)
brown = process_mask_brown(im)
blue = process_mask_blue(im)
yellow = process_mask_yellow(im)
cyan = process_mask_cyan(im)
mag = process_mask_magenta(im)
dblue = process_mask_dblue(im)
green = process_mask_green(im)
orange = process_mask_orange(im)
purple = process_mask_purple(im)
background = np.logical_not(np.sum(matrix, axis=0))
matrix = np.stack([background, red, brown, blue, yellow, cyan,
mag, dblue, green, orange, purple])
return matrix
Chúng tôi huấn luyện mô hình phân đoạn Unet gồm 11 lớp (10 tập bản đồ và 1 lớp nền). Để điều chỉnh siêu tham số, chúng tôi đã sử dụng Optuna đề xuất trình tối ưu hóa và trình lập lịch trình, đồng thời cố gắng tối đa hóa điểm nhịp xác thực. Các cuộc gọi lại đã lưu 3 mô hình tốt nhất trong mỗi lần.
def objective(trial: optuna.Trial, dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite) -> float:
'''Chooses one of three optimizers and trains with this variant.
'''
model = segmentation_models_pytorch.Unet(encoder_name="efficientnet-b3", classes=11, encoder_weights="imagenet")
optimizer_name = trial.suggest_categorical('optimizer', ['AdamW', 'Adam', 'Adagrad'])
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=0.001, weight_decay=0.000005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(15, 0.000001, True)
MODEL_DIR = 'logs/optuna_pl'
os.makedirs(MODEL_DIR, exist_ok=True)
checkpoint_callback = ModelCheckpoint(save_top_k=3, verbose=True, monitor="val_best_score", mode="max")
logger = TensorBoardLogger(save_dir=MODEL_DIR, name=f'version{lr}-{optimizer_name}')
# Learner is a custom class inherited from LightningModule in pytorch_lightning
lightning_model = Learner(
dataloaders=dataloaders,
model=model,
evaluator=evaluator,
summury_writer=summury_writer,
optimizer=optimizer,
scheduler=scheduler,
config=myconfig)
trainer = pl.Trainer(
gpus=[0],
max_epochs=10,
num_sanity_val_steps=0,
log_every_n_steps=2,
logger=logger,
gradient_clip_val=0.1,
accumulate_grad_batches=4,
precision=16,
callbacks=checkpoint_callback)
trainer.fit(lightning_model)
return trainer.callback_metrics["val_best_score"].item()
def start_optuna(dataloaders: dict, evaluator: Evaluator, summury_writer: SummaryWrite):
study = optuna.create_study(study_name='optuna_test',direction="maximize")
study.optimize(lambda trial: objective(trial, dataloaders, evaluator, summury_writer), n_trials=10, gc_after_trial=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
df.to_csv('optuna_test.csv')
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print("Value: ", trial.value)
print("Parameteres: ", trial.params)
Dịch vụ này có cấu trúc như sau:
- Chấp nhận tin nhắn từ phụ trợ.
- Tải xuống tệp nifti từ máy chủ s3 và lưu trữ tạm thời tệp đó trong một thư mục cục bộ.
- Chia tệp nifti 3D thành các lát 2D.
def build_3d_slices(filename: str) -> List[np.ndarray]:
''' Loads nifti from filename and returns list of 3-channel slices.
'''
nifti_data = nib.load(filename).get_fdata()
nifti_data = np.rot90(nifti_data)
list_of_slices = []
for idx in range(nifti_data.shape[-1]):
slice_img = nifti_data[:, :, idx]
pil_slice_img = Image.fromarray(slice_img).convert("RGB")
res = np.asarray(pil_slice_img)
list_of_slices.append(res)
return list_of_slices
def preprocess(self, filename: str) -> Dict[str, torch.Tensor]:
''' Returns a dict with augmented images.
'''
list_of_slices = build_3d_slices(filename)
augmentations = get_valid_transforms()
result = {}
for i, img in enumerate(list_of_slices):
sample = augmentations(image=img)
preprocessed_image = sample['image']
result[i] = preprocessed_image
return result
def process(self, data: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
''' Returns masks of each slice of the nifti.
'''
result = {}
for i in range(len(data)):
preprocessed_image = data[i]
with torch.no_grad():
preprocessed_image = preprocessed_image.to(self.device)
mask = self.model(preprocessed_image.unsqueeze(0))
mask_soft = torch.softmax(mask, dim=1)
pr_mask = mask_soft
pr_mask = pr_mask > self.threshold_prob
if pr_mask.sum() < self.threshold_area:
pr_mask = torch.zeros_like(pr_mask)
pr_mask = pr_mask.type(torch.uint8)
pr_mask = (pr_mask.squeeze().cpu().numpy())
pr_mask = pr_mask.astype(np.uint8)
pr_mask = pr_mask[1:, :, :]
pr_mask = pr_mask.astype(np.uint8)*255
result[i] = pr_mask
return result
'''Dict for each slice, the result stored in JSON'''
conts = {
'm1': {'r': [], 'l': []},
'm2': {'r': [], 'l': []},
'm3': {'r': [], 'l': []},
'm4': {'r': [], 'l': []},
'm5': {'r': [], 'l': []},
'm6': {'r': [], 'l': []},
'i': {'r': [], 'l': []},
'l': {'r': [], 'l': []},
'c': {'r': [], 'l': []},
'ic': {'r': [], 'l': []}
}

![Dù sao thì một danh sách được liên kết là gì? [Phần 1]](https://post.nghiatu.com/assets/images/m/max/724/1*Xokk6XOjWyIGCBujkJsCzQ.jpeg)



































