def unet3d_loss(preds, label_batch, bce_weight): # smax = F.softmax(preds, dim=1) pred = F.sigmoid(preds) # , dim=1) bce = F.binary_cross_entropy_with_logits(pred, label_batch) dice = mt_losses.dice_loss(pred, label_batch) loss = bce * bce_weight + dice * (1 - bce_weight) return loss
def run_main(): train_transform = transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) val_transform = transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # Here we assume that the SC GM Challenge data is inside the folder # "../data" and it was previously resampled. gmdataset_train = mt_datasets.SCGMChallenge2DTrain( root_dir="../data", subj_ids=range(1, 9), transform=train_transform, slice_filter_fn=mt_filters.SliceFilter()) # Here we assume that the SC GM Challenge data is inside the folder # "../data" and it was previously resampled. gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data", subj_ids=range(9, 11), transform=val_transform) train_loader = DataLoader(gmdataset_train, batch_size=16, shuffle=True, pin_memory=True, collate_fn=mt_datasets.mt_collate, num_workers=1) val_loader = DataLoader(gmdataset_val, batch_size=16, shuffle=True, pin_memory=True, collate_fn=mt_datasets.mt_collate, num_workers=1) model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1) model.cuda() num_epochs = 200 initial_lr = 0.001 optimizer = optim.Adam(model.parameters(), lr=initial_lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) writer = SummaryWriter(log_dir="log_exp") for epoch in tqdm(range(1, num_epochs + 1)): start_time = time.time() scheduler.step() lr = scheduler.get_lr()[0] writer.add_scalar('learning_rate', lr, epoch) model.train() train_loss_total = 0.0 num_steps = 0 for i, batch in enumerate(train_loader): input_samples, gt_samples = batch["input"], batch["gt"] var_input = input_samples.cuda() var_gt = gt_samples.cuda(non_blocking=True) preds = model(var_input) loss = mt_losses.dice_loss(preds, var_gt) train_loss_total += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() num_steps += 1 if epoch % 5 == 0: grid_img = vutils.make_grid(input_samples, normalize=True, scale_each=True) writer.add_image('Input', grid_img, epoch) grid_img = vutils.make_grid(preds.data.cpu(), normalize=True, scale_each=True) writer.add_image('Predictions', grid_img, epoch) grid_img = vutils.make_grid(gt_samples, normalize=True, scale_each=True) writer.add_image('Ground Truth', grid_img, epoch) train_loss_total_avg = train_loss_total / num_steps model.eval() val_loss_total = 0.0 num_steps = 0 metric_fns = [ mt_metrics.dice_score, mt_metrics.hausdorff_score, mt_metrics.precision_score, mt_metrics.recall_score, mt_metrics.specificity_score, mt_metrics.intersection_over_union, mt_metrics.accuracy_score ] metric_mgr = mt_metrics.MetricManager(metric_fns) for i, batch in enumerate(val_loader): input_samples, gt_samples = batch["input"], batch["gt"] with torch.no_grad(): var_input = input_samples.cuda() var_gt = gt_samples.cuda(async=True) preds = model(var_input) loss = mt_losses.dice_loss(preds, var_gt) val_loss_total += loss.item() # Metrics computation gt_npy = gt_samples.numpy().astype(np.uint8) gt_npy = gt_npy.squeeze(axis=1) preds = preds.data.cpu().numpy() preds = threshold_predictions(preds) preds = preds.astype(np.uint8) preds = preds.squeeze(axis=1) metric_mgr(preds, gt_npy) num_steps += 1 metrics_dict = metric_mgr.get_results() metric_mgr.reset() writer.add_scalars('metrics', metrics_dict, epoch) val_loss_total_avg = val_loss_total / num_steps writer.add_scalars('losses', { 'val_loss': val_loss_total_avg, 'train_loss': train_loss_total_avg }, epoch) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) writer.add_scalars('losses', {'train_loss': train_loss_total_avg}, epoch)
lr = scheduler.get_lr()[0] model.train() train_loss_total = 0.0 num_steps = 0 ### Training for i, batch in enumerate(train_loader): input_samples, gt_samples = batch["input"], batch["gt"] var_input = input_samples.cuda() var_gt = gt_samples.cuda() preds = model(var_input) loss = mt_losses.dice_loss(preds, var_gt) train_loss_total += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() num_steps += 1 # if epoch % 5 == 0: # grid_img = vutils.make_grid(input_samples, # normalize=True, # scale_each=True) # # grid_img = vutils.make_grid(preds.data.cpu(), # normalize=True, # scale_each=True)
def cmd_train(ctx): global_step = 0 num_workers = ctx["num_workers"] num_epochs = ctx["num_epochs"] experiment_name = ctx["experiment_name"] cons_weight = ctx["cons_weight"] initial_lr = ctx["initial_lr"] consistency_rampup = ctx["consistency_rampup"] weight_decay = ctx["weight_decay"] rootdir_gmchallenge_train = ctx["rootdir_gmchallenge_train"] rootdir_gmchallenge_test = ctx["rootdir_gmchallenge_test"] supervised_only = ctx["supervised_only"] """ experiment_name """ # experiment_name += '-e%s-cw%s-lr%s-cr%s-lramp%s-wd%s-cl%s-sc%s-ac%s-vc%s' % \ # (num_epochs, cons_weight, initial_lr, consistency_rampup, # ctx["initial_lr_rampup"], weight_decay, ctx["consistency_loss"], # ctx["source_centers"], ctx["adapt_centers"], ctx["val_centers"]) # Decay for learning rate if "constant" in ctx["decay_lr"]: decay_lr_fn = decay_constant_lr if "poly" in ctx["decay_lr"]: decay_lr_fn = decay_poly_lr if "cosine" in ctx["decay_lr"]: decay_lr_fn = cosine_lr # Consistency loss # # mse = Mean Squared Error # dice = Dice loss # cross_entropy = Cross Entropy # mse_confidence = MSE with Confidence Threshold if ctx["consistency_loss"] == "dice": consistency_loss_fn = mt_losses.dice_loss if ctx["consistency_loss"] == "mse": consistency_loss_fn = F.mse_loss if ctx["consistency_loss"] == "cross_entropy": consistency_loss_fn = F.binary_cross_entropy if ctx["consistency_loss"] == "mse_confident": confidence_threshold = ctx["confidence_threshold"] consistency_loss_fn = mt_losses.ConfidentMSELoss(confidence_threshold) # Xs, Ys = Source input and source label, train # Xt1, Xt2 = Target, domain adaptation, no label, different aug (same sample), train # Xv, Yv = Target input and target label, validation # Sample Xs and Ys from this source_train = mt_datasets.SCGMChallenge2DTrain( rootdir_gmchallenge_train, slice_filter_fn=mt_filters.SliceFilter(), site_ids=ctx["source_centers"], # Test = 1,2,3, train = 1,2 subj_ids=range(1, 11)) # Sample Xt1, Xt2 from this unlabeled_filter = mt_filters.SliceFilter(filter_empty_mask=False) target_adapt_train = mt_datasets.SCGMChallenge2DTest( rootdir_gmchallenge_test, slice_filter_fn=unlabeled_filter, site_ids=ctx["adapt_centers"], # 3 = train, 4 = test subj_ids=range(11, 21)) # Sample Xv, Yv from this validation_centers = [] for center in ctx["val_centers"]: validation_centers.append( mt_datasets.SCGMChallenge2DTrain( rootdir_gmchallenge_train, slice_filter_fn=mt_filters.SliceFilter(), site_ids=[center], # 3 = train, 4 = test subj_ids=range(1, 11))) # Training source data augmentation source_transform = tv.transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # Target adaptation data augmentation target_adapt_transform = tv.transforms.Compose([ mt_transforms.CenterCrop2D((200, 200), labeled=False), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) # Target adaptation data augmentation target_val_adapt_transform = tv.transforms.Compose([ mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ToTensor(), mt_transforms.NormalizeInstance(), ]) source_train.set_transform(source_transform) target_adapt_train.set_transform(target_adapt_transform) for center in validation_centers: center.set_transform(target_val_adapt_transform) source_train_loader = DataLoader(source_train, batch_size=ctx["source_batch_size"], shuffle=True, drop_last=True, num_workers=num_workers, collate_fn=mt_datasets.mt_collate, pin_memory=True) target_adapt_train_loader = DataLoader(target_adapt_train, batch_size=ctx["target_batch_size"], shuffle=True, drop_last=True, num_workers=num_workers, collate_fn=mt_datasets.mt_collate, pin_memory=True) validation_centers_loaders = [] for center in validation_centers: validation_centers_loaders.append( DataLoader(center, batch_size=ctx["target_batch_size"], shuffle=False, drop_last=False, num_workers=num_workers, collate_fn=mt_datasets.mt_collate, pin_memory=True)) model = create_model(ctx) if not supervised_only: model_ema = create_model(ctx, ema=True) else: model_ema = None optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) writer = SummaryWriter(log_dir="log_{}".format(experiment_name)) # Training loop for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"): start_time = time.time() # Rampup ----- initial_lr_rampup = ctx["initial_lr_rampup"] if initial_lr_rampup > 0: if epoch <= initial_lr_rampup: lr = initial_lr * sigmoid_rampup(epoch, initial_lr_rampup) else: lr = decay_lr_fn(epoch - initial_lr_rampup, num_epochs - initial_lr_rampup, initial_lr) else: lr = decay_lr_fn(epoch, num_epochs, initial_lr) writer.add_scalar('learning_rate', lr, epoch) for param_group in optimizer.param_groups: tqdm.write("Learning Rate: {:.6f}".format(lr)) param_group['lr'] = lr consistency_weight = get_current_consistency_weight( cons_weight, epoch, consistency_rampup) writer.add_scalar('consistency_weight', consistency_weight, epoch) # Train mode model.train() if not supervised_only: model_ema.train() composite_loss_total = 0.0 class_loss_total = 0.0 consistency_loss_total = 0.0 num_steps = 0 target_adapt_train_iter = iter(target_adapt_train_loader) for i, train_batch in enumerate(source_train_loader): # Keys: 'input', 'gt', 'input_metadata', 'gt_metadata' # Supervised component -------------------------------------------- train_input, train_gt = train_batch["input"], train_batch["gt"] train_input = train_input.cuda() train_gt = train_gt.cuda(async=True) preds_supervised = model(train_input) class_loss = mt_losses.dice_loss(preds_supervised, train_gt) if not supervised_only: # Unsupervised component ------------------------------------------ try: target_adapt_batch = target_adapt_train_iter.next() except StopIteration: target_adapt_train_iter = iter(target_adapt_train_loader) target_adapt_batch = target_adapt_train_iter.next() target_adapt_input = target_adapt_batch["input"] target_adapt_input = target_adapt_input.cuda() # Teacher forward with torch.no_grad(): teacher_preds_unsup = model_ema(target_adapt_input) linked_aug_batch = \ linked_batch_augmentation(target_adapt_input, teacher_preds_unsup) adapt_input_batch = linked_aug_batch['input'][0].cuda() teacher_preds_unsup_aug = linked_aug_batch['input'][1].cuda() # Student forward student_preds_unsup = model(adapt_input_batch) consistency_loss = consistency_weight * consistency_loss_fn( student_preds_unsup, teacher_preds_unsup_aug) else: consistency_loss = torch.FloatTensor([0.]).cuda() composite_loss = class_loss + consistency_loss optimizer.zero_grad() composite_loss.backward() optimizer.step() composite_loss_total += composite_loss.item() consistency_loss_total += consistency_loss.item() class_loss_total += class_loss.item() num_steps += 1 global_step += 1 if not supervised_only: if epoch <= ctx["ema_late_epoch"]: update_ema_variables(model, model_ema, ctx["ema_alpha"], global_step) else: update_ema_variables(model, model_ema, ctx["ema_alpha_late"], global_step) # Write histogram of the probs if not supervised_only: npy_teacher_preds = teacher_preds_unsup.detach().cpu().numpy() writer.add_histogram("Teacher Preds Hist", npy_teacher_preds, epoch) npy_student_preds = student_preds_unsup.detach().cpu().numpy() writer.add_histogram("Student Preds Hist", npy_student_preds, epoch) npy_supervised_preds = preds_supervised.detach().cpu().numpy() writer.add_histogram("Supervised Preds Hist", npy_supervised_preds, epoch) composite_loss_avg = composite_loss_total / num_steps class_loss_avg = class_loss_total / num_steps consistency_loss_avg = consistency_loss_total / num_steps tqdm.write("Steps p/ Epoch: {}".format(num_steps)) tqdm.write("Consistency Weight: {:.6f}".format(consistency_weight)) tqdm.write("Composite Loss: {:.6f}".format(composite_loss_avg)) tqdm.write("Class Loss: {:.6f}".format(class_loss_avg)) tqdm.write("Consistency Loss: {:.6f}".format(consistency_loss_avg)) # Write sample images if ctx["write_images"] and epoch % ctx["write_images_interval"] == 0: try: plot_img = vutils.make_grid(preds_supervised, normalize=True, scale_each=True) writer.add_image('Train Source Prediction', plot_img, epoch) plot_img = vutils.make_grid(train_input, normalize=True, scale_each=True) writer.add_image('Train Source Input', plot_img, epoch) plot_img = vutils.make_grid(train_gt, normalize=True, scale_each=True) writer.add_image('Train Source Ground Truth', plot_img, epoch) # Unsupervised component viz if not supervised_only: plot_img = vutils.make_grid(target_adapt_input, normalize=True, scale_each=True) writer.add_image('Train Target Student Input', plot_img, epoch) plot_img = vutils.make_grid(teacher_preds_unsup, normalize=True, scale_each=True) writer.add_image('Train Target Student Preds', plot_img, epoch) plot_img = vutils.make_grid(adapt_input_batch, normalize=True, scale_each=True) writer.add_image('Train Target Teacher Input', plot_img, epoch) plot_img = vutils.make_grid(student_preds_unsup, normalize=True, scale_each=True) writer.add_image('Train Target Teacher Preds', plot_img, epoch) plot_img = vutils.make_grid(student_preds_unsup, normalize=True, scale_each=True) writer.add_image('Train Target Student Preds (augmented)', plot_img, epoch) except: tqdm.write("*** Error writing images ***") writer.add_scalars( 'losses', { 'composite_loss': composite_loss_avg, 'class_loss': class_loss_avg, 'consistency_loss': consistency_loss_avg }, epoch) # Evaluation mode model.eval() if not supervised_only: model_ema.eval() metric_fns = [ mt_metrics.dice_score, mt_metrics.jaccard_score, mt_metrics.hausdorff_score, mt_metrics.precision_score, mt_metrics.recall_score, mt_metrics.specificity_score, mt_metrics.intersection_over_union, mt_metrics.accuracy_score ] for center, loader in enumerate(validation_centers_loaders): validation(model, model_ema, loader, writer, metric_fns, epoch, ctx, 'val_%s' % ctx["val_centers"][center]) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))
def validation(model, model_ema, loader, writer, metric_fns, epoch, ctx, prefix): val_loss = 0.0 ema_val_loss = 0.0 num_samples = 0 num_steps = 0 result_dict = defaultdict(float) result_ema_dict = defaultdict(float) for i, batch in enumerate(loader): input_data, gt_data = batch["input"], batch["gt"] input_data_gpu = input_data.cuda() gt_data_gpu = gt_data.cuda(async=True) with torch.no_grad(): model_out = model(input_data_gpu) val_class_loss = mt_losses.dice_loss(model_out, gt_data_gpu) val_loss += val_class_loss.item() if not ctx["supervised_only"]: model_ema_out = model_ema(input_data_gpu) ema_val_class_loss = mt_losses.dice_loss( model_ema_out, gt_data_gpu) ema_val_loss += ema_val_class_loss.item() gt_masks = gt_data_gpu.cpu().numpy().astype(np.uint8) gt_masks = gt_masks.squeeze(axis=1) preds = model_out.cpu().numpy() preds = threshold_predictions(preds) preds = preds.astype(np.uint8) preds = preds.squeeze(axis=1) for metric_fn in metric_fns: for prediction, ground_truth in zip(preds, gt_masks): res = metric_fn(prediction, ground_truth) dict_key = 'val_{}'.format(metric_fn.__name__) result_dict[dict_key] += res if not ctx["supervised_only"]: preds_ema = model_ema_out.cpu().numpy() preds_ema = threshold_predictions(preds_ema) preds_ema = preds_ema.astype(np.uint8) preds_ema = preds_ema.squeeze(axis=1) for metric_fn in metric_fns: for prediction, ground_truth in zip(preds_ema, gt_masks): res = metric_fn(prediction, ground_truth) dict_key = 'val_ema_{}'.format(metric_fn.__name__) result_ema_dict[dict_key] += res num_samples += len(preds) num_steps += 1 val_loss_avg = val_loss / num_steps for key, val in result_dict.items(): result_dict[key] = val / num_samples if not ctx["supervised_only"]: for key, val in result_ema_dict.items(): result_ema_dict[key] = val / num_samples ema_val_loss_avg = ema_val_loss / num_steps writer.add_scalars(prefix + '_ema_metrics', result_ema_dict, epoch) writer.add_scalars(prefix + '_losses', { prefix + '_loss': val_loss_avg, prefix + '_ema_loss': ema_val_loss_avg }, epoch) else: writer.add_scalars(prefix + '_losses', { prefix + '_loss': val_loss_avg, }, epoch) writer.add_scalars(prefix + '_metrics', result_dict, epoch)