예제 #1
0
def unet3d_loss(preds, label_batch, bce_weight):
    # smax = F.softmax(preds, dim=1)
    pred = F.sigmoid(preds)  # , dim=1)
    bce = F.binary_cross_entropy_with_logits(pred, label_batch)
    dice = mt_losses.dice_loss(pred, label_batch)
    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss
예제 #2
0
def run_main():
    train_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    val_transform = transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_train = mt_datasets.SCGMChallenge2DTrain(
        root_dir="../data",
        subj_ids=range(1, 9),
        transform=train_transform,
        slice_filter_fn=mt_filters.SliceFilter())

    # Here we assume that the SC GM Challenge data is inside the folder
    # "../data" and it was previously resampled.
    gmdataset_val = mt_datasets.SCGMChallenge2DTrain(root_dir="../data",
                                                     subj_ids=range(9, 11),
                                                     transform=val_transform)

    train_loader = DataLoader(gmdataset_train,
                              batch_size=16,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=mt_datasets.mt_collate,
                              num_workers=1)

    val_loader = DataLoader(gmdataset_val,
                            batch_size=16,
                            shuffle=True,
                            pin_memory=True,
                            collate_fn=mt_datasets.mt_collate,
                            num_workers=1)

    model = mt_models.Unet(drop_rate=0.4, bn_momentum=0.1)
    model.cuda()

    num_epochs = 200
    initial_lr = 0.001

    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    writer = SummaryWriter(log_dir="log_exp")
    for epoch in tqdm(range(1, num_epochs + 1)):
        start_time = time.time()

        scheduler.step()

        lr = scheduler.get_lr()[0]
        writer.add_scalar('learning_rate', lr, epoch)

        model.train()
        train_loss_total = 0.0
        num_steps = 0
        for i, batch in enumerate(train_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            var_input = input_samples.cuda()
            var_gt = gt_samples.cuda(non_blocking=True)

            preds = model(var_input)

            loss = mt_losses.dice_loss(preds, var_gt)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_steps += 1

            if epoch % 5 == 0:
                grid_img = vutils.make_grid(input_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Input', grid_img, epoch)

                grid_img = vutils.make_grid(preds.data.cpu(),
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Predictions', grid_img, epoch)

                grid_img = vutils.make_grid(gt_samples,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Ground Truth', grid_img, epoch)

        train_loss_total_avg = train_loss_total / num_steps

        model.eval()
        val_loss_total = 0.0
        num_steps = 0

        metric_fns = [
            mt_metrics.dice_score, mt_metrics.hausdorff_score,
            mt_metrics.precision_score, mt_metrics.recall_score,
            mt_metrics.specificity_score, mt_metrics.intersection_over_union,
            mt_metrics.accuracy_score
        ]

        metric_mgr = mt_metrics.MetricManager(metric_fns)

        for i, batch in enumerate(val_loader):
            input_samples, gt_samples = batch["input"], batch["gt"]

            with torch.no_grad():
                var_input = input_samples.cuda()
                var_gt = gt_samples.cuda(async=True)

                preds = model(var_input)
                loss = mt_losses.dice_loss(preds, var_gt)
                val_loss_total += loss.item()

            # Metrics computation
            gt_npy = gt_samples.numpy().astype(np.uint8)
            gt_npy = gt_npy.squeeze(axis=1)

            preds = preds.data.cpu().numpy()
            preds = threshold_predictions(preds)
            preds = preds.astype(np.uint8)
            preds = preds.squeeze(axis=1)

            metric_mgr(preds, gt_npy)

            num_steps += 1

        metrics_dict = metric_mgr.get_results()
        metric_mgr.reset()

        writer.add_scalars('metrics', metrics_dict, epoch)

        val_loss_total_avg = val_loss_total / num_steps

        writer.add_scalars('losses', {
            'val_loss': val_loss_total_avg,
            'train_loss': train_loss_total_avg
        }, epoch)

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))

        writer.add_scalars('losses', {'train_loss': train_loss_total_avg},
                           epoch)
예제 #3
0
    lr = scheduler.get_lr()[0]

    model.train()
    train_loss_total = 0.0
    num_steps = 0

    ### Training
    for i, batch in enumerate(train_loader):
        input_samples, gt_samples = batch["input"], batch["gt"]

        var_input = input_samples.cuda()
        var_gt = gt_samples.cuda()

        preds = model(var_input)

        loss = mt_losses.dice_loss(preds, var_gt)
        train_loss_total += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        num_steps += 1

        # if epoch % 5 == 0:
        #     grid_img = vutils.make_grid(input_samples,
        #                                 normalize=True,
        #                                 scale_each=True)
        #
        #     grid_img = vutils.make_grid(preds.data.cpu(),
        #                                 normalize=True,
        #                                 scale_each=True)
예제 #4
0
def cmd_train(ctx):
    global_step = 0

    num_workers = ctx["num_workers"]
    num_epochs = ctx["num_epochs"]
    experiment_name = ctx["experiment_name"]
    cons_weight = ctx["cons_weight"]
    initial_lr = ctx["initial_lr"]
    consistency_rampup = ctx["consistency_rampup"]
    weight_decay = ctx["weight_decay"]
    rootdir_gmchallenge_train = ctx["rootdir_gmchallenge_train"]
    rootdir_gmchallenge_test = ctx["rootdir_gmchallenge_test"]
    supervised_only = ctx["supervised_only"]
    """
    experiment_name
    """
    # experiment_name += '-e%s-cw%s-lr%s-cr%s-lramp%s-wd%s-cl%s-sc%s-ac%s-vc%s' % \
    #                     (num_epochs, cons_weight, initial_lr, consistency_rampup,
    #                     ctx["initial_lr_rampup"], weight_decay, ctx["consistency_loss"],
    #                     ctx["source_centers"], ctx["adapt_centers"], ctx["val_centers"])

    # Decay for learning rate
    if "constant" in ctx["decay_lr"]:
        decay_lr_fn = decay_constant_lr

    if "poly" in ctx["decay_lr"]:
        decay_lr_fn = decay_poly_lr

    if "cosine" in ctx["decay_lr"]:
        decay_lr_fn = cosine_lr

    # Consistency loss
    #
    # mse = Mean Squared Error
    # dice = Dice loss
    # cross_entropy = Cross Entropy
    # mse_confidence = MSE with Confidence Threshold
    if ctx["consistency_loss"] == "dice":
        consistency_loss_fn = mt_losses.dice_loss
    if ctx["consistency_loss"] == "mse":
        consistency_loss_fn = F.mse_loss
    if ctx["consistency_loss"] == "cross_entropy":
        consistency_loss_fn = F.binary_cross_entropy
    if ctx["consistency_loss"] == "mse_confident":
        confidence_threshold = ctx["confidence_threshold"]
        consistency_loss_fn = mt_losses.ConfidentMSELoss(confidence_threshold)

    # Xs, Ys = Source input and source label, train
    # Xt1, Xt2 = Target, domain adaptation, no label, different aug (same sample), train
    # Xv, Yv = Target input and target label, validation

    # Sample Xs and Ys from this
    source_train = mt_datasets.SCGMChallenge2DTrain(
        rootdir_gmchallenge_train,
        slice_filter_fn=mt_filters.SliceFilter(),
        site_ids=ctx["source_centers"],  # Test = 1,2,3, train = 1,2
        subj_ids=range(1, 11))

    # Sample Xt1, Xt2 from this
    unlabeled_filter = mt_filters.SliceFilter(filter_empty_mask=False)
    target_adapt_train = mt_datasets.SCGMChallenge2DTest(
        rootdir_gmchallenge_test,
        slice_filter_fn=unlabeled_filter,
        site_ids=ctx["adapt_centers"],  # 3 = train, 4 = test
        subj_ids=range(11, 21))

    # Sample Xv, Yv from this
    validation_centers = []
    for center in ctx["val_centers"]:
        validation_centers.append(
            mt_datasets.SCGMChallenge2DTrain(
                rootdir_gmchallenge_train,
                slice_filter_fn=mt_filters.SliceFilter(),
                site_ids=[center],  # 3 = train, 4 = test
                subj_ids=range(1, 11)))

    # Training source data augmentation
    source_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                       sigma_range=(3.5, 4.0),
                                       p=0.3),
        mt_transforms.RandomAffine(degrees=4.6,
                                   scale=(0.98, 1.02),
                                   translate=(0.03, 0.03)),
        mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Target adaptation data augmentation
    target_adapt_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200), labeled=False),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    # Target adaptation data augmentation
    target_val_adapt_transform = tv.transforms.Compose([
        mt_transforms.CenterCrop2D((200, 200)),
        mt_transforms.ToTensor(),
        mt_transforms.NormalizeInstance(),
    ])

    source_train.set_transform(source_transform)

    target_adapt_train.set_transform(target_adapt_transform)
    for center in validation_centers:
        center.set_transform(target_val_adapt_transform)

    source_train_loader = DataLoader(source_train,
                                     batch_size=ctx["source_batch_size"],
                                     shuffle=True,
                                     drop_last=True,
                                     num_workers=num_workers,
                                     collate_fn=mt_datasets.mt_collate,
                                     pin_memory=True)

    target_adapt_train_loader = DataLoader(target_adapt_train,
                                           batch_size=ctx["target_batch_size"],
                                           shuffle=True,
                                           drop_last=True,
                                           num_workers=num_workers,
                                           collate_fn=mt_datasets.mt_collate,
                                           pin_memory=True)

    validation_centers_loaders = []
    for center in validation_centers:
        validation_centers_loaders.append(
            DataLoader(center,
                       batch_size=ctx["target_batch_size"],
                       shuffle=False,
                       drop_last=False,
                       num_workers=num_workers,
                       collate_fn=mt_datasets.mt_collate,
                       pin_memory=True))

    model = create_model(ctx)

    if not supervised_only:
        model_ema = create_model(ctx, ema=True)
    else:
        model_ema = None

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=initial_lr,
                                 weight_decay=weight_decay)

    writer = SummaryWriter(log_dir="log_{}".format(experiment_name))

    # Training loop
    for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"):
        start_time = time.time()

        # Rampup -----
        initial_lr_rampup = ctx["initial_lr_rampup"]

        if initial_lr_rampup > 0:
            if epoch <= initial_lr_rampup:
                lr = initial_lr * sigmoid_rampup(epoch, initial_lr_rampup)
            else:
                lr = decay_lr_fn(epoch - initial_lr_rampup,
                                 num_epochs - initial_lr_rampup, initial_lr)
        else:
            lr = decay_lr_fn(epoch, num_epochs, initial_lr)

        writer.add_scalar('learning_rate', lr, epoch)

        for param_group in optimizer.param_groups:
            tqdm.write("Learning Rate: {:.6f}".format(lr))
            param_group['lr'] = lr

        consistency_weight = get_current_consistency_weight(
            cons_weight, epoch, consistency_rampup)
        writer.add_scalar('consistency_weight', consistency_weight, epoch)

        # Train mode
        model.train()

        if not supervised_only:
            model_ema.train()

        composite_loss_total = 0.0
        class_loss_total = 0.0
        consistency_loss_total = 0.0

        num_steps = 0
        target_adapt_train_iter = iter(target_adapt_train_loader)

        for i, train_batch in enumerate(source_train_loader):
            # Keys: 'input', 'gt', 'input_metadata', 'gt_metadata'

            # Supervised component --------------------------------------------
            train_input, train_gt = train_batch["input"], train_batch["gt"]
            train_input = train_input.cuda()
            train_gt = train_gt.cuda(async=True)
            preds_supervised = model(train_input)
            class_loss = mt_losses.dice_loss(preds_supervised, train_gt)

            if not supervised_only:

                # Unsupervised component ------------------------------------------
                try:
                    target_adapt_batch = target_adapt_train_iter.next()
                except StopIteration:
                    target_adapt_train_iter = iter(target_adapt_train_loader)
                    target_adapt_batch = target_adapt_train_iter.next()

                target_adapt_input = target_adapt_batch["input"]
                target_adapt_input = target_adapt_input.cuda()

                # Teacher forward
                with torch.no_grad():
                    teacher_preds_unsup = model_ema(target_adapt_input)

                linked_aug_batch = \
                    linked_batch_augmentation(target_adapt_input, teacher_preds_unsup)

                adapt_input_batch = linked_aug_batch['input'][0].cuda()
                teacher_preds_unsup_aug = linked_aug_batch['input'][1].cuda()

                # Student forward
                student_preds_unsup = model(adapt_input_batch)

                consistency_loss = consistency_weight * consistency_loss_fn(
                    student_preds_unsup, teacher_preds_unsup_aug)
            else:
                consistency_loss = torch.FloatTensor([0.]).cuda()

            composite_loss = class_loss + consistency_loss

            optimizer.zero_grad()
            composite_loss.backward()

            optimizer.step()

            composite_loss_total += composite_loss.item()
            consistency_loss_total += consistency_loss.item()
            class_loss_total += class_loss.item()

            num_steps += 1
            global_step += 1

            if not supervised_only:
                if epoch <= ctx["ema_late_epoch"]:
                    update_ema_variables(model, model_ema, ctx["ema_alpha"],
                                         global_step)
                else:
                    update_ema_variables(model, model_ema,
                                         ctx["ema_alpha_late"], global_step)

        # Write histogram of the probs
        if not supervised_only:
            npy_teacher_preds = teacher_preds_unsup.detach().cpu().numpy()
            writer.add_histogram("Teacher Preds Hist", npy_teacher_preds,
                                 epoch)

            npy_student_preds = student_preds_unsup.detach().cpu().numpy()
            writer.add_histogram("Student Preds Hist", npy_student_preds,
                                 epoch)

        npy_supervised_preds = preds_supervised.detach().cpu().numpy()
        writer.add_histogram("Supervised Preds Hist", npy_supervised_preds,
                             epoch)

        composite_loss_avg = composite_loss_total / num_steps
        class_loss_avg = class_loss_total / num_steps
        consistency_loss_avg = consistency_loss_total / num_steps

        tqdm.write("Steps p/ Epoch: {}".format(num_steps))
        tqdm.write("Consistency Weight: {:.6f}".format(consistency_weight))
        tqdm.write("Composite Loss: {:.6f}".format(composite_loss_avg))
        tqdm.write("Class Loss: {:.6f}".format(class_loss_avg))
        tqdm.write("Consistency Loss: {:.6f}".format(consistency_loss_avg))

        # Write sample images
        if ctx["write_images"] and epoch % ctx["write_images_interval"] == 0:
            try:
                plot_img = vutils.make_grid(preds_supervised,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Prediction', plot_img, epoch)

                plot_img = vutils.make_grid(train_input,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Input', plot_img, epoch)

                plot_img = vutils.make_grid(train_gt,
                                            normalize=True,
                                            scale_each=True)
                writer.add_image('Train Source Ground Truth', plot_img, epoch)

                # Unsupervised component viz
                if not supervised_only:
                    plot_img = vutils.make_grid(target_adapt_input,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Input', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(teacher_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Preds', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(adapt_input_batch,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Teacher Input', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(student_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Teacher Preds', plot_img,
                                     epoch)

                    plot_img = vutils.make_grid(student_preds_unsup,
                                                normalize=True,
                                                scale_each=True)
                    writer.add_image('Train Target Student Preds (augmented)',
                                     plot_img, epoch)
            except:
                tqdm.write("*** Error writing images ***")

        writer.add_scalars(
            'losses', {
                'composite_loss': composite_loss_avg,
                'class_loss': class_loss_avg,
                'consistency_loss': consistency_loss_avg
            }, epoch)

        # Evaluation mode
        model.eval()

        if not supervised_only:
            model_ema.eval()

        metric_fns = [
            mt_metrics.dice_score, mt_metrics.jaccard_score,
            mt_metrics.hausdorff_score, mt_metrics.precision_score,
            mt_metrics.recall_score, mt_metrics.specificity_score,
            mt_metrics.intersection_over_union, mt_metrics.accuracy_score
        ]

        for center, loader in enumerate(validation_centers_loaders):
            validation(model, model_ema, loader, writer, metric_fns, epoch,
                       ctx, 'val_%s' % ctx["val_centers"][center])

        end_time = time.time()
        total_time = end_time - start_time
        tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time))
예제 #5
0
def validation(model, model_ema, loader, writer, metric_fns, epoch, ctx,
               prefix):
    val_loss = 0.0
    ema_val_loss = 0.0

    num_samples = 0
    num_steps = 0

    result_dict = defaultdict(float)
    result_ema_dict = defaultdict(float)

    for i, batch in enumerate(loader):
        input_data, gt_data = batch["input"], batch["gt"]

        input_data_gpu = input_data.cuda()
        gt_data_gpu = gt_data.cuda(async=True)

        with torch.no_grad():
            model_out = model(input_data_gpu)
            val_class_loss = mt_losses.dice_loss(model_out, gt_data_gpu)
            val_loss += val_class_loss.item()

            if not ctx["supervised_only"]:
                model_ema_out = model_ema(input_data_gpu)
                ema_val_class_loss = mt_losses.dice_loss(
                    model_ema_out, gt_data_gpu)
                ema_val_loss += ema_val_class_loss.item()

        gt_masks = gt_data_gpu.cpu().numpy().astype(np.uint8)
        gt_masks = gt_masks.squeeze(axis=1)

        preds = model_out.cpu().numpy()
        preds = threshold_predictions(preds)
        preds = preds.astype(np.uint8)
        preds = preds.squeeze(axis=1)

        for metric_fn in metric_fns:
            for prediction, ground_truth in zip(preds, gt_masks):
                res = metric_fn(prediction, ground_truth)
                dict_key = 'val_{}'.format(metric_fn.__name__)
                result_dict[dict_key] += res

        if not ctx["supervised_only"]:
            preds_ema = model_ema_out.cpu().numpy()
            preds_ema = threshold_predictions(preds_ema)
            preds_ema = preds_ema.astype(np.uint8)
            preds_ema = preds_ema.squeeze(axis=1)

            for metric_fn in metric_fns:
                for prediction, ground_truth in zip(preds_ema, gt_masks):
                    res = metric_fn(prediction, ground_truth)
                    dict_key = 'val_ema_{}'.format(metric_fn.__name__)
                    result_ema_dict[dict_key] += res

        num_samples += len(preds)
        num_steps += 1

    val_loss_avg = val_loss / num_steps

    for key, val in result_dict.items():
        result_dict[key] = val / num_samples

    if not ctx["supervised_only"]:
        for key, val in result_ema_dict.items():
            result_ema_dict[key] = val / num_samples

        ema_val_loss_avg = ema_val_loss / num_steps
        writer.add_scalars(prefix + '_ema_metrics', result_ema_dict, epoch)
        writer.add_scalars(prefix + '_losses', {
            prefix + '_loss': val_loss_avg,
            prefix + '_ema_loss': ema_val_loss_avg
        }, epoch)
    else:
        writer.add_scalars(prefix + '_losses', {
            prefix + '_loss': val_loss_avg,
        }, epoch)

    writer.add_scalars(prefix + '_metrics', result_dict, epoch)