d2f9d5b1696d723607b469880b3d5616ee5b225a5296662d3b494a9f93762c27
Browse files- .gitignore +0 -0
- 3DUnet_Like/__pycache__/trainer.cpython-39.pyc +0 -0
- 3DUnet_Like/dataset/__pycache__/utils.cpython-39.pyc +0 -0
- 3DUnet_Like/dataset/brats.py +31 -0
- 3DUnet_Like/dataset/utils.py +100 -0
- 3DUnet_Like/logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0 +0 -0
- 3DUnet_Like/logs/SegTransVAE/version_0/hparams.yaml +1 -0
- 3DUnet_Like/logs/SegTransVAE/version_0/metric_log.csv +2 -0
- 3DUnet_Like/loss/__init__.py +0 -0
- 3DUnet_Like/loss/__pycache__/__init__.cpython-39.pyc +0 -0
- 3DUnet_Like/loss/__pycache__/loss.cpython-39.pyc +0 -0
- 3DUnet_Like/loss/loss.py +55 -0
- 3DUnet_Like/models/SegTranVAE/SegTranVAE.py +538 -0
- 3DUnet_Like/models/SegTranVAE/__init__.py +0 -0
- 3DUnet_Like/models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc +0 -0
- 3DUnet_Like/models/SegTranVAE/__pycache__/__init__.cpython-39.pyc +0 -0
- 3DUnet_Like/train.py +69 -0
- 3DUnet_Like/trainer.py +233 -0
- brats_2021_task1/BraTS2021_Training_Data/.DS_Store +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1ce.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t2.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_flair.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_seg.nii.gz +0 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1.nii.gz +3 -0
- brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1ce.nii.gz +3 -0
.gitignore
ADDED
|
File without changes
|
3DUnet_Like/__pycache__/trainer.cpython-39.pyc
ADDED
|
Binary file (7.55 kB). View file
|
|
|
3DUnet_Like/dataset/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
3DUnet_Like/dataset/brats.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from monai.transforms import MapTransform
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
|
| 6 |
+
"""
|
| 7 |
+
Convert labels to multi channels based on brats classes:
|
| 8 |
+
label 1 is the necrotic and non-enhancing tumor core
|
| 9 |
+
label 2 is the peritumoral edema
|
| 10 |
+
label 4 is the GD-enhancing tumor
|
| 11 |
+
The possible classes are TC (Tumor core), WT (Whole tumor)
|
| 12 |
+
and ET (Enhancing tumor).
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __call__(self, data):
|
| 17 |
+
d = dict(data)
|
| 18 |
+
for key in self.keys:
|
| 19 |
+
result = []
|
| 20 |
+
# merge label 1 and label 4 to construct TC
|
| 21 |
+
result.append(np.logical_or(d[key] == 1, d[key] == 4))
|
| 22 |
+
# merge labels 1, 2 and 4 to construct WT
|
| 23 |
+
result.append(
|
| 24 |
+
np.logical_or(
|
| 25 |
+
np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
|
| 26 |
+
)
|
| 27 |
+
)
|
| 28 |
+
# label 4 is ET
|
| 29 |
+
result.append(d[key] == 4)
|
| 30 |
+
d[key] = np.stack(result, axis=0).astype(np.float32)
|
| 31 |
+
return d
|
3DUnet_Like/dataset/utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
|
| 5 |
+
from monai.data import DataLoader, Dataset
|
| 6 |
+
from monai import transforms
|
| 7 |
+
|
| 8 |
+
def datafold_read(datalist, basedir, fold=0, key="training"):
|
| 9 |
+
with open(datalist) as f:
|
| 10 |
+
json_data = json.load(f)
|
| 11 |
+
|
| 12 |
+
json_data = json_data[key]
|
| 13 |
+
|
| 14 |
+
for d in json_data:
|
| 15 |
+
for k in d:
|
| 16 |
+
if isinstance(d[k], list):
|
| 17 |
+
d[k] = [os.path.join(basedir, iv) for iv in d[k]]
|
| 18 |
+
elif isinstance(d[k], str):
|
| 19 |
+
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
|
| 20 |
+
|
| 21 |
+
tr = []
|
| 22 |
+
val = []
|
| 23 |
+
for d in json_data:
|
| 24 |
+
if "fold" in d and d["fold"] == fold:
|
| 25 |
+
val.append(d)
|
| 26 |
+
else:
|
| 27 |
+
tr.append(d)
|
| 28 |
+
|
| 29 |
+
return tr, val
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
|
| 33 |
+
train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
|
| 34 |
+
if volume != None :
|
| 35 |
+
train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
|
| 36 |
+
|
| 37 |
+
train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
|
| 38 |
+
|
| 39 |
+
validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
|
| 40 |
+
return train_files, validation_files, test_files
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
|
| 44 |
+
train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
|
| 45 |
+
|
| 46 |
+
train_transform = transforms.Compose(
|
| 47 |
+
[
|
| 48 |
+
transforms.LoadImaged(keys=["image", "label"]),
|
| 49 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
| 50 |
+
transforms.CropForegroundd(
|
| 51 |
+
keys=["image", "label"],
|
| 52 |
+
source_key="image",
|
| 53 |
+
k_divisible=[roi[0], roi[1], roi[2]],
|
| 54 |
+
),
|
| 55 |
+
transforms.RandSpatialCropd(
|
| 56 |
+
keys=["image", "label"],
|
| 57 |
+
roi_size=[roi[0], roi[1], roi[2]],
|
| 58 |
+
random_size=False,
|
| 59 |
+
),
|
| 60 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
| 61 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
|
| 62 |
+
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
|
| 63 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 64 |
+
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
|
| 65 |
+
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
val_transform = transforms.Compose(
|
| 69 |
+
[
|
| 70 |
+
transforms.LoadImaged(keys=["image", "label"]),
|
| 71 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
| 72 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
train_ds = Dataset(data=train_files, transform=train_transform)
|
| 77 |
+
train_loader = DataLoader(
|
| 78 |
+
train_ds,
|
| 79 |
+
batch_size=batch_size,
|
| 80 |
+
shuffle=True,
|
| 81 |
+
num_workers=2,
|
| 82 |
+
pin_memory=True,
|
| 83 |
+
)
|
| 84 |
+
val_ds = Dataset(data=validation_files, transform=val_transform)
|
| 85 |
+
val_loader = DataLoader(
|
| 86 |
+
val_ds,
|
| 87 |
+
batch_size=1,
|
| 88 |
+
shuffle=False,
|
| 89 |
+
num_workers=2,
|
| 90 |
+
pin_memory=True,
|
| 91 |
+
)
|
| 92 |
+
test_ds = Dataset(data=test_files, transform=val_transform)
|
| 93 |
+
test_loader = DataLoader(
|
| 94 |
+
test_ds,
|
| 95 |
+
batch_size=1,
|
| 96 |
+
shuffle=False,
|
| 97 |
+
num_workers=2,
|
| 98 |
+
pin_memory=True,
|
| 99 |
+
)
|
| 100 |
+
return train_loader, val_loader,test_loader
|
3DUnet_Like/logs/SegTransVAE/version_0/events.out.tfevents.1710047381.speech-demo.148199.0
ADDED
|
Binary file (117 kB). View file
|
|
|
3DUnet_Like/logs/SegTransVAE/version_0/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
3DUnet_Like/logs/SegTransVAE/version_0/metric_log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Epoch,Mean Dice Score,Dice TC,Dice WT,Dice ET
|
| 2 |
+
0,0.004601036664098501,0.0006361556006595492,0.012770041823387146,0.0003969123645219952
|
3DUnet_Like/loss/__init__.py
ADDED
|
File without changes
|
3DUnet_Like/loss/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (124 Bytes). View file
|
|
|
3DUnet_Like/loss/__pycache__/loss.cpython-39.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
3DUnet_Like/loss/loss.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
class Loss_VAE(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.mse = nn.MSELoss(reduction='sum')
|
| 8 |
+
|
| 9 |
+
def forward(self, recon_x, x, mu, log_var):
|
| 10 |
+
mse = self.mse(recon_x, x)
|
| 11 |
+
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
| 12 |
+
loss = mse + kld
|
| 13 |
+
return loss
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def DiceScore(
|
| 17 |
+
y_pred: torch.Tensor,
|
| 18 |
+
y: torch.Tensor,
|
| 19 |
+
include_background: bool = True,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
"""Computes Dice score metric from full size Tensor and collects average.
|
| 22 |
+
Args:
|
| 23 |
+
y_pred: input data to compute, typical segmentation model output.
|
| 24 |
+
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
|
| 25 |
+
should be binarized.
|
| 26 |
+
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
|
| 27 |
+
The values should be binarized.
|
| 28 |
+
include_background: whether to skip Dice computation on the first channel of
|
| 29 |
+
the predicted output. Defaults to True.
|
| 30 |
+
Returns:
|
| 31 |
+
Dice scores per batch and per class, (shape [batch_size, num_classes]).
|
| 32 |
+
Raises:
|
| 33 |
+
ValueError: when `y_pred` and `y` have different shapes.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
y = y.float()
|
| 37 |
+
y_pred = y_pred.float()
|
| 38 |
+
|
| 39 |
+
if y.shape != y_pred.shape:
|
| 40 |
+
raise ValueError("y_pred and y should have same shapes.")
|
| 41 |
+
|
| 42 |
+
# reducing only spatial dimensions (not batch nor channels)
|
| 43 |
+
n_len = len(y_pred.shape)
|
| 44 |
+
reduce_axis = list(range(2, n_len))
|
| 45 |
+
intersection = torch.sum(y * y_pred, dim=reduce_axis)
|
| 46 |
+
|
| 47 |
+
y_o = torch.sum(y, reduce_axis)
|
| 48 |
+
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
|
| 49 |
+
denominator = y_o + y_pred_o
|
| 50 |
+
|
| 51 |
+
return torch.where(
|
| 52 |
+
denominator > 0,
|
| 53 |
+
(2.0 * intersection) / denominator,
|
| 54 |
+
torch.tensor(float("1"), device=y_o.device),
|
| 55 |
+
)
|
3DUnet_Like/models/SegTranVAE/SegTranVAE.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
###########Resnet Block############
|
| 9 |
+
def normalization(planes, norm = 'instance'):
|
| 10 |
+
if norm == 'bn':
|
| 11 |
+
m = nn.BatchNorm3d(planes)
|
| 12 |
+
elif norm == 'gn':
|
| 13 |
+
m = nn.GroupNorm(8, planes)
|
| 14 |
+
elif norm == 'instance':
|
| 15 |
+
m = nn.InstanceNorm3d(planes)
|
| 16 |
+
else:
|
| 17 |
+
raise ValueError("Does not support this kind of norm.")
|
| 18 |
+
return m
|
| 19 |
+
class ResNetBlock(nn.Module):
|
| 20 |
+
def __init__(self, in_channels, norm = 'instance'):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.resnetblock = nn.Sequential(
|
| 23 |
+
normalization(in_channels, norm = norm),
|
| 24 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 25 |
+
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
|
| 26 |
+
normalization(in_channels, norm = norm),
|
| 27 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 28 |
+
nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
y = self.resnetblock(x)
|
| 33 |
+
return y + x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
##############VAE###############
|
| 37 |
+
def calculate_total_dimension(a):
|
| 38 |
+
res = 1
|
| 39 |
+
for x in a:
|
| 40 |
+
res *= x
|
| 41 |
+
return res
|
| 42 |
+
|
| 43 |
+
class VAE(nn.Module):
|
| 44 |
+
def __init__(self, input_shape, latent_dim, num_channels):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.input_shape = input_shape
|
| 47 |
+
self.in_channels = input_shape[1] #input_shape[0] is batch size
|
| 48 |
+
self.latent_dim = latent_dim
|
| 49 |
+
self.encoder_channels = self.in_channels // 16
|
| 50 |
+
|
| 51 |
+
#Encoder
|
| 52 |
+
self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
|
| 53 |
+
kernel_size = 3, stride = 2, padding=1)
|
| 54 |
+
# self.VAE_reshape = nn.Sequential(
|
| 55 |
+
# nn.GroupNorm(8, self.in_channels),
|
| 56 |
+
# nn.ReLU(),
|
| 57 |
+
# nn.Conv3d(self.in_channels, self.encoder_channels,
|
| 58 |
+
# kernel_size = 3, stride = 2, padding=1),
|
| 59 |
+
# )
|
| 60 |
+
|
| 61 |
+
flatten_input_shape = calculate_total_dimension(input_shape)
|
| 62 |
+
flatten_input_shape_after_vae_reshape = \
|
| 63 |
+
flatten_input_shape * self.encoder_channels // (8 * self.in_channels)
|
| 64 |
+
|
| 65 |
+
#Convert from total dimension to latent space
|
| 66 |
+
self.to_latent_space = nn.Linear(
|
| 67 |
+
flatten_input_shape_after_vae_reshape // self.in_channels, 1)
|
| 68 |
+
|
| 69 |
+
self.mean = nn.Linear(self.in_channels, self.latent_dim)
|
| 70 |
+
self.logvar = nn.Linear(self.in_channels, self.latent_dim)
|
| 71 |
+
# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))
|
| 72 |
+
|
| 73 |
+
#Decoder
|
| 74 |
+
self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
|
| 75 |
+
self.Reconstruct = nn.Sequential(
|
| 76 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 77 |
+
nn.Conv3d(
|
| 78 |
+
self.encoder_channels, self.in_channels,
|
| 79 |
+
stride = 1, kernel_size = 1),
|
| 80 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
| 81 |
+
|
| 82 |
+
nn.Conv3d(
|
| 83 |
+
self.in_channels, self.in_channels // 2,
|
| 84 |
+
stride = 1, kernel_size = 1),
|
| 85 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
| 86 |
+
ResNetBlock(self.in_channels // 2),
|
| 87 |
+
|
| 88 |
+
nn.Conv3d(
|
| 89 |
+
self.in_channels // 2, self.in_channels // 4,
|
| 90 |
+
stride = 1, kernel_size = 1),
|
| 91 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
| 92 |
+
ResNetBlock(self.in_channels // 4),
|
| 93 |
+
|
| 94 |
+
nn.Conv3d(
|
| 95 |
+
self.in_channels // 4, self.in_channels // 8,
|
| 96 |
+
stride = 1, kernel_size = 1),
|
| 97 |
+
nn.Upsample(scale_factor=2, mode = 'nearest'),
|
| 98 |
+
ResNetBlock(self.in_channels // 8),
|
| 99 |
+
|
| 100 |
+
nn.InstanceNorm3d(self.in_channels // 8),
|
| 101 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 102 |
+
nn.Conv3d(
|
| 103 |
+
self.in_channels // 8, num_channels,
|
| 104 |
+
kernel_size = 3, padding = 1),
|
| 105 |
+
# nn.Sigmoid()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def forward(self, x): #x has shape = input_shape
|
| 110 |
+
#Encoder
|
| 111 |
+
# print(x.shape)
|
| 112 |
+
x = self.VAE_reshape(x)
|
| 113 |
+
shape = x.shape
|
| 114 |
+
|
| 115 |
+
x = x.view(self.in_channels, -1)
|
| 116 |
+
x = self.to_latent_space(x)
|
| 117 |
+
x = x.view(1, self.in_channels)
|
| 118 |
+
|
| 119 |
+
mean = self.mean(x)
|
| 120 |
+
logvar = self.logvar(x)
|
| 121 |
+
# sigma = torch.exp(0.5 * logvar)
|
| 122 |
+
# Reparameter
|
| 123 |
+
epsilon = torch.randn_like(logvar)
|
| 124 |
+
sample = mean + epsilon * torch.exp(0.5*logvar)
|
| 125 |
+
|
| 126 |
+
#Decoder
|
| 127 |
+
y = self.to_original_dimension(sample)
|
| 128 |
+
y = y.view(*shape)
|
| 129 |
+
return self.Reconstruct(y), mean, logvar
|
| 130 |
+
def total_params(self):
|
| 131 |
+
total = sum(p.numel() for p in self.parameters())
|
| 132 |
+
return format(total, ',')
|
| 133 |
+
|
| 134 |
+
def total_trainable_params(self):
|
| 135 |
+
total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 136 |
+
return format(total_trainable, ',')
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# x = torch.rand((1, 256, 16, 16, 16))
|
| 140 |
+
# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
|
| 141 |
+
# y = vae(x)
|
| 142 |
+
# print(y[0].shape, y[1].shape, y[2].shape)
|
| 143 |
+
# print(vae.total_trainable_params())
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
### Decoder ####
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Upsample(nn.Module):
|
| 151 |
+
def __init__(self, in_channel, out_channel):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
|
| 154 |
+
self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
|
| 155 |
+
self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)
|
| 156 |
+
|
| 157 |
+
def forward(self, prev, x):
|
| 158 |
+
x = self.deconv(self.conv1(x))
|
| 159 |
+
y = torch.cat((prev, x), dim = 1)
|
| 160 |
+
return self.conv2(y)
|
| 161 |
+
|
| 162 |
+
class FinalConv(nn.Module): # Input channels are equal to output channels
|
| 163 |
+
def __init__(self, in_channels, out_channels=32, norm="instance"):
|
| 164 |
+
super(FinalConv, self).__init__()
|
| 165 |
+
if norm == "batch":
|
| 166 |
+
norm_layer = nn.BatchNorm3d(num_features=in_channels)
|
| 167 |
+
elif norm == "group":
|
| 168 |
+
norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
|
| 169 |
+
elif norm == 'instance':
|
| 170 |
+
norm_layer = nn.InstanceNorm3d(in_channels)
|
| 171 |
+
|
| 172 |
+
self.layer = nn.Sequential(
|
| 173 |
+
norm_layer,
|
| 174 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 175 |
+
nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
| 176 |
+
)
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
return self.layer(x)
|
| 179 |
+
|
| 180 |
+
class Decoder(nn.Module):
|
| 181 |
+
def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.img_dim = img_dim
|
| 184 |
+
self.patch_dim = patch_dim
|
| 185 |
+
self.embedding_dim = embedding_dim
|
| 186 |
+
|
| 187 |
+
self.decoder_upsample_1 = Upsample(128, 64)
|
| 188 |
+
self.decoder_block_1 = ResNetBlock(64)
|
| 189 |
+
|
| 190 |
+
self.decoder_upsample_2 = Upsample(64, 32)
|
| 191 |
+
self.decoder_block_2 = ResNetBlock(32)
|
| 192 |
+
|
| 193 |
+
self.decoder_upsample_3 = Upsample(32, 16)
|
| 194 |
+
self.decoder_block_3 = ResNetBlock(16)
|
| 195 |
+
|
| 196 |
+
self.endconv = FinalConv(16, num_classes)
|
| 197 |
+
# self.normalize = nn.Sigmoid()
|
| 198 |
+
|
| 199 |
+
def forward(self, x1, x2, x3, x):
|
| 200 |
+
x = self.decoder_upsample_1(x3, x)
|
| 201 |
+
x = self.decoder_block_1(x)
|
| 202 |
+
|
| 203 |
+
x = self.decoder_upsample_2(x2, x)
|
| 204 |
+
x = self.decoder_block_2(x)
|
| 205 |
+
|
| 206 |
+
x = self.decoder_upsample_3(x1, x)
|
| 207 |
+
x = self.decoder_block_3(x)
|
| 208 |
+
|
| 209 |
+
y = self.endconv(x)
|
| 210 |
+
return y
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
###############Encoder##############
|
| 215 |
+
class InitConv(nn.Module):
|
| 216 |
+
def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.layer = nn.Sequential(
|
| 219 |
+
nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
|
| 220 |
+
nn.Dropout3d(dropout)
|
| 221 |
+
)
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
y = self.layer(x)
|
| 224 |
+
return y
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class DownSample(nn.Module):
|
| 228 |
+
def __init__(self, in_channels, out_channels):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
return self.conv(x)
|
| 233 |
+
|
| 234 |
+
class Encoder(nn.Module):
|
| 235 |
+
def __init__(self, in_channels, base_channels, dropout = 0.2):
|
| 236 |
+
super().__init__()
|
| 237 |
+
|
| 238 |
+
self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
|
| 239 |
+
self.encoder_block1 = ResNetBlock(in_channels = base_channels)
|
| 240 |
+
self.encoder_down1 = DownSample(base_channels, base_channels * 2)
|
| 241 |
+
|
| 242 |
+
self.encoder_block2_1 = ResNetBlock(base_channels * 2)
|
| 243 |
+
self.encoder_block2_2 = ResNetBlock(base_channels * 2)
|
| 244 |
+
self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)
|
| 245 |
+
|
| 246 |
+
self.encoder_block3_1 = ResNetBlock(base_channels * 4)
|
| 247 |
+
self.encoder_block3_2 = ResNetBlock(base_channels * 4)
|
| 248 |
+
self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)
|
| 249 |
+
|
| 250 |
+
self.encoder_block4_1 = ResNetBlock(base_channels * 8)
|
| 251 |
+
self.encoder_block4_2 = ResNetBlock(base_channels * 8)
|
| 252 |
+
self.encoder_block4_3 = ResNetBlock(base_channels * 8)
|
| 253 |
+
self.encoder_block4_4 = ResNetBlock(base_channels * 8)
|
| 254 |
+
# self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
x = self.init_conv(x) #(1, 16, 128, 128, 128)
|
| 257 |
+
|
| 258 |
+
x1 = self.encoder_block1(x)
|
| 259 |
+
x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)
|
| 260 |
+
|
| 261 |
+
x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
|
| 262 |
+
x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)
|
| 263 |
+
|
| 264 |
+
x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
|
| 265 |
+
x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)
|
| 266 |
+
|
| 267 |
+
output = self.encoder_block4_4(
|
| 268 |
+
self.encoder_block4_3(
|
| 269 |
+
self.encoder_block4_2(
|
| 270 |
+
self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
|
| 271 |
+
return x1, x2, x3, output
|
| 272 |
+
|
| 273 |
+
# x = torch.rand((1, 4, 128, 128, 128))
|
| 274 |
+
# Enc = Encoder(4, 32)
|
| 275 |
+
# _, _, _, y = Enc(x)
|
| 276 |
+
# print(y.shape) (1,256,16,16)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
###############FeatureMapping###############
|
| 280 |
+
|
| 281 |
+
class FeatureMapping(nn.Module):
|
| 282 |
+
def __init__(self, in_channel, out_channel, norm = 'instance'):
|
| 283 |
+
super().__init__()
|
| 284 |
+
if norm == 'bn':
|
| 285 |
+
norm_layer_1 = nn.BatchNorm3d(out_channel)
|
| 286 |
+
norm_layer_2 = nn.BatchNorm3d(out_channel)
|
| 287 |
+
elif norm == 'gn':
|
| 288 |
+
norm_layer_1 = nn.GroupNorm(8, out_channel)
|
| 289 |
+
norm_layer_2 = nn.GroupNorm(8, out_channel)
|
| 290 |
+
elif norm == 'instance':
|
| 291 |
+
norm_layer_1 = nn.InstanceNorm3d(out_channel)
|
| 292 |
+
norm_layer_2 = nn.InstanceNorm3d(out_channel)
|
| 293 |
+
self.feature_mapping = nn.Sequential(
|
| 294 |
+
nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
|
| 295 |
+
norm_layer_1,
|
| 296 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 297 |
+
nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
|
| 298 |
+
norm_layer_2,
|
| 299 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def forward(self, x):
|
| 303 |
+
return self.feature_mapping(x)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class FeatureMapping1(nn.Module):
|
| 307 |
+
def __init__(self, in_channel, norm = 'instance'):
|
| 308 |
+
super().__init__()
|
| 309 |
+
if norm == 'bn':
|
| 310 |
+
norm_layer_1 = nn.BatchNorm3d(in_channel)
|
| 311 |
+
norm_layer_2 = nn.BatchNorm3d(in_channel)
|
| 312 |
+
elif norm == 'gn':
|
| 313 |
+
norm_layer_1 = nn.GroupNorm(8, in_channel)
|
| 314 |
+
norm_layer_2 = nn.GroupNorm(8, in_channel)
|
| 315 |
+
elif norm == 'instance':
|
| 316 |
+
norm_layer_1 = nn.InstanceNorm3d(in_channel)
|
| 317 |
+
norm_layer_2 = nn.InstanceNorm3d(in_channel)
|
| 318 |
+
self.feature_mapping1 = nn.Sequential(
|
| 319 |
+
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
|
| 320 |
+
norm_layer_1,
|
| 321 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 322 |
+
nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
|
| 323 |
+
norm_layer_2,
|
| 324 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 325 |
+
)
|
| 326 |
+
def forward(self, x):
|
| 327 |
+
y = self.feature_mapping1(x)
|
| 328 |
+
return x + y #Resnet Like
|
| 329 |
+
|
| 330 |
+
################Transformer#######################
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def pair(t):
|
| 334 |
+
return t if isinstance(t, tuple) else (t, t)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class PreNorm(nn.Module):
|
| 338 |
+
def __init__(self, dim, function):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.norm = nn.LayerNorm(dim)
|
| 341 |
+
self.function = function
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
return self.function(self.norm(x))
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class FeedForward(nn.Module):
|
| 348 |
+
def __init__(self, dim, hidden_dim, dropout = 0.0):
|
| 349 |
+
super().__init__()
|
| 350 |
+
self.net = nn.Sequential(
|
| 351 |
+
nn.Linear(dim, hidden_dim),
|
| 352 |
+
nn.GELU(),
|
| 353 |
+
nn.Dropout(dropout),
|
| 354 |
+
nn.Linear(hidden_dim, dim),
|
| 355 |
+
nn.Dropout(dropout)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def forward(self, x):
|
| 359 |
+
return self.net(x)
|
| 360 |
+
|
| 361 |
+
class Attention(nn.Module):
|
| 362 |
+
def __init__(self, dim, heads, dim_head, dropout = 0.0):
|
| 363 |
+
super().__init__()
|
| 364 |
+
all_head_size = heads * dim_head
|
| 365 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 366 |
+
|
| 367 |
+
self.heads = heads
|
| 368 |
+
self.scale = dim_head ** -0.5
|
| 369 |
+
|
| 370 |
+
self.softmax = nn.Softmax(dim = -1)
|
| 371 |
+
self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)
|
| 372 |
+
|
| 373 |
+
self.to_out = nn.Sequential(
|
| 374 |
+
nn.Linear(all_head_size, dim),
|
| 375 |
+
nn.Dropout(dropout)
|
| 376 |
+
) if project_out else nn.Identity()
|
| 377 |
+
|
| 378 |
+
def forward(self, x):
|
| 379 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
| 380 |
+
#(batch, heads * dim_head) -> (batch, all_head_size)
|
| 381 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
| 382 |
+
|
| 383 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 384 |
+
|
| 385 |
+
atten = self.softmax(dots)
|
| 386 |
+
|
| 387 |
+
out = torch.matmul(atten, v)
|
| 388 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 389 |
+
return self.to_out(out)
|
| 390 |
+
|
| 391 |
+
class Transformer(nn.Module):
|
| 392 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.layers = nn.ModuleList([])
|
| 395 |
+
for _ in range(depth):
|
| 396 |
+
self.layers.append(nn.ModuleList([
|
| 397 |
+
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
|
| 398 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
|
| 399 |
+
]))
|
| 400 |
+
def forward(self, x):
|
| 401 |
+
for attention, feedforward in self.layers:
|
| 402 |
+
x = attention(x) + x
|
| 403 |
+
x = feedforward(x) + x
|
| 404 |
+
return x
|
| 405 |
+
|
| 406 |
+
class FixedPositionalEncoding(nn.Module):
|
| 407 |
+
def __init__(self, embedding_dim, max_length=768):
|
| 408 |
+
super(FixedPositionalEncoding, self).__init__()
|
| 409 |
+
|
| 410 |
+
pe = torch.zeros(max_length, embedding_dim)
|
| 411 |
+
position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
|
| 412 |
+
div_term = torch.exp(
|
| 413 |
+
torch.arange(0, embedding_dim, 2).float()
|
| 414 |
+
* (-torch.log(torch.tensor(10000.0)) / embedding_dim)
|
| 415 |
+
)
|
| 416 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 417 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 418 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 419 |
+
self.register_buffer('pe', pe)
|
| 420 |
+
|
| 421 |
+
def forward(self, x):
|
| 422 |
+
x = x + self.pe[: x.size(0), :]
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class LearnedPositionalEncoding(nn.Module):
|
| 427 |
+
def __init__(self, embedding_dim, seq_length):
|
| 428 |
+
super(LearnedPositionalEncoding, self).__init__()
|
| 429 |
+
self.seq_length = seq_length
|
| 430 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x
|
| 431 |
+
|
| 432 |
+
def forward(self, x, position_ids=None):
|
| 433 |
+
position_embeddings = self.position_embeddings
|
| 434 |
+
# print(x.shape, self.position_embeddings.shape)
|
| 435 |
+
return x + position_embeddings
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
###############Main model#################
|
| 442 |
+
|
| 443 |
+
class SegTransVAE(nn.Module):
|
| 444 |
+
def __init__(self, img_dim, patch_dim, num_channels, num_classes,
|
| 445 |
+
embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
|
| 446 |
+
dropout = 0.0, attention_dropout = 0.0,
|
| 447 |
+
conv_patch_representation = True, positional_encoding = 'learned',
|
| 448 |
+
use_VAE = False):
|
| 449 |
+
|
| 450 |
+
super().__init__()
|
| 451 |
+
assert embedding_dim % num_heads == 0
|
| 452 |
+
assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0
|
| 453 |
+
|
| 454 |
+
self.img_dim = img_dim
|
| 455 |
+
self.embedding_dim = embedding_dim
|
| 456 |
+
self.num_heads = num_heads
|
| 457 |
+
self.num_classes = num_classes
|
| 458 |
+
self.patch_dim = patch_dim
|
| 459 |
+
self.num_channels = num_channels
|
| 460 |
+
self.in_channels_vae = in_channels_vae
|
| 461 |
+
self.dropout = dropout
|
| 462 |
+
self.attention_dropout = attention_dropout
|
| 463 |
+
self.conv_patch_representation = conv_patch_representation
|
| 464 |
+
self.use_VAE = use_VAE
|
| 465 |
+
|
| 466 |
+
self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
|
| 467 |
+
self.seq_length = self.num_patches
|
| 468 |
+
self.flatten_dim = 128 * num_channels
|
| 469 |
+
|
| 470 |
+
self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
|
| 471 |
+
if positional_encoding == "learned":
|
| 472 |
+
self.position_encoding = LearnedPositionalEncoding(
|
| 473 |
+
self.embedding_dim, self.seq_length
|
| 474 |
+
)
|
| 475 |
+
elif positional_encoding == "fixed":
|
| 476 |
+
self.position_encoding = FixedPositionalEncoding(
|
| 477 |
+
self.embedding_dim,
|
| 478 |
+
)
|
| 479 |
+
self.pe_dropout = nn.Dropout(self.dropout)
|
| 480 |
+
|
| 481 |
+
self.transformer = Transformer(
|
| 482 |
+
embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
|
| 483 |
+
)
|
| 484 |
+
self.pre_head_ln = nn.LayerNorm(embedding_dim)
|
| 485 |
+
|
| 486 |
+
if self.conv_patch_representation:
|
| 487 |
+
self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
|
| 488 |
+
self.encoder = Encoder(self.num_channels, 16)
|
| 489 |
+
self.bn = nn.InstanceNorm3d(128)
|
| 490 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
| 491 |
+
self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
|
| 492 |
+
self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
|
| 493 |
+
self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)
|
| 494 |
+
|
| 495 |
+
self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
|
| 496 |
+
if use_VAE:
|
| 497 |
+
self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
|
| 498 |
+
def encode(self, x):
|
| 499 |
+
if self.conv_patch_representation:
|
| 500 |
+
x1, x2, x3, x = self.encoder(x)
|
| 501 |
+
x = self.bn(x)
|
| 502 |
+
x = self.relu(x)
|
| 503 |
+
x = self.conv_x(x)
|
| 504 |
+
x = x.permute(0, 2, 3, 4, 1).contiguous()
|
| 505 |
+
x = x.view(x.size(0), -1, self.embedding_dim)
|
| 506 |
+
x = self.position_encoding(x)
|
| 507 |
+
x = self.pe_dropout(x)
|
| 508 |
+
x = self.transformer(x)
|
| 509 |
+
x = self.pre_head_ln(x)
|
| 510 |
+
|
| 511 |
+
return x1, x2, x3, x
|
| 512 |
+
|
| 513 |
+
def decode(self, x1, x2, x3, x):
|
| 514 |
+
#x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
|
| 515 |
+
# print("In decode...")
|
| 516 |
+
# print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
|
| 517 |
+
# break
|
| 518 |
+
return self.decoder(x1, x2, x3, x)
|
| 519 |
+
|
| 520 |
+
def forward(self, x, is_validation = True):
|
| 521 |
+
x1, x2, x3, x = self.encode(x)
|
| 522 |
+
x = x.view( x.size(0),
|
| 523 |
+
self.img_dim[0] // self.patch_dim,
|
| 524 |
+
self.img_dim[1] // self.patch_dim,
|
| 525 |
+
self.img_dim[2] // self.patch_dim,
|
| 526 |
+
self.embedding_dim)
|
| 527 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous()
|
| 528 |
+
x = self.FeatureMapping(x)
|
| 529 |
+
x = self.FeatureMapping1(x)
|
| 530 |
+
if self.use_VAE and not is_validation:
|
| 531 |
+
vae_out, mu, sigma = self.vae(x)
|
| 532 |
+
y = self.decode(x1, x2, x3, x)
|
| 533 |
+
if self.use_VAE and not is_validation:
|
| 534 |
+
return y, vae_out, mu, sigma
|
| 535 |
+
else:
|
| 536 |
+
return y
|
| 537 |
+
|
| 538 |
+
|
3DUnet_Like/models/SegTranVAE/__init__.py
ADDED
|
File without changes
|
3DUnet_Like/models/SegTranVAE/__pycache__/SegTranVAE.cpython-39.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
3DUnet_Like/models/SegTranVAE/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (137 Bytes). View file
|
|
|
3DUnet_Like/train.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from monai.utils import set_determinism
|
| 4 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
| 5 |
+
import os
|
| 6 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 7 |
+
from trainer import BRATS
|
| 8 |
+
from dataset.utils import get_loader
|
| 9 |
+
import pytorch_lightning as pl
|
| 10 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
|
| 13 |
+
set_determinism(seed=0)
|
| 14 |
+
|
| 15 |
+
os.system('cls||clear')
|
| 16 |
+
print("Training ...")
|
| 17 |
+
|
| 18 |
+
data_dir = "/app/brats_2021_task1"
|
| 19 |
+
json_list = "/app/info.json"
|
| 20 |
+
roi = (128, 128, 128)
|
| 21 |
+
batch_size = 1
|
| 22 |
+
fold = 1
|
| 23 |
+
max_epochs = 500
|
| 24 |
+
val_every = 10
|
| 25 |
+
train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2)
|
| 26 |
+
print("Done initialize dataloader !! ")
|
| 27 |
+
|
| 28 |
+
model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader )
|
| 29 |
+
checkpoint_callback = ModelCheckpoint(
|
| 30 |
+
monitor='val/MeanDiceScore',
|
| 31 |
+
dirpath='./checkpoints/{}'.format("SegTransVAE"),
|
| 32 |
+
filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
|
| 33 |
+
save_top_k=3,
|
| 34 |
+
mode='max',
|
| 35 |
+
save_last= True,
|
| 36 |
+
auto_insert_metric_name=False
|
| 37 |
+
)
|
| 38 |
+
early_stop_callback = EarlyStopping(
|
| 39 |
+
monitor='val/MeanDiceScore',
|
| 40 |
+
min_delta=0.0001,
|
| 41 |
+
patience=15,
|
| 42 |
+
verbose=False,
|
| 43 |
+
mode='max'
|
| 44 |
+
)
|
| 45 |
+
tensorboardlogger = TensorBoardLogger(
|
| 46 |
+
'logs',
|
| 47 |
+
name = "SegTransVAE",
|
| 48 |
+
default_hp_metric = None
|
| 49 |
+
)
|
| 50 |
+
trainer = pl.Trainer(#fast_dev_run = 10,
|
| 51 |
+
# accelerator='ddp',
|
| 52 |
+
#overfit_batches=5,
|
| 53 |
+
devices = [0],
|
| 54 |
+
precision=16,
|
| 55 |
+
max_epochs = max_epochs,
|
| 56 |
+
enable_progress_bar=True,
|
| 57 |
+
callbacks=[checkpoint_callback, early_stop_callback],
|
| 58 |
+
# auto_lr_find=True,
|
| 59 |
+
num_sanity_val_steps=1,
|
| 60 |
+
logger = tensorboardlogger,
|
| 61 |
+
check_val_every_n_epoch = 10,
|
| 62 |
+
# limit_train_batches=0.01,
|
| 63 |
+
# limit_val_batches=0.01
|
| 64 |
+
)
|
| 65 |
+
# trainer.tune(model)
|
| 66 |
+
trainer.fit(model)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
3DUnet_Like/trainer.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import csv
|
| 5 |
+
import torch
|
| 6 |
+
from monai.transforms import AsDiscrete, Activations, Compose, EnsureType
|
| 7 |
+
from models.SegTranVAE.SegTranVAE import SegTransVAE
|
| 8 |
+
from loss.loss import Loss_VAE, DiceScore
|
| 9 |
+
from monai.losses import DiceLoss
|
| 10 |
+
import pytorch_lightning as pl
|
| 11 |
+
from monai.inferers import sliding_window_inference
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BRATS(pl.LightningModule):
|
| 18 |
+
def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.train_loader = train_loader
|
| 21 |
+
self.val_loader = val_loader
|
| 22 |
+
self.test_loader = test_loader
|
| 23 |
+
self.use_vae = use_VAE
|
| 24 |
+
self.lr = lr
|
| 25 |
+
self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)
|
| 26 |
+
|
| 27 |
+
self.loss_vae = Loss_VAE()
|
| 28 |
+
self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
|
| 29 |
+
self.post_trans_images = Compose(
|
| 30 |
+
[EnsureType(),
|
| 31 |
+
Activations(sigmoid=True),
|
| 32 |
+
AsDiscrete(threshold_values=True),
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.best_val_dice = 0
|
| 37 |
+
|
| 38 |
+
self.training_step_outputs = []
|
| 39 |
+
self.val_step_loss = []
|
| 40 |
+
self.val_step_dice = []
|
| 41 |
+
self.val_step_dice_tc = []
|
| 42 |
+
self.val_step_dice_wt = []
|
| 43 |
+
self.val_step_dice_et = []
|
| 44 |
+
self.test_step_loss = []
|
| 45 |
+
self.test_step_dice = []
|
| 46 |
+
self.test_step_dice_tc = []
|
| 47 |
+
self.test_step_dice_wt = []
|
| 48 |
+
self.test_step_dice_et = []
|
| 49 |
+
|
| 50 |
+
def forward(self, x, is_validation = True):
|
| 51 |
+
return self.model(x, is_validation)
|
| 52 |
+
def training_step(self, batch, batch_index):
|
| 53 |
+
inputs, labels = (batch['image'], batch['label'])
|
| 54 |
+
|
| 55 |
+
if not self.use_vae:
|
| 56 |
+
outputs = self.forward(inputs, is_validation=False)
|
| 57 |
+
loss = self.dice_loss(outputs, labels)
|
| 58 |
+
else:
|
| 59 |
+
outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)
|
| 60 |
+
|
| 61 |
+
vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)
|
| 62 |
+
dice_loss = self.dice_loss(outputs, labels)
|
| 63 |
+
loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss
|
| 64 |
+
self.training_step_outputs.append(loss)
|
| 65 |
+
self.log('train/vae_loss', vae_loss)
|
| 66 |
+
self.log('train/dice_loss', dice_loss)
|
| 67 |
+
if batch_index == 10:
|
| 68 |
+
|
| 69 |
+
tensorboard = self.logger.experiment
|
| 70 |
+
fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')
|
| 74 |
+
ax[0].set_title("Input")
|
| 75 |
+
|
| 76 |
+
ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
| 77 |
+
ax[1].set_title("Reconstruction")
|
| 78 |
+
|
| 79 |
+
ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
| 80 |
+
ax[2].set_title("Labels TC")
|
| 81 |
+
|
| 82 |
+
ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')
|
| 83 |
+
ax[3].set_title("TC")
|
| 84 |
+
|
| 85 |
+
ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')
|
| 86 |
+
ax[4].set_title("Labels ET")
|
| 87 |
+
|
| 88 |
+
ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')
|
| 89 |
+
ax[5].set_title("ET")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
tensorboard.add_figure('train_visualize', fig, self.current_epoch)
|
| 93 |
+
|
| 94 |
+
self.log('train/loss', loss)
|
| 95 |
+
|
| 96 |
+
return loss
|
| 97 |
+
|
| 98 |
+
def on_train_epoch_end(self):
|
| 99 |
+
## F1 Macro all epoch saving outputs and target per batch
|
| 100 |
+
|
| 101 |
+
# free up the memory
|
| 102 |
+
# --> HERE STEP 3 <--
|
| 103 |
+
epoch_average = torch.stack(self.training_step_outputs).mean()
|
| 104 |
+
self.log("training_epoch_average", epoch_average)
|
| 105 |
+
self.training_step_outputs.clear() # free memory
|
| 106 |
+
|
| 107 |
+
def validation_step(self, batch, batch_index):
|
| 108 |
+
inputs, labels = (batch['image'], batch['label'])
|
| 109 |
+
roi_size = (128, 128, 128)
|
| 110 |
+
sw_batch_size = 1
|
| 111 |
+
outputs = sliding_window_inference(
|
| 112 |
+
inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)
|
| 113 |
+
loss = self.dice_loss(outputs, labels)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
val_outputs = self.post_trans_images(outputs)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
|
| 120 |
+
metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
|
| 121 |
+
metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
|
| 122 |
+
mean_val_dice = (metric_tc + metric_wt + metric_et)/3
|
| 123 |
+
self.val_step_loss.append(loss)
|
| 124 |
+
self.val_step_dice.append(mean_val_dice)
|
| 125 |
+
self.val_step_dice_tc.append(metric_tc)
|
| 126 |
+
self.val_step_dice_wt.append(metric_wt)
|
| 127 |
+
self.val_step_dice_et.append(metric_et)
|
| 128 |
+
return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,
|
| 129 |
+
'val_dice_wt': metric_wt, 'val_dice_et': metric_et}
|
| 130 |
+
|
| 131 |
+
def on_validation_epoch_end(self):
|
| 132 |
+
|
| 133 |
+
loss = torch.stack(self.val_step_loss).mean()
|
| 134 |
+
mean_val_dice = torch.stack(self.val_step_dice).mean()
|
| 135 |
+
metric_tc = torch.stack(self.val_step_dice_tc).mean()
|
| 136 |
+
metric_wt = torch.stack(self.val_step_dice_wt).mean()
|
| 137 |
+
metric_et = torch.stack(self.val_step_dice_et).mean()
|
| 138 |
+
self.log('val/Loss', loss)
|
| 139 |
+
self.log('val/MeanDiceScore', mean_val_dice)
|
| 140 |
+
self.log('val/DiceTC', metric_tc)
|
| 141 |
+
self.log('val/DiceWT', metric_wt)
|
| 142 |
+
self.log('val/DiceET', metric_et)
|
| 143 |
+
os.makedirs(self.logger.log_dir, exist_ok=True)
|
| 144 |
+
if self.current_epoch == 0:
|
| 145 |
+
with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:
|
| 146 |
+
writer = csv.writer(f)
|
| 147 |
+
writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])
|
| 148 |
+
with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:
|
| 149 |
+
writer = csv.writer(f)
|
| 150 |
+
writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])
|
| 151 |
+
|
| 152 |
+
if mean_val_dice > self.best_val_dice:
|
| 153 |
+
self.best_val_dice = mean_val_dice
|
| 154 |
+
self.best_val_epoch = self.current_epoch
|
| 155 |
+
print(
|
| 156 |
+
f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}"
|
| 157 |
+
f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
|
| 158 |
+
f"\n Best mean dice: {self.best_val_dice}"
|
| 159 |
+
f" at epoch: {self.best_val_epoch}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.val_step_loss.clear()
|
| 163 |
+
self.val_step_dice.clear()
|
| 164 |
+
self.val_step_dice_tc.clear()
|
| 165 |
+
self.val_step_dice_wt.clear()
|
| 166 |
+
self.val_step_dice_et.clear()
|
| 167 |
+
return {'val_MeanDiceScore': mean_val_dice}
|
| 168 |
+
def test_step(self, batch, batch_index):
|
| 169 |
+
inputs, labels = (batch['image'], batch['label'])
|
| 170 |
+
|
| 171 |
+
roi_size = (128, 128, 128)
|
| 172 |
+
sw_batch_size = 1
|
| 173 |
+
test_outputs = sliding_window_inference(
|
| 174 |
+
inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)
|
| 175 |
+
loss = self.dice_loss(test_outputs, labels)
|
| 176 |
+
test_outputs = self.post_trans_images(test_outputs)
|
| 177 |
+
metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
|
| 178 |
+
metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
|
| 179 |
+
metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
|
| 180 |
+
mean_test_dice = (metric_tc + metric_wt + metric_et)/3
|
| 181 |
+
|
| 182 |
+
self.test_step_loss.append(loss)
|
| 183 |
+
self.test_step_dice.append(mean_test_dice)
|
| 184 |
+
self.test_step_dice_tc.append(metric_tc)
|
| 185 |
+
self.test_step_dice_wt.append(metric_wt)
|
| 186 |
+
self.test_step_dice_et.append(metric_et)
|
| 187 |
+
|
| 188 |
+
return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,
|
| 189 |
+
'test_dice_wt': metric_wt, 'test_dice_et': metric_et}
|
| 190 |
+
|
| 191 |
+
def test_epoch_end(self):
|
| 192 |
+
loss = torch.stack(self.test_step_loss).mean()
|
| 193 |
+
mean_test_dice = torch.stack(self.test_step_dice).mean()
|
| 194 |
+
metric_tc = torch.stack(self.test_step_dice_tc).mean()
|
| 195 |
+
metric_wt = torch.stack(self.test_step_dice_wt).mean()
|
| 196 |
+
metric_et = torch.stack(self.test_step_dice_et).mean()
|
| 197 |
+
self.log('test/Loss', loss)
|
| 198 |
+
self.log('test/MeanDiceScore', mean_test_dice)
|
| 199 |
+
self.log('test/DiceTC', metric_tc)
|
| 200 |
+
self.log('test/DiceWT', metric_wt)
|
| 201 |
+
self.log('test/DiceET', metric_et)
|
| 202 |
+
|
| 203 |
+
with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:
|
| 204 |
+
writer = csv.writer(f)
|
| 205 |
+
writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"])
|
| 206 |
+
writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])
|
| 207 |
+
|
| 208 |
+
self.test_step_loss.clear()
|
| 209 |
+
self.test_step_dice.clear()
|
| 210 |
+
self.test_step_dice_tc.clear()
|
| 211 |
+
self.test_step_dice_wt.clear()
|
| 212 |
+
self.test_step_dice_et.clear()
|
| 213 |
+
return {'test_MeanDiceScore': mean_test_dice}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def configure_optimizers(self):
|
| 217 |
+
optimizer = torch.optim.Adam(
|
| 218 |
+
self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True
|
| 219 |
+
)
|
| 220 |
+
# optimizer = AdaBelief(self.model.parameters(),
|
| 221 |
+
# lr=self.lr, eps=1e-16,
|
| 222 |
+
# betas=(0.9,0.999), weight_decouple = True,
|
| 223 |
+
# rectify = False)
|
| 224 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
|
| 225 |
+
return [optimizer], [scheduler]
|
| 226 |
+
|
| 227 |
+
def train_dataloader(self):
|
| 228 |
+
return self.train_loader
|
| 229 |
+
def val_dataloader(self):
|
| 230 |
+
return self.val_loader
|
| 231 |
+
|
| 232 |
+
def test_dataloader(self):
|
| 233 |
+
return self.test_loader
|
brats_2021_task1/BraTS2021_Training_Data/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb899a83627591e55cada00b2c6d5402199832b717c8b9f90bb550fe35d971ff
|
| 3 |
+
size 2532638
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_seg.nii.gz
ADDED
|
Binary file (57.9 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e92b72ee221624c36cf89ac826deceeee7097f46dd66a5d218d18b7916ebd67d
|
| 3 |
+
size 2332393
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:735c27fd7a17b1702875837bcc843eedc88ced6ac2cb0e73cdf995e3e64ba82f
|
| 3 |
+
size 2643179
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00000/BraTS2021_00000_t2.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:955ff59d053e87153bb7c809235743ec904817727ec02c630f3141e191d6f452
|
| 3 |
+
size 2432699
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_flair.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:208e4e8fadbdf2b1c1a87f8bad8854bc7d4becd604bc01eefad725d19b43c6ef
|
| 3 |
+
size 2331912
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_seg.nii.gz
ADDED
|
Binary file (78 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b76c5e6326f0f89e2f6ee243473060390982fc83e9bceb27a4b94899b2b0df1
|
| 3 |
+
size 2170543
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t1ce.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f49aa266cd907ce890bcc9c96f82534810f85bb98a76f8a092a5b529d3b6b6e
|
| 3 |
+
size 2486326
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00002/BraTS2021_00002_t2.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5c9da53d4a573fe37fa8a8075b8220d4dfe3bbc377c3e1943fbd3ef86d3b118
|
| 3 |
+
size 2303833
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_flair.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f68b6d8026f2c6ac097ea26a4dd727571025e26b69b485f9a1a84e244222721
|
| 3 |
+
size 2719582
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_seg.nii.gz
ADDED
|
Binary file (63.4 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90997c21753a6d49bdb8629f5b2c6ddb61ef7c2f86f977b18001ce6c5e3161ed
|
| 3 |
+
size 2488450
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t1ce.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a211cd8c26dd8d396341399d0a13fb888d5df6252bd9c5aa01156680e97c5577
|
| 3 |
+
size 2834759
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00003/BraTS2021_00003_t2.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c440a9ef1ef3b494d74e647fab9e41eb48f1fc40c8eabffd55863f46779ed5d
|
| 3 |
+
size 2635293
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_flair.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5c3ed117922199afb26cae01b090400d6ca90f06825c12f5a5a3ae7ced098e3
|
| 3 |
+
size 2265964
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_seg.nii.gz
ADDED
|
Binary file (70.7 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54f109386253cb60694dec35e10e1e8f2ee02a2671faa860ac212cce8997b434
|
| 3 |
+
size 2085481
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t1ce.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e074611e7388f7f7fefb5bbe39abb09f16f0dbeb0a35fbd661a0bbcd14d4b5a
|
| 3 |
+
size 2323871
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00005/BraTS2021_00005_t2.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d1689213fb1bb46330f16d30489a9875e2b5e9f30c8c5659a6b689011bd69e9
|
| 3 |
+
size 2144371
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_flair.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f42d02ba6637b8d98c0c44869ccfa444935dedb8f5ada29a3db65d33c878c0b
|
| 3 |
+
size 2588071
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_seg.nii.gz
ADDED
|
Binary file (70.5 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f9e9ff16cdffd375f31a969994e31725c329950e7c3e64789238afc3faadead
|
| 3 |
+
size 2386395
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t1ce.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20847fc162fe9193fb2d5a08e7b9009d16603eae268f769bf6d7bbffb0d79c42
|
| 3 |
+
size 2705917
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00006/BraTS2021_00006_t2.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a18fb6644ca828ef43067264a631d297d5aeaad4ebae05bce4ea09c9c76f898
|
| 3 |
+
size 2479204
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_flair.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5dbec780a624127d829b4b7165ff6afe6258e9224a529abf8ddaf25474888c5
|
| 3 |
+
size 2452120
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_seg.nii.gz
ADDED
|
Binary file (45.5 kB). View file
|
|
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f46e84f3f81a10b11b994fca26ae7f2bf537a6dfc0cb60853c097243924e39e8
|
| 3 |
+
size 2397406
|
brats_2021_task1/BraTS2021_Training_Data/BraTS2021_00008/BraTS2021_00008_t1ce.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b92f907a585a13ae99b0c79f315d396a3dd4524f1a35001212509957e264abf
|
| 3 |
+
size 2639147
|