def test_inverse_inferred_seg(self): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99)) ]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2, ), ).to(device) data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value segs_dict = { "label": segs, label_transform_key: data[label_transform_key] } segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({ "image": image, "label": label.astype(np.float32) }) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([ AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform ]) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1, )).to(device) data = first(loader) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) self.assertIsInstance(labels, MetaTensor) segs = model(labels).detach().cpu() segs_decollated = decollate_batch(segs) self.assertIsInstance(segs_decollated[0], MetaTensor) # inverse of individual segmentation seg_metatensor = first(segs_decollated) # test to convert interpolation mode for 1 data of model output batch convert_applied_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) # manually invert the last crop samples xform = seg_metatensor.applied_operations.pop(-1) shape_before_extra_xform = xform["orig_size"] resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) with resizer.trace_transform(False): seg_metatensor = resizer(seg_metatensor) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse({"label": seg_metatensor})["label"] self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({"image": image, "label": label.astype(np.float32)}) batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) device = "cuda" if torch.cuda.is_available() else "cpu" model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device) data = first(loader) self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation seg_dict = first(segs_dict_decollated) # test to convert interpolation mode for 1 data of model output batch convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) # Inverse of batch batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) with allow_missing_keys_mode(transforms): inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
def first_key(self, data: Dict[Hashable, Any]): """ Get the first available key of `self.keys` in the input `data` dictionary. If no available key, return an empty list `[]`. Args: data: data that the transform will be applied to. """ return first(self.key_iterator(data), [])
def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape): train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), ToTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}], transform=train_transform) data_loader = DataLoader(dataset) data: dict = first(data_loader) for s in data[PostFix.meta("image")]["spatial_shape"]: torch.testing.assert_allclose(s, expected_spatial_shape) self.assertTupleEqual(data["image"].shape, expected_shape)
def __call__(self, data: Dict[str, Any]) -> Any: decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) try: return first(inv_loader) except RuntimeError as re: re_str = str(re) if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." raise RuntimeError(re_str) from re
def test_with_dataloader_batch(self, file_path, level, expected_spatial_shape, expected_shape): train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), FromMetaTensord(keys=["image"]), ToTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}, {"image": file_path}], transform=train_transform) batch_size = 2 data_loader = DataLoader(dataset, batch_size=batch_size) data: dict = first(data_loader) for s in data[PostFix.meta("image")]["spatial_shape"]: assert_allclose(s, expected_spatial_shape, type_test=False) self.assertTupleEqual(data["image"].shape, (batch_size, *expected_shape[1:]))
def __call__(self, data: Mapping[Hashable, np.ndarray]): d = dict(data) original_spatial_shape = d[first(self.keys)].shape[1:] for patch in zip(*[self.patch_iter(d[key]) for key in self.keys]): coords = patch[0][1] # use the coordinate of the first item ret = {k: v[0] for k, v in zip(self.keys, patch)} # fill in the extra keys with unmodified data for k in set(d.keys()).difference(set(self.keys)): ret[k] = deepcopy(d[k]) # also store the `coordinate`, `spatial shape of original image`, `start position` in the dictionary ret[self.coords_key] = coords ret[self.original_spatial_shape_key] = original_spatial_shape ret[self.start_pos_key] = self.patch_iter.start_pos yield ret, coords
def is_import_statement(cls, config: Union[Dict, List, str]) -> bool: """ Check whether the config is an import statement (a special case of expression). Args: config: input config content to check. """ if not cls.is_expression(config): return False if "import" not in config: return False return isinstance( first( ast.iter_child_nodes( ast.parse(f"{config[len(cls.prefix) :]}"))), (ast.Import, ast.ImportFrom))
def dense_patch_slices( image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int], ) -> List[Tuple[slice, ...]]: """ Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image. Args: image_size: dimensions of image to iterate over patch_size: size of patches to generate slices scan_interval: dense patch sampling interval Returns: a list of slice objects defining each patch """ num_spatial_dims = len(image_size) patch_size = get_valid_patch_size(image_size, patch_size) scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims) scan_num = [] for i in range(num_spatial_dims): if scan_interval[i] == 0: scan_num.append(1) else: num = int(math.ceil(float(image_size[i]) / scan_interval[i])) scan_dim = first( d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i]) scan_num.append(scan_dim + 1 if scan_dim is not None else 1) starts = [] for dim in range(num_spatial_dims): dim_starts = [] for idx in range(scan_num[dim]): start_idx = idx * scan_interval[dim] start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0) dim_starts.append(start_idx) starts.append(dim_starts) out = np.asarray( [x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T slices = [ tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out ] return slices
def _parse_import_string(self, import_string: str): """parse single import statement such as "from monai.transforms import Resize""" node = first(ast.iter_child_nodes(ast.parse(import_string))) if not isinstance(node, (ast.Import, ast.ImportFrom)): return None if len(node.names) < 1: return None if len(node.names) > 1: warnings.warn(f"ignoring multiple import alias '{import_string}'.") name, asname = f"{node.names[0].name}", node.names[0].asname asname = name if asname is None else f"{asname}" if isinstance(node, ast.ImportFrom): self.globals[asname], _ = optional_import(f"{node.module}", name=f"{name}") return self.globals[asname] if isinstance(node, ast.Import): self.globals[asname], _ = optional_import(f"{name}") return self.globals[asname] return None
def main(): #TODO Defining file paths & output directory path json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json') data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19') logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467') if os.path.exists(logdir_path)==False: os.mkdir(logdir_path) # Load Json & Append Root Path with open(json_Path, 'r') as json_f: json_Data = json.load(json_f) train_Data = json_Data['training'] val_Data = json_Data['validation'] for idx, each_d in enumerate(train_Data): train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image']) for idx, each_d in enumerate(val_Data): val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image']) print('Total Number of Training Data Samples: {}'.format(len(train_Data))) print(train_Data) print('#' * 10) print('Total Number of Validation Data Samples: {}'.format(len(val_Data))) print(val_Data) print('#' * 10) # Set Determinism set_determinism(seed=123) # Define Training Transforms train_Transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Spacingd(keys=["image"], pixdim=( 2.0, 2.0, 2.0), mode=("bilinear")), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image"), SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)), RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2), CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False), OneOf(transforms=[ RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8), # Please note that that if image, image_2 are called via the same transform call because of the determinism # they will get augmented the exact same way which is not the required case here, hence two calls are made OneOf(transforms=[ RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8) ] ) check_ds = Dataset(data=train_Data, transform=train_Transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image = (check_data["image"][0][0]) print(f"image shape: {image.shape}") # Define Network ViT backbone & Loss & Optimizer device = torch.device("cuda:0") model = ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), pos_embed='conv', hidden_size=768, mlp_dim=3072, ) model = model.to(device) # Define Hyper-paramters for training loop max_epochs = 500 val_interval = 2 batch_size = 4 lr = 1e-4 epoch_loss_values = [] step_loss_values = [] epoch_cl_loss_values = [] epoch_recon_loss_values = [] val_loss_values = [] best_val_loss = 1000.0 recon_loss = L1Loss() contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Define DataLoader using MONAI, CacheDataset needs to be used train_ds = Dataset(data=train_Data, transform=train_Transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) val_ds = Dataset(data=val_Data, transform=train_Transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 epoch_cl_loss = 0 epoch_recon_loss = 0 step = 0 for batch_data in train_loader: step += 1 start_time = time.time() inputs, inputs_2, gt_input = ( batch_data["image"].to(device), batch_data["image_2"].to(device), batch_data["gt_image"].to(device), ) optimizer.zero_grad() outputs_v1, hidden_v1 = model(inputs) outputs_v2, hidden_v2 = model(inputs_2) flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4) flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4) r_loss = recon_loss(outputs_v1, gt_input) cl_loss = contrastive_loss(flat_out_v1, flat_out_v2) # Adjust the CL loss by Recon Loss total_loss = r_loss + cl_loss * r_loss total_loss.backward() optimizer.step() epoch_loss += total_loss.item() step_loss_values.append(total_loss.item()) # CL & Recon Loss Storage of Value epoch_cl_loss += cl_loss.item() epoch_recon_loss += r_loss.item() end_time = time.time() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {total_loss.item():.4f}, " f"time taken: {end_time-start_time}s") epoch_loss /= step epoch_cl_loss /= step epoch_recon_loss /= step epoch_loss_values.append(epoch_loss) epoch_cl_loss_values.append(epoch_cl_loss) epoch_recon_loss_values.append(epoch_recon_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if epoch % val_interval == 0: print('Entering Validation for epoch: {}'.format(epoch+1)) total_val_loss = 0 val_step = 0 model.eval() for val_batch in val_loader: val_step += 1 start_time = time.time() inputs, gt_input = ( val_batch["image"].to(device), val_batch["gt_image"].to(device), ) print('Input shape: {}'.format(inputs.shape)) outputs, outputs_v2 = model(inputs) val_loss = recon_loss(outputs, gt_input) total_val_loss += val_loss.item() end_time = time.time() total_val_loss /= val_step val_loss_values.append(total_val_loss) print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") if total_val_loss < best_val_loss: print(f"Saving new model based on validation loss {total_val_loss:.4f}") best_val_loss = total_val_loss checkpoint = {'epoch': max_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt')) plt.figure(1, figsize=(8, 8)) plt.subplot(2, 2, 1) plt.plot(epoch_loss_values) plt.grid() plt.title('Training Loss') plt.subplot(2, 2, 2) plt.plot(val_loss_values) plt.grid() plt.title('Validation Loss') plt.subplot(2, 2, 3) plt.plot(epoch_cl_loss_values) plt.grid() plt.title('Training Contrastive Loss') plt.subplot(2, 2, 4) plt.plot(epoch_recon_loss_values) plt.grid() plt.title('Training Recon Loss') plt.savefig(os.path.join(logdir_path, 'loss_plots.png')) plt.close(1) print('Done') return None
def dense_patch_slices( image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int], ) -> List[Tuple[slice, ...]]: """ Enumerate all slices defining 2D/3D patches of size `patch_size` from an `image_size` input image. Args: image_size: dimensions of image to iterate over patch_size: size of patches to generate slices scan_interval: dense patch sampling interval Raises: ValueError: When ``image_size`` length is not one of [2, 3]. Returns: a list of slice objects defining each patch """ num_spatial_dims = len(image_size) if num_spatial_dims not in (2, 3): raise ValueError( f"Unsupported image_size length: {len(image_size)}, available options are [2, 3]" ) patch_size = get_valid_patch_size(image_size, patch_size) scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims) scan_num = list() for i in range(num_spatial_dims): if scan_interval[i] == 0: scan_num.append(1) else: num = int(math.ceil(float(image_size[i]) / scan_interval[i])) scan_dim = first( d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i]) scan_num.append(scan_dim + 1) slices: List[Tuple[slice, ...]] = [] if num_spatial_dims == 3: for i in range(scan_num[0]): start_i = i * scan_interval[0] start_i -= max(start_i + patch_size[0] - image_size[0], 0) slice_i = slice(start_i, start_i + patch_size[0]) for j in range(scan_num[1]): start_j = j * scan_interval[1] start_j -= max(start_j + patch_size[1] - image_size[1], 0) slice_j = slice(start_j, start_j + patch_size[1]) for k in range(0, scan_num[2]): start_k = k * scan_interval[2] start_k -= max(start_k + patch_size[2] - image_size[2], 0) slice_k = slice(start_k, start_k + patch_size[2]) slices.append((slice_i, slice_j, slice_k)) else: for i in range(scan_num[0]): start_i = i * scan_interval[0] start_i -= max(start_i + patch_size[0] - image_size[0], 0) slice_i = slice(start_i, start_i + patch_size[0]) for j in range(scan_num[1]): start_j = j * scan_interval[1] start_j -= max(start_j + patch_size[1] - image_size[1], 0) slice_j = slice(start_j, start_j + patch_size[1]) slices.append((slice_i, slice_j)) return slices
ScaleIntensityRanged( keys=["image"], a_min=-300, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), #CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) """## Check transforms in DataLoader""" check_ds = Dataset(data=val_files, transform=val_transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image, label = (check_data["image"][0][0], check_data["label"][0][0]) print(f"image shape: {image.shape}, label shape: {label.shape}") # plot the slice [:, :, 80] #fig = plt.figure("check", (12, 6)) #figure size #plt.subplot(1, 2, 1) #plt.title("image") #plt.imshow(image[:, :, 80], cmap="gray") #plt.subplot(1, 2, 2) #plt.title("label") #plt.imshow(label[:, :, 80]) #plt.show() #fig.savefig('my_figure.png') """## Define CacheDataset and DataLoader for training and validation Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.
def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" print(f"save the best model as '{model_name}' during training.") crop_size = [int(cz) for cz in crop_size.split(",")] print(f"input image crop_size: {crop_size}") # starting training set loader train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) if np.any([cz == -1 for cz in crop_size]): # using full image train_transform = Compose([ AddChannelDict(keys="image"), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform) # when bs > 1, the loader assumes that the full image sizes are the same across the dataset train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) else: # draw balanced foreground/background window samples according to the ground truth label train_transform = Compose([ AddChannelDict(keys="image"), SpatialPadd( keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size RandCropByPosNegLabeld(keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform ) # each dataset item is a list of windows train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate) first_sample = first(train_dataloader) print(first_sample["image"].shape) # starting validation set loader val_transform = Compose([AddChannelDict(keys="image")]) val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) print(val_dataset[0]["image"].shape) print( f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}" ) model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) model = flatten_sequential(model) lossweight = torch.from_numpy( np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) if optimizer.lower() == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 elif optimizer.lower() == "momentum": optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning else: raise ValueError( f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')." ) # config GPipe x = first_sample["image"].float() x = torch.autograd.Variable(x.cuda()) partitions = torch.cuda.device_count() print(f"partition: {partitions}, input: {x.size()}") balance = balance_by_size(partitions, model, x) model = GPipe(model, balance, chunks=4, checkpoint="always") # config loss functions dice_loss_func = DiceLoss(softmax=True, reduction="none") # use the same pipeline and loss in # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, # Medical Physics, 2018. focal_loss_func = FocalLoss(reduction="none") if pretrain: print(f"loading from {pretrain}.") pretrained_dict = torch.load(pretrain)["weight"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) b_time = time.time() best_val_loss = [0] * (N_CLASSES - 1) # foreground for epoch in range(ep): model.train() trainloss = 0 for b_idx, data_dict in enumerate(train_dataloader): x_train = data_dict["image"] y_train = data_dict["label"] flagvec = data_dict["with_complete_groundtruth"] x_train = torch.autograd.Variable(x_train.cuda()) y_train = torch.autograd.Variable(y_train.cuda().float()) optimizer.zero_grad() o = model(x_train).to(0, non_blocking=True).float() loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss.backward() optimizer.step() trainloss += loss.item() if b_idx % 20 == 0: print( f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}" ) print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") if epoch % 10 == 0: model.eval() # check validation dice val_loss = [0] * (N_CLASSES - 1) n_val = [0] * (N_CLASSES - 1) for data_dict in val_dataloader: x_val = data_dict["image"] y_val = data_dict["label"] with torch.no_grad(): x_val = torch.autograd.Variable(x_val.cuda()) o = model(x_val).to(0, non_blocking=True) loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) val_loss = [ l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss) ] n_val = [ n + 1 if l == l else n for l, n in zip(loss[0], n_val) ] val_loss = [l / n for l, n in zip(val_loss, n_val)] print( "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) for c in range(1, 10): if best_val_loss[c - 1] < val_loss[c - 1]: best_val_loss[c - 1] = val_loss[c - 1] state = { "epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1] } torch.save(state, f"{model_name}" + str(c)) print( "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) print("total time", time.time() - b_time)
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def main(cfg): # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- # Create log/model dir log_dir = create_log_dir(cfg) # Set the logger logging.basicConfig( format="%(asctime)s %(levelname)2s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) log_name = os.path.join(log_dir, "logs.txt") logger = logging.getLogger() fh = logging.FileHandler(log_name) fh.setLevel(logging.INFO) logger.addHandler(fh) # Set TensorBoard summary writer writer = SummaryWriter(log_dir) # Save configs logging.info(json.dumps(cfg)) with open(os.path.join(log_dir, "config.json"), "w") as fp: json.dump(cfg, fp, indent=4) # Set device cuda/cpu device = set_device(cfg) # Set cudnn benchmark/deterministic if cfg["benchmark"]: torch.backends.cudnn.benchmark = True else: set_determinism(seed=0) # ------------------------------------------------------------------------- # Transforms and Datasets # ------------------------------------------------------------------------- # Pre-processing preprocess_cpu_train = None preprocess_gpu_train = None preprocess_cpu_valid = None preprocess_gpu_valid = None if cfg["backend"] == "cucim": preprocess_cpu_train = Compose([ToTensorD(keys="label")]) preprocess_gpu_train = Compose([ Range()(ToCupy()), Range("ColorJitter")(RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04)), Range("RandomFlip")(RandCuCIM(name="image_flip", apply_prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90", prob=cfg["prob"], max_k=3, spatial_axis=(-2, -1))), Range()(CastToType(dtype=np.float32)), Range("RandomZoom")(RandCuCIM(name="rand_zoom", min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensor(device=device)), ]) preprocess_cpu_valid = Compose([ToTensorD(keys="label")]) preprocess_gpu_valid = Compose([ Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)), Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensor(device=device)), ]) elif cfg["backend"] == "numpy": preprocess_cpu_train = Compose([ Range()(ToTensorD(keys=("image", "label"))), Range("ColorJitter")(TorchVisionD( keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04, )), Range()(ToNumpyD(keys="image")), Range("RandomFlip")(RandFlipD(keys="image", prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandRotate90D(keys="image", prob=cfg["prob"])), Range()(CastToTypeD(keys="image", dtype=np.float32)), Range("RandomZoom")(RandZoomD(keys="image", prob=cfg["prob"], min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensorD(keys="image")), ]) preprocess_cpu_valid = Compose([ Range("ValidCastType")(CastToTypeD(keys="image", dtype=np.float32)), Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensorD(keys=("image", "label"))), ]) else: raise ValueError( f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]" ) # Post-processing postprocess = Compose([ Activations(sigmoid=True), AsDiscrete(threshold=0.5), ]) # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( data=train_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_train, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( data=valid_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_valid, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("First sample is None!") for d in ["image", "label"]: logging.info(f"[{d}] \n" f" {d} shape: {first_sample[d].shape}\n" f" {d} type: {type(first_sample[d])}\n" f" {d} dtype: {first_sample[d].dtype}") logging.info(f"Batch size: {cfg['batch_size']}") logging.info(f"[Training] number of batches: {len(train_dataloader)}") logging.info(f"[Validation] number of batches: {len(valid_dataloader)}") # ------------------------------------------------------------------------- # Deep Learning Model and Configurations # ------------------------------------------------------------------------- # Initialize model model = TorchVisionFCModel("resnet18", n_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # Loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # Optimizer if cfg["novograd"] is True: optimizer = Novograd(model.parameters(), lr=cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6) if cfg["amp"] is True: scaler = GradScaler() else: scaler = None # Learning rate scheduler if cfg["cos"] is True: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) else: scheduler = None # ------------------------------------------------------------------------- # Training/Evaluating # ------------------------------------------------------------------------- train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1} total_valid_time, total_train_time = 0.0, 0.0 t_start = time.perf_counter() metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1} # Training/Validation Loop for _ in range(cfg["n_epochs"]): t_epoch = time.perf_counter() logging.info( f"[Training] learning rate: {optimizer.param_groups[0]['lr']}") # Training with Range("Training Epoch"): train_counter = training( train_counter, model, loss_func, optimizer, scaler, cfg["amp"], train_dataloader, preprocess_gpu_train, postprocess, device, writer, cfg["print_step"], ) if scheduler is not None: scheduler.step() if cfg["save"]: torch.save( model.state_dict(), os.path.join(log_dir, f"model_epoch_{train_counter['epoch']}.pt")) t_train = time.perf_counter() train_time = t_train - t_epoch total_train_time += train_time # Validation if cfg["validate"]: with Range("Validation"): valid_loss, valid_acc = validation( model, loss_func, cfg["amp"], valid_dataloader, preprocess_gpu_valid, postprocess, device, cfg["print_step"], ) t_valid = time.perf_counter() valid_time = t_valid - t_train total_valid_time += valid_time if valid_loss < metric_summary["loss"]: metric_summary["loss"] = min(valid_loss, metric_summary["loss"]) metric_summary["accuracy"] = max(valid_acc, metric_summary["accuracy"]) metric_summary["best_epoch"] = train_counter["epoch"] writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"]) writer.add_scalar("valid/accuracy", valid_acc, train_counter["epoch"]) logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, " f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)" ) else: logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s" ) writer.flush() t_end = time.perf_counter() # Save final metrics metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"] metric_summary["total_time"] = t_end - t_start writer.add_hparams(hparam_dict=cfg, metric_dict=metric_summary, run_name=log_dir) writer.close() logging.info(f"Metric Summary: {metric_summary}") # Save the best and final model if cfg["validate"] is True: copyfile( os.path.join(log_dir, f"model_epoch_{metric_summary['best_epoch']}.pth"), os.path.join(log_dir, "model_best.pth"), ) copyfile( os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"), os.path.join(log_dir, "model_final.pth"), ) # Final prints logging.info( f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s " f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)", ) logging.info(f"Logs and model was saved at: {log_dir}")
def decollate_batch(batch, detach: bool = True): """De-collate a batch of data (for example, as produced by a `DataLoader`). Returns a list of structures with the original tensor's 0-th dimension sliced into elements using `torch.unbind`. Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, such as metadata, may have been stored in a list (or a list inside nested dictionaries). In this case we return the element of the list corresponding to the batch idx. Return types aren't guaranteed to be the same as the original, since numpy arrays will have been converted to torch.Tensor, sequences may be converted to lists of tensors, mappings may be converted into dictionaries. For example: .. code-block:: python batch_data = { "image": torch.rand((2,1,10,10)), "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} } out = decollate_batch(batch_data) print(len(out)) >>> 2 print(out[0]) >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))] out = decollate_batch(batch_data) print(out[0]) >>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])] batch_data = torch.rand((2,1,10,10)) out = decollate_batch(batch_data) print(out[0]) >>> tensor([[[4.3549e-01...43e-01]]]) Args: batch: data to be de-collated. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. """ if batch is None: return batch if isinstance(batch, (float, int, str, bytes)): return batch if isinstance(batch, torch.Tensor): if detach: batch = batch.detach() if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) if isinstance(batch, Mapping): _dict_list = { key: decollate_batch(batch[key], detach) for key in batch } return [ dict(zip(_dict_list, item)) for item in zip(*_dict_list.values()) ] if isinstance(batch, Iterable): item_0 = first(batch) if (not isinstance(item_0, Iterable) or isinstance(item_0, (str, bytes)) or (isinstance(item_0, torch.Tensor) and item_0.ndim == 0)): # Not running the usual list decollate here: # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']] # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor return [decollate_batch(b, detach) for b in batch] return [ list(item) for item in zip(*(decollate_batch(b, detach) for b in batch)) ] raise NotImplementedError( f"Unable to de-collate: {batch}, type: {type(batch)}.")
val_image_trans = Compose([ScaleIntensity(), AddChannel(), ToTensor(),]) val_seg_trans = Compose([AddChannel(), ToTensor()]) val_ds = ArrayDataset(test_images, val_image_trans, test_segs, val_seg_trans) val_loader = DataLoader( dataset=val_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=torch.cuda.is_available(), ) # %timeit first(loader) batch = first(loader) im = batch["img"] seg = batch["seg"] print(im.shape, im.min(), im.max(), seg.shape) plt.imshow(im[0, 0].numpy() + seg[0, 0].numpy(), cmap="gray") #%% #Using UNet for the segmentation network #Training scheme: For each epoch the training is made on each batch of iamges from the training set, assim the training is made with each image once, and then is evaluated with the validation set. net = monai.networks.nets.UNet( dimensions=2, in_channels=1,