def linked_batch_augmentation(input_batch, preds_unsup): # Teach transformation teacher_transform = tv.transforms.Compose([ mt_transforms.ToPIL(labeled=False), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3, labeled=False), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03), labeled=False), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor(labeled=False), ]) input_batch_size = input_batch.size(0) input_batch_cpu = input_batch.cpu().detach() input_batch_cpu = input_batch_cpu.numpy() preds_unsup_cpu = preds_unsup.cpu().detach() preds_unsup_cpu = preds_unsup_cpu.numpy() samples_linked_aug = [] for sample_idx in range(input_batch_size): sample_linked_aug = { 'input': [input_batch_cpu[sample_idx], preds_unsup_cpu[sample_idx]] } out = teacher_transform(sample_linked_aug) samples_linked_aug.append(out) samples_linked_aug = mt_datasets.mt_collate(samples_linked_aug) return samples_linked_aug
def run_main(): train_transform = transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) val_transform = transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # Here we assume that the SC GM Challenge data is inside the folder # "../data" and it was previously resampled. gmdataset_train = mt_datasets.SCGMChallenge2DTrain( root_dir="../data", subj_ids=range(1, 9), transform=train_transform, slice_filter_fn=mt_filters.SliceFilter()) # Here we assume that the SC GM Challenge data is inside the folder # "../data" and it was previously resampled. gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data", subj_ids=range(9, 11), transform=val_transform) train_loader = DataLoader(gmdataset_train, batch_size=16, shuffle=True, pin_memory=True, collate_fn=mt_datasets.mt_collate, num_workers=1) val_loader = DataLoader(gmdataset_val, batch_size=16, shuffle=True, pin_memory=True, collate_fn=mt_datasets.mt_collate, num_workers=1) model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1) model.cuda() num_epochs = 200 initial_lr = 0.001 optimizer = optim.Adam(model.parameters(), lr=initial_lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) writer = SummaryWriter(log_dir="log_exp") for epoch in tqdm(range(1, num_epochs + 1)): start_time = time.time() scheduler.step() lr = scheduler.get_lr()[0] writer.add_scalar('learning_rate', lr, epoch) model.train() train_loss_total = 0.0 num_steps = 0 for i, batch in enumerate(train_loader): input_samples, gt_samples = batch["input"], batch["gt"] var_input = input_samples.cuda() var_gt = gt_samples.cuda(non_blocking=True) preds = model(var_input) loss = mt_losses.dice_loss(preds, var_gt) train_loss_total += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() num_steps += 1 if epoch % 5 == 0: grid_img = vutils.make_grid(input_samples, normalize=True, scale_each=True) writer.add_image('Input', grid_img, epoch) grid_img = vutils.make_grid(preds.data.cpu(), normalize=True, scale_each=True) writer.add_image('Predictions', grid_img, epoch) grid_img = vutils.make_grid(gt_samples, normalize=True, scale_each=True) writer.add_image('Ground Truth', grid_img, epoch) train_loss_total_avg = train_loss_total / num_steps model.eval() val_loss_total = 0.0 num_steps = 0 metric_fns = [ mt_metrics.dice_score, mt_metrics.hausdorff_score, mt_metrics.precision_score, mt_metrics.recall_score, mt_metrics.specificity_score, mt_metrics.intersection_over_union, mt_metrics.accuracy_score ] metric_mgr = mt_metrics.MetricManager(metric_fns) for i, batch in enumerate(val_loader): input_samples, gt_samples = batch["input"], batch["gt"] with torch.no_grad(): var_input = input_samples.cuda() var_gt = gt_samples.cuda(async=True) preds = model(var_input) loss = mt_losses.dice_loss(preds, var_gt) val_loss_total += loss.item() # Metrics computation gt_npy = gt_samples.numpy().astype(np.uint8) gt_npy = gt_npy.squeeze(axis=1) preds = preds.data.cpu().numpy() preds = threshold_predictions(preds) preds = preds.astype(np.uint8) preds = preds.squeeze(axis=1) metric_mgr(preds, gt_npy) num_steps += 1 metrics_dict = metric_mgr.get_results() metric_mgr.reset() writer.add_scalars('metrics', metrics_dict, epoch) val_loss_total_avg = val_loss_total / num_steps writer.add_scalars('losses', { 'val_loss': val_loss_total_avg, 'train_loss': train_loss_total_avg }, epoch) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) writer.add_scalars('losses', {'train_loss': train_loss_total_avg}, epoch)
def cmd_train(context): """Main command do train the network. :param context: this is a dictionary with all data from the configuration file: - 'command': run the specified command (e.g. train, test) - 'gpu': ID of the used GPU - 'bids_path_train': list of relative paths of the BIDS folders of each training center - 'bids_path_validation': list of relative paths of the BIDS folders of each validation center - 'bids_path_test': list of relative paths of the BIDS folders of each test center - 'batch_size' - 'dropout_rate' - 'batch_norm_momentum' - 'num_epochs' - 'initial_lr': initial learning rate - 'log_directory': folder name where log files are saved """ # Set the GPU gpu_number = context["gpu"] torch.cuda.set_device(gpu_number) # These are the training transformations train_transform = transforms.Compose([ mt_transforms.CenterCrop2D((128, 128)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # These are the validation/testing transformations val_transform = transforms.Compose([ mt_transforms.CenterCrop2D((128, 128)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # This code will iterate over the folders and load the data, filtering # the slices without labels and then concatenating all the datasets together train_datasets = [] for bids_ds in tqdm(context["bids_path_train"], desc="Loading training set"): ds_train = loader.BidsDataset(bids_ds, transform=train_transform, slice_filter_fn=loader.SliceFilter()) train_datasets.append(ds_train) ds_train = ConcatDataset(train_datasets) print(f"Loaded {len(ds_train)} axial slices for the training set.") train_loader = DataLoader(ds_train, batch_size=context["batch_size"], shuffle=True, pin_memory=True, collate_fn=mt_datasets.mt_collate, num_workers=1) # Validation dataset ------------------------------------------------------ validation_datasets = [] for bids_ds in tqdm(context["bids_path_validation"], desc="Loading validation set"): ds_val = loader.BidsDataset(bids_ds, transform=val_transform, slice_filter_fn=loader.SliceFilter()) validation_datasets.append(ds_val) ds_val = ConcatDataset(validation_datasets) print(f"Loaded {len(ds_val)} axial slices for the validation set.") val_loader = DataLoader(ds_val, batch_size=context["batch_size"], shuffle=True, pin_memory=True, collate_fn=mt_datasets.mt_collate, num_workers=1) model = M.Classifier(drop_rate=context["dropout_rate"], bn_momentum=context["batch_norm_momentum"]) model.cuda() num_epochs = context["num_epochs"] initial_lr = context["initial_lr"] # Using SGD with cosine annealing learning rate optimizer = optim.SGD(model.parameters(), lr=initial_lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=context["log_directory"]) # Cross Entropy Loss criterion = nn.CrossEntropyLoss() # Training loop ----------------------------------------------------------- best_validation_loss = float("inf") lst_train_loss = [] lst_val_loss = [] lst_accuracy = [] for epoch in tqdm(range(1, num_epochs + 1), desc="Training"): start_time = time.time() scheduler.step() lr = scheduler.get_lr()[0] writer.add_scalar('learning_rate', lr, epoch) model.train() train_loss_total = 0.0 num_steps = 0 for i, batch in enumerate(train_loader): input_samples = batch["input"] input_labels = get_modality(batch) var_input = input_samples.cuda() var_labels = torch.cuda.LongTensor(input_labels).cuda( non_blocking=True) outputs = model(var_input) loss = criterion(outputs, var_labels) train_loss_total += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() num_steps += 1 train_loss_total_avg = train_loss_total / num_steps lst_train_loss.append(train_loss_total_avg) tqdm.write(f"Epoch {epoch} training loss: {train_loss_total_avg:.4f}.") # Validation loop ----------------------------------------------------- model.eval() val_loss_total = 0.0 num_steps = 0 val_accuracy = 0 num_samples = 0 for i, batch in enumerate(val_loader): input_samples = batch["input"] input_labels = get_modality(batch) with torch.no_grad(): var_input = input_samples.cuda() var_labels = torch.cuda.LongTensor(input_labels).cuda( non_blocking=True) outputs = model(var_input) _, preds = torch.max(outputs, 1) loss = criterion(outputs, var_labels) val_loss_total += loss.item() val_accuracy += int((var_labels == preds).sum()) num_steps += 1 num_samples += context['batch_size'] val_loss_total_avg = val_loss_total / num_steps lst_val_loss.append(val_loss_total_avg) tqdm.write(f"Epoch {epoch} validation loss: {val_loss_total_avg:.4f}.") val_accuracy_avg = 100 * val_accuracy / num_samples lst_accuracy.append(val_accuracy_avg) tqdm.write(f"Epoch {epoch} accuracy : {val_accuracy_avg:.4f}.") # add metrics for tensorboard writer.add_scalars('validation metrics', { 'accuracy': val_accuracy_avg, }, epoch) writer.add_scalars('losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) if val_loss_total_avg < best_validation_loss: best_validation_loss = val_loss_total_avg torch.save(model, "./" + context["log_directory"] + "/best_model.pt") # save final model torch.save(model, "./" + context["log_directory"] + "/final_model.pt") # save the metrics parameters = "CrossEntropyLoss/batchsize=" + str(context['batch_size']) parameters += "/initial_lr=" + str(context['initial_lr']) parameters += "/dropout=" + str(context['dropout_rate']) plt.subplot(2, 1, 1) plt.title(parameters) plt.plot(lst_train_loss, color='red', label='Training') plt.plot(lst_val_loss, color='blue', label='Validation') plt.legend(loc='upper right') plt.ylabel('Loss') plt.subplot(2, 1, 2) plt.plot(lst_accuracy) plt.ylabel('Accuracy') plt.xlabel('Epoch') plt.savefig(parameters + '.png') return
import numpy as np from batchgenerators.transforms.color_transforms import GammaTransform from batchgenerators.transforms.spatial_transforms import MirrorTransform, SpatialTransform, ZoomTransform from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform from medicaltorch import transforms as mt_transforms from medicaltorch import losses as mt_losses from torchvision import transforms train_transform = transforms.Compose([ # mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor() # mt_transforms.NormalizeInstance(), ]) gamma_t = GammaTransform(data_key="img", gamma_range=(0.1, 10)) mirror_t = MirrorTransform(data_key="img", label_key="seg") spatial_t = SpatialTransform(patch_size=(8, 8, 8), data_key="img", label_key="seg") gauss_noise_t = GaussianNoiseTransform(data_key="img", noise_variance=(0, 1))
def __init__(self, tcia_folder=None, brats_folder=None, lgg_folder=None, hgg_folder=None, type_str='T1', stage='lgg', flag_3d=False, mode='train', channel_size_3d=32, mri_slice_dim=128, aug=False): assert stage in ['lgg', 'hgg'] assert type_str in ['T1', 'T2', 'FLAIR'] Dataset.__init__(self) self.flag_3d = flag_3d self.tcia_folder = tcia_folder self.brats_folder = brats_folder self.lgg_folder = lgg_folder self.hgg_folder = hgg_folder if stage == 'lgg': self.folders = glob.glob(self.tcia_folder + '/*') self.dataset_types = ['tcia' for _ in range(len(self.folders))] brats = glob.glob(self.brats_folder + '/LGG/*') self.folders.extend(brats) self.dataset_types.extend(['brats' for _ in range(len(brats))]) lgg = [ x for x in glob.glob(self.lgg_folder + '/*') if os.path.isdir(x) ] self.folders.extend(lgg) self.dataset_types.extend(['lgg' for _ in range(len(lgg))]) else: self.folders = glob.glob(self.brats_folder + '/HGG/*') self.dataset_types = ['brats' for _ in range(len(self.folders))] hgg = [ x for x in glob.glob(self.hgg_folder + '/*') if os.path.isdir(x) ] self.folders.extend(hgg) self.dataset_types.extend(['hgg' for _ in range(len(hgg))]) self.type = type_str self.stage = stage self.channel_size_3d = channel_size_3d self.seg_mapping = { 'tcia': 'Segmentation', 'brats': 'seg', 'lgg': ['GlistrBoost_ManuallyCorrected', 'GlistrBoost*'], 'hgg': ['GlistrBoost_ManuallyCorrected', 'GlistrBoost*'] } self.type_mapping = { 'tcia': self.type + '*', 'brats': self.type.lower(), 'lgg': self.type.lower(), 'hgg': self.type.lower() } self.segmentation_pairs = [] for idx in range(len(self.folders)): if isinstance(self.seg_mapping[self.dataset_types[idx]], list): try: seg_fname = glob.glob( self.folders[idx] + '/*' + self.seg_mapping[self.dataset_types[idx]][0] + '.nii.gz')[0] except Exception: seg_fname = glob.glob( self.folders[idx] + '/*' + self.seg_mapping[self.dataset_types[idx]][1] + '.nii.gz')[0] else: seg_fname = glob.glob( self.folders[idx] + '/*' + self.seg_mapping[self.dataset_types[idx]] + '.nii.gz')[0] vox_fname_list = glob.glob( self.folders[idx] + '/*' + self.type_mapping[self.dataset_types[idx]] + '.nii.gz') if vox_fname_list == []: continue else: vox_fname = vox_fname_list[0] self.segmentation_pairs.append([vox_fname, seg_fname]) spl = [.8, .1, .1] train_ptr = int(spl[0] * len(self.segmentation_pairs)) val_ptr = train_ptr + int(spl[1] * len(self.segmentation_pairs)) if not flag_3d: if aug: train_transforms = transforms.Compose([ MTResize((mri_slice_dim, mri_slice_dim)), transforms.RandomChoice([ mt_transforms.RandomRotation(30), mt_transforms.ElasticTransform(alpha=2000, sigma=50), mt_transforms.AdditiveGaussianNoise(mean=0.05, std=0.01), mt_transforms.RandomAffine(degrees, translate=0.2, scale=(0.8, 1.2), shear=0.2) ]), mt_transforms.ToTensor(), MTNormalize() ]) else: train_transforms = transforms.Compose([ MTResize((mri_slice_dim, mri_slice_dim)), mt_transforms.ToTensor(), MTNormalize() ]) val_transforms = transforms.Compose([ transforms.Resize((mri_slice_dim, mri_slice_dim)), mt_transforms.ToTensor(), MTNormalize() ]) train_unnormalized = train_transforms else: if aug: train_transforms = transforms.Compose([ ToPILImage3D(), Resize3D((mri_slice_dim, mri_slice_dim)), transforms.RandomChoice([ RandomHorizontalFlip3D(), RandomVerticalFlip3D(), RandomRotation3D(30), RandomShear3D(45, translate=.4, scale=(.7, 1.3), shear=.2) ]), ToTensor3D(), Normalize3D('min_max') ]) train_unnormalized = transforms.Compose([ ToPILImage3D(), Resize3D((mri_slice_dim, mri_slice_dim)), transforms.RandomChoice([ RandomHorizontalFlip3D(), RandomVerticalFlip3D(), RandomRotation3D(30) ]), ToTensor3D(), ]) else: train_transforms = transforms.Compose([ ToPILImage3D(), Resize3D((mri_slice_dim, mri_slice_dim)), ToTensor3D(), Normalize3D('min_max') ]) train_unnormalized = transforms.Compose([ ToPILImage3D(), Resize3D((mri_slice_dim, mri_slice_dim)), ToTensor3D(), ]) val_transforms = transforms.Compose([ ToPILImage3D(), Resize3D((mri_slice_dim, mri_slice_dim)), ToTensor3D(), IndividualNormalize3D(), ]) if mode == 'train': self.segmentation_pairs = self.segmentation_pairs[:train_ptr] self.transforms = train_transforms self.seg_transforms = train_unnormalized elif mode == 'val': self.segmentation_pairs = self.segmentation_pairs[ train_ptr:val_ptr] self.transforms = val_transforms self.seg_transforms = train_unnormalized else: self.segmentation_pairs = self.segmentation_pairs[val_ptr:] self.transforms = val_transforms self.seg_transforms = train_unnormalized if not flag_3d: self.twod_slices_dataset = mt_datasets.MRI2DSegmentationDataset( self.segmentation_pairs, transform=self.transforms)
from medicaltorch import transforms as mt_transforms from medicaltorch import losses as mt_losses from torchvision import transforms packed_transforms = [ mt_transforms.RandomRotation(degrees=(90, 180)), mt_transforms.ElasticTransform(), mt_transforms.AdditiveGaussianNoise(mean=0.0, std=0.05), mt_transforms.RandomAffine(), mt_transforms.ToTensor() ]
def cmd_train(ctx): global_step = 0 num_workers = ctx["num_workers"] num_epochs = ctx["num_epochs"] experiment_name = ctx["experiment_name"] cons_weight = ctx["cons_weight"] initial_lr = ctx["initial_lr"] consistency_rampup = ctx["consistency_rampup"] weight_decay = ctx["weight_decay"] rootdir_gmchallenge_train = ctx["rootdir_gmchallenge_train"] rootdir_gmchallenge_test = ctx["rootdir_gmchallenge_test"] supervised_only = ctx["supervised_only"] """ experiment_name """ # experiment_name += '-e%s-cw%s-lr%s-cr%s-lramp%s-wd%s-cl%s-sc%s-ac%s-vc%s' % \ # (num_epochs, cons_weight, initial_lr, consistency_rampup, # ctx["initial_lr_rampup"], weight_decay, ctx["consistency_loss"], # ctx["source_centers"], ctx["adapt_centers"], ctx["val_centers"]) # Decay for learning rate if "constant" in ctx["decay_lr"]: decay_lr_fn = decay_constant_lr if "poly" in ctx["decay_lr"]: decay_lr_fn = decay_poly_lr if "cosine" in ctx["decay_lr"]: decay_lr_fn = cosine_lr # Consistency loss # # mse = Mean Squared Error # dice = Dice loss # cross_entropy = Cross Entropy # mse_confidence = MSE with Confidence Threshold if ctx["consistency_loss"] == "dice": consistency_loss_fn = mt_losses.dice_loss if ctx["consistency_loss"] == "mse": consistency_loss_fn = F.mse_loss if ctx["consistency_loss"] == "cross_entropy": consistency_loss_fn = F.binary_cross_entropy if ctx["consistency_loss"] == "mse_confident": confidence_threshold = ctx["confidence_threshold"] consistency_loss_fn = mt_losses.ConfidentMSELoss(confidence_threshold) # Xs, Ys = Source input and source label, train # Xt1, Xt2 = Target, domain adaptation, no label, different aug (same sample), train # Xv, Yv = Target input and target label, validation # Sample Xs and Ys from this source_train = mt_datasets.SCGMChallenge2DTrain( rootdir_gmchallenge_train, slice_filter_fn=mt_filters.SliceFilter(), site_ids=ctx["source_centers"], # Test = 1,2,3, train = 1,2 subj_ids=range(1, 11)) # Sample Xt1, Xt2 from this unlabeled_filter = mt_filters.SliceFilter(filter_empty_mask=False) target_adapt_train = mt_datasets.SCGMChallenge2DTest( rootdir_gmchallenge_test, slice_filter_fn=unlabeled_filter, site_ids=ctx["adapt_centers"], # 3 = train, 4 = test subj_ids=range(11, 21)) # Sample Xv, Yv from this validation_centers = [] for center in ctx["val_centers"]: validation_centers.append( mt_datasets.SCGMChallenge2DTrain( rootdir_gmchallenge_train, slice_filter_fn=mt_filters.SliceFilter(), site_ids=[center], # 3 = train, 4 = test subj_ids=range(1, 11))) # Training source data augmentation source_transform = tv.transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # Target adaptation data augmentation target_adapt_transform = tv.transforms.Compose([ mt_transforms.CenterCrop2D((200, 200), labeled=False), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # Target adaptation data augmentation target_val_adapt_transform = tv.transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) source_train.set_transform(source_transform) target_adapt_train.set_transform(target_adapt_transform) for center in validation_centers: center.set_transform(target_val_adapt_transform) source_train_loader = DataLoader(source_train, batch_size=ctx["source_batch_size"], shuffle=True, drop_last=True, num_workers=num_workers, collate_fn=mt_datasets.mt_collate, pin_memory=True) target_adapt_train_loader = DataLoader(target_adapt_train, batch_size=ctx["target_batch_size"], shuffle=True, drop_last=True, num_workers=num_workers, collate_fn=mt_datasets.mt_collate, pin_memory=True) validation_centers_loaders = [] for center in validation_centers: validation_centers_loaders.append( DataLoader(center, batch_size=ctx["target_batch_size"], shuffle=False, drop_last=False, num_workers=num_workers, collate_fn=mt_datasets.mt_collate, pin_memory=True)) model = create_model(ctx) if not supervised_only: model_ema = create_model(ctx, ema=True) else: model_ema = None optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) writer = SummaryWriter(log_dir="log_{}".format(experiment_name)) # Training loop for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"): start_time = time.time() # Rampup ----- initial_lr_rampup = ctx["initial_lr_rampup"] if initial_lr_rampup > 0: if epoch <= initial_lr_rampup: lr = initial_lr * sigmoid_rampup(epoch, initial_lr_rampup) else: lr = decay_lr_fn(epoch - initial_lr_rampup, num_epochs - initial_lr_rampup, initial_lr) else: lr = decay_lr_fn(epoch, num_epochs, initial_lr) writer.add_scalar('learning_rate', lr, epoch) for param_group in optimizer.param_groups: tqdm.write("Learning Rate: {:.6f}".format(lr)) param_group['lr'] = lr consistency_weight = get_current_consistency_weight( cons_weight, epoch, consistency_rampup) writer.add_scalar('consistency_weight', consistency_weight, epoch) # Train mode model.train() if not supervised_only: model_ema.train() composite_loss_total = 0.0 class_loss_total = 0.0 consistency_loss_total = 0.0 num_steps = 0 target_adapt_train_iter = iter(target_adapt_train_loader) for i, train_batch in enumerate(source_train_loader): # Keys: 'input', 'gt', 'input_metadata', 'gt_metadata' # Supervised component -------------------------------------------- train_input, train_gt = train_batch["input"], train_batch["gt"] train_input = train_input.cuda() train_gt = train_gt.cuda(async=True) preds_supervised = model(train_input) class_loss = mt_losses.dice_loss(preds_supervised, train_gt) if not supervised_only: # Unsupervised component ------------------------------------------ try: target_adapt_batch = target_adapt_train_iter.next() except StopIteration: target_adapt_train_iter = iter(target_adapt_train_loader) target_adapt_batch = target_adapt_train_iter.next() target_adapt_input = target_adapt_batch["input"] target_adapt_input = target_adapt_input.cuda() # Teacher forward with torch.no_grad(): teacher_preds_unsup = model_ema(target_adapt_input) linked_aug_batch = \ linked_batch_augmentation(target_adapt_input, teacher_preds_unsup) adapt_input_batch = linked_aug_batch['input'][0].cuda() teacher_preds_unsup_aug = linked_aug_batch['input'][1].cuda() # Student forward student_preds_unsup = model(adapt_input_batch) consistency_loss = consistency_weight * consistency_loss_fn( student_preds_unsup, teacher_preds_unsup_aug) else: consistency_loss = torch.FloatTensor([0.]).cuda() composite_loss = class_loss + consistency_loss optimizer.zero_grad() composite_loss.backward() optimizer.step() composite_loss_total += composite_loss.item() consistency_loss_total += consistency_loss.item() class_loss_total += class_loss.item() num_steps += 1 global_step += 1 if not supervised_only: if epoch <= ctx["ema_late_epoch"]: update_ema_variables(model, model_ema, ctx["ema_alpha"], global_step) else: update_ema_variables(model, model_ema, ctx["ema_alpha_late"], global_step) # Write histogram of the probs if not supervised_only: npy_teacher_preds = teacher_preds_unsup.detach().cpu().numpy() writer.add_histogram("Teacher Preds Hist", npy_teacher_preds, epoch) npy_student_preds = student_preds_unsup.detach().cpu().numpy() writer.add_histogram("Student Preds Hist", npy_student_preds, epoch) npy_supervised_preds = preds_supervised.detach().cpu().numpy() writer.add_histogram("Supervised Preds Hist", npy_supervised_preds, epoch) composite_loss_avg = composite_loss_total / num_steps class_loss_avg = class_loss_total / num_steps consistency_loss_avg = consistency_loss_total / num_steps tqdm.write("Steps p/ Epoch: {}".format(num_steps)) tqdm.write("Consistency Weight: {:.6f}".format(consistency_weight)) tqdm.write("Composite Loss: {:.6f}".format(composite_loss_avg)) tqdm.write("Class Loss: {:.6f}".format(class_loss_avg)) tqdm.write("Consistency Loss: {:.6f}".format(consistency_loss_avg)) # Write sample images if ctx["write_images"] and epoch % ctx["write_images_interval"] == 0: try: plot_img = vutils.make_grid(preds_supervised, normalize=True, scale_each=True) writer.add_image('Train Source Prediction', plot_img, epoch) plot_img = vutils.make_grid(train_input, normalize=True, scale_each=True) writer.add_image('Train Source Input', plot_img, epoch) plot_img = vutils.make_grid(train_gt, normalize=True, scale_each=True) writer.add_image('Train Source Ground Truth', plot_img, epoch) # Unsupervised component viz if not supervised_only: plot_img = vutils.make_grid(target_adapt_input, normalize=True, scale_each=True) writer.add_image('Train Target Student Input', plot_img, epoch) plot_img = vutils.make_grid(teacher_preds_unsup, normalize=True, scale_each=True) writer.add_image('Train Target Student Preds', plot_img, epoch) plot_img = vutils.make_grid(adapt_input_batch, normalize=True, scale_each=True) writer.add_image('Train Target Teacher Input', plot_img, epoch) plot_img = vutils.make_grid(student_preds_unsup, normalize=True, scale_each=True) writer.add_image('Train Target Teacher Preds', plot_img, epoch) plot_img = vutils.make_grid(student_preds_unsup, normalize=True, scale_each=True) writer.add_image('Train Target Student Preds (augmented)', plot_img, epoch) except: tqdm.write("*** Error writing images ***") writer.add_scalars( 'losses', { 'composite_loss': composite_loss_avg, 'class_loss': class_loss_avg, 'consistency_loss': consistency_loss_avg }, epoch) # Evaluation mode model.eval() if not supervised_only: model_ema.eval() metric_fns = [ mt_metrics.dice_score, mt_metrics.jaccard_score, mt_metrics.hausdorff_score, mt_metrics.precision_score, mt_metrics.recall_score, mt_metrics.specificity_score, mt_metrics.intersection_over_union, mt_metrics.accuracy_score ] for center, loader in enumerate(validation_centers_loaders): validation(model, model_ema, loader, writer, metric_fns, epoch, ctx, 'val_%s' % ctx["val_centers"][center]) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))