Example #1
0
def test_calculate_metric(epoch_num,
                          patch_size=(128, 128, 64),
                          stride_xy=64,
                          stride_z=32,
                          device='cuda'):
    net = VNet(n_channels=1,
               n_classes=num_classes,
               normalization='batchnorm',
               has_dropout=False).to(device)
    save_mode_path = os.path.join(snapshot_path,
                                  'iter_' + str(epoch_num) + '.pth')
    print(save_mode_path)
    net.load_state_dict(torch.load(save_mode_path))
    print("init weight from {}".format(save_mode_path))
    net.eval()

    metrics = test_all_case(net,
                            image_list,
                            num_classes=num_classes,
                            name_classes=name_classes,
                            patch_size=patch_size,
                            stride_xy=stride_xy,
                            stride_z=stride_z,
                            save_result=True,
                            test_save_path=test_save_path,
                            device=device)

    return metrics
Example #2
0
def main():
    args = get_args()

    # dataset
    db_test = ABUS(base_dir=args.root_path, split='test')
    testloader = DataLoader(db_test,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True)
    args.testloader = testloader

    # network
    if args.arch == 'vnet':
        model = VNet(n_channels=1,
                     n_classes=2,
                     normalization='batchnorm',
                     has_dropout=True,
                     use_tm=args.use_tm)
    elif args.arch == 'd2unet':
        model = D2UNet()
    else:
        raise (NotImplementedError('model {} not implement'.format(args.arch)))
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_pre = checkpoint['best_pre']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            # --- saving path ---
            if 'best' in args.resume:
                file_name = 'model_best_' + str(checkpoint['epoch'])
            elif 'check' in args.resume:
                file_name = 'checkpoint_{}_result'.format(checkpoint['epoch'])

            if args.save is not None:
                save_path = os.path.join(args.save, file_name)
            else:
                save_path = os.path.join(os.path.dirname(args.resume),
                                         file_name)
            if os.path.exists(save_path):
                shutil.rmtree(save_path)
            os.makedirs(save_path, exist_ok=True)

    test_all_case(model,
                  args.testloader,
                  num_classes=args.num_classes,
                  patch_size=(64, 128, 128),
                  stride_xy=64,
                  stride_z=64,
                  save_result=True,
                  test_save_path=save_path)
def test_calculate_metric(epoch_num):
    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda()
    save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth')
    net.load_state_dict(torch.load(save_mode_path))
    print("init weight from {}".format(save_mode_path))
    net.eval()

    avg_metric = test_all_case(net, image_list, num_classes=num_classes,
                               patch_size=(112, 112, 80), stride_xy=18, stride_z=4,
                               save_result=True, test_save_path=test_save_path)

    return avg_metric
Example #4
0
def test_calculate_metric(args):
    net = VNet(n_channels=1,
               n_classes=args.num_classes,
               normalization='batchnorm',
               has_dropout=False).cuda()
    save_mode_path = os.path.join(args.snapshot_path,
                                  'iter_' + str(args.start_epoch) + '.pth')
    net.load_state_dict(torch.load(save_mode_path))
    print("init weight from {}".format(save_mode_path))
    net.eval()

    avg_metric = test_all_case(net,
                               args.testloader,
                               num_classes=args.num_classes,
                               patch_size=(128, 64, 128),
                               stride_xy=18,
                               stride_z=4,
                               save_result=True,
                               test_save_path=args.test_save_path)

    return avg_metric
Example #5
0
def debugger():
    patch_size = (112, 112, 80)
    training_data = data_loader(split='train')
    testing_data = data_loader(split='test')

    x_criterion = soft_cross_entropy  #supervised loss is 0.5*(x_criterion + dice_loss)
    u_criterion = nn.MSELoss()  #unsupervised loss

    labelled_index = np.random.permutation(LABELLED_INDEX)
    unlabelled_index = np.random.permutation(
        UNLABELLED_INDEX)[:len(labelled_index)]
    labelled_data = [training_data[i] for i in labelled_index]
    unlabelled_data = [training_data[i] for i in unlabelled_index]  #size = 16

    ##data transformation: rotation, flip, random_crop
    labelled_data = [
        shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
        for sample in labelled_data
    ]
    unlabelled_data = [
        shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
        for sample in unlabelled_data
    ]

    net = VNet(n_channels=1,
               n_classes=2,
               normalization='batchnorm',
               has_dropout=True).cuda()

    model_path = "../saved/0_supervised.pth"
    net.load_state_dict(torch.load(model_path))

    optimizer = optim.SGD(net.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=0.0001)
    training_loss = train_epoch(net=net,
                                labelled_data=labelled_data,
                                unlabelled_data=unlabelled_data,
                                batch_size=2,
                                supervised_only=True,
                                optimizer=optimizer,
                                x_criterion=x_criterion,
                                u_criterion=u_criterion,
                                K=1,
                                T=1,
                                alpha=1,
                                mixup_mode="__",
                                Lambda=0,
                                aug_factor=0)

    net = VNet(n_channels=1,
               n_classes=2,
               normalization='batchnorm',
               has_dropout=True).cuda()
    model_path = "../saved/8_expected_supervised.pth"
    net.load_state_dict(torch.load(model_path))

    optimizer = optim.SGD(net.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=0.0001)
    training_loss = train_epoch(net=net,
                                labelled_data=labelled_data,
                                unlabelled_data=unlabelled_data,
                                batch_size=2,
                                supervised_only=False,
                                optimizer=optimizer,
                                x_criterion=x_criterion,
                                u_criterion=u_criterion,
                                K=1,
                                T=1,
                                alpha=1,
                                mixup_mode="__",
                                Lambda=0,
                                aug_factor=0)
Example #6
0
def experiment(exp_identifier,
               max_epoch,
               training_data,
               testing_data,
               batch_size=2,
               supervised_only=False,
               K=2,
               T=0.5,
               alpha=1,
               mixup_mode='all',
               Lambda=1,
               Lambda_ramp=None,
               base_lr=0.01,
               change_lr=None,
               aug_factor=1,
               from_saved=None,
               always_do_validation=True,
               decay=0):
    '''
    max_epoch: epochs to run. Going through labeled data once is one epoch.
    batch_size: batch size of labeled data. Unlabeled data is of the same size.
    training_data: data for train_epoch, list of dicts of numpy array.
    training_data: data for validation, list of dicts of numpy array.
    supervised_only: if True, only do supervised training on LABELLED_INDEX; otherwise, use both LABELLED_INDEX and UNLABELLED_INDEX
    
    Hyperparameters
    ---------------
    K: repeats of each unlabelled data
    T: temperature of sharpening
    alpha: mixup hyperparameter of beta distribution
    mixup_mode: how mixup is performed --
        '__': no mix up
        'ww': x and u both mixed up with w(x+u)
        'xx': both with x
        'xu': x with x, u with u
        'uu': both with u
        ... _ means no, x means with x, u means with u, w means with w(x+u)
    Lambda: loss = loss_x + Lambda * loss_u, relative weight for unsupervised loss
    base_lr: initial learning rate

    Lambda_ramp: callable or None. Lambda is ignored if this is not None. In this case,  Lambda = Lambda_ramp(epoch).
    change_lr: dict, {epoch: change_multiplier}


    '''
    print(
        f"Experiment {exp_identifier}: max_epoch = {max_epoch}, batch_size = {batch_size}, supervised_only = {supervised_only},"
        f"K = {K}, T = {T}, alpha = {alpha}, mixup_mode = {mixup_mode}, Lambda = {Lambda}, Lambda_ramp = {Lambda_ramp}, base_lr = {base_lr}, aug_factor = {aug_factor}."
    )

    net = VNet(n_channels=1,
               n_classes=2,
               normalization='batchnorm',
               has_dropout=True)
    eval_net = VNet(n_channels=1,
                    n_classes=2,
                    normalization='batchnorm',
                    has_dropout=True)

    if from_saved is not None:
        net.load_state_dict(torch.load(from_saved))

    if GPU:
        net = net.cuda()
        eval_net.cuda()

    ## eval_net is not updating
    for param in eval_net.parameters():
        param.detach_()

    net.train()
    eval_net.train()

    optimizer = optim.SGD(net.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    x_criterion = soft_cross_entropy  #supervised loss is 0.5*(x_criterion + dice_loss)
    u_criterion = nn.MSELoss()  #unsupervised loss

    training_losses = []
    testing_losses = []
    testing_accuracy = []  #dice accuracy

    patch_size = (112, 112, 80)

    testing_data = [
        shape_transform(CenterCrop(patch_size)(sample))
        for sample in testing_data
    ]
    t0 = time.time()

    lr = base_lr

    for epoch in range(max_epoch):
        labelled_index = np.random.permutation(LABELLED_INDEX)
        unlabelled_index = np.random.permutation(
            UNLABELLED_INDEX)[:len(labelled_index)]
        labelled_data = [training_data[i] for i in labelled_index]
        unlabelled_data = [training_data[i]
                           for i in unlabelled_index]  #size = 16

        ##data transformation: rotation, flip, random_crop
        labelled_data = [
            shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
            for sample in labelled_data
        ]
        unlabelled_data = [
            shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
            for sample in unlabelled_data
        ]

        if Lambda_ramp is not None:
            Lambda = Lambda_ramp(epoch)
            print(f"Lambda ramp: Lambda = {Lambda}")

        if change_lr is not None:
            if epoch in change_lr:
                lr_ = lr * change_lr[epoch]
                print(
                    f"Learning rate decay at epoch {epoch}, from {lr} to {lr_}"
                )
                lr = lr_
                #change learning rate.
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_

        training_loss = train_epoch(net=net,
                                    eval_net=eval_net,
                                    labelled_data=labelled_data,
                                    unlabelled_data=unlabelled_data,
                                    batch_size=batch_size,
                                    supervised_only=supervised_only,
                                    optimizer=optimizer,
                                    x_criterion=x_criterion,
                                    u_criterion=u_criterion,
                                    K=K,
                                    T=T,
                                    alpha=alpha,
                                    mixup_mode=mixup_mode,
                                    Lambda=Lambda,
                                    aug_factor=aug_factor,
                                    decay=decay)

        training_losses.append(training_loss)

        if always_do_validation or epoch % 50 == 0:
            testing_dice_loss, accuracy = validation(net=net,
                                                     testing_data=testing_data,
                                                     x_criterion=x_criterion)

        testing_losses.append(testing_dice_loss)
        testing_accuracy.append(accuracy)
        print(
            f"Epoch {epoch+1}/{max_epoch}, time used: {time.time()-t0:.2f},  training loss: {training_loss:.6f}, testing dice_loss: {testing_dice_loss:.6f}, testing accuracy: {100.0*accuracy:.2f}% "
        )

    save_path = f"../saved/{exp_identifier}.pth"
    torch.save(net.state_dict(), save_path)
    print(f"Experiment {exp_identifier} finished. Model saved as {save_path}")
    return training_losses, testing_losses, testing_accuracy