def create_model(ema=False):
     # Network definition
     net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
     model = net.cuda()
     if ema:
         for param in model.parameters():
             param.detach_()
     return model
示例#2
0
def main():
    args = get_args()

    # dataset
    db_test = ABUS(base_dir=args.root_path, split='test')
    testloader = DataLoader(db_test,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True)
    args.testloader = testloader

    # network
    if args.arch == 'vnet':
        model = VNet(n_channels=1,
                     n_classes=2,
                     normalization='batchnorm',
                     has_dropout=True,
                     use_tm=args.use_tm)
    elif args.arch == 'd2unet':
        model = D2UNet()
    else:
        raise (NotImplementedError('model {} not implement'.format(args.arch)))
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_pre = checkpoint['best_pre']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            # --- saving path ---
            if 'best' in args.resume:
                file_name = 'model_best_' + str(checkpoint['epoch'])
            elif 'check' in args.resume:
                file_name = 'checkpoint_{}_result'.format(checkpoint['epoch'])

            if args.save is not None:
                save_path = os.path.join(args.save, file_name)
            else:
                save_path = os.path.join(os.path.dirname(args.resume),
                                         file_name)
            if os.path.exists(save_path):
                shutil.rmtree(save_path)
            os.makedirs(save_path, exist_ok=True)

    test_all_case(model,
                  args.testloader,
                  num_classes=args.num_classes,
                  patch_size=(64, 128, 128),
                  stride_xy=64,
                  stride_z=64,
                  save_result=True,
                  test_save_path=save_path)
示例#3
0
def test_calculate_metric(epoch_num,
                          patch_size=(128, 128, 64),
                          stride_xy=64,
                          stride_z=32,
                          device='cuda'):
    net = VNet(n_channels=1,
               n_classes=num_classes,
               normalization='batchnorm',
               has_dropout=False).to(device)
    save_mode_path = os.path.join(snapshot_path,
                                  'iter_' + str(epoch_num) + '.pth')
    print(save_mode_path)
    net.load_state_dict(torch.load(save_mode_path))
    print("init weight from {}".format(save_mode_path))
    net.eval()

    metrics = test_all_case(net,
                            image_list,
                            num_classes=num_classes,
                            name_classes=name_classes,
                            patch_size=patch_size,
                            stride_xy=stride_xy,
                            stride_z=stride_z,
                            save_result=True,
                            test_save_path=test_save_path,
                            device=device)

    return metrics
示例#4
0
def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2):
    if net_type == "unet_3D":
        net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda()
    elif net_type == "attention_unet":
        net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda()
    elif net_type == "voxresnet":
        net = VoxResNet(in_chns=in_chns, feature_chns=64,
                        class_num=class_num).cuda()
    elif net_type == "vnet":
        net = VNet(n_channels=in_chns,
                   n_classes=class_num,
                   normalization='batchnorm',
                   has_dropout=True).cuda()
    else:
        net = None
    return net
def test_calculate_metric(epoch_num):
    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda()
    save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth')
    net.load_state_dict(torch.load(save_mode_path))
    print("init weight from {}".format(save_mode_path))
    net.eval()

    avg_metric = test_all_case(net, image_list, num_classes=num_classes,
                               patch_size=(112, 112, 80), stride_xy=18, stride_z=4,
                               save_result=True, test_save_path=test_save_path)

    return avg_metric
示例#6
0
def test_calculate_metric(args):
    net = VNet(n_channels=1,
               n_classes=args.num_classes,
               normalization='batchnorm',
               has_dropout=False).cuda()
    save_mode_path = os.path.join(args.snapshot_path,
                                  'iter_' + str(args.start_epoch) + '.pth')
    net.load_state_dict(torch.load(save_mode_path))
    print("init weight from {}".format(save_mode_path))
    net.eval()

    avg_metric = test_all_case(net,
                               args.testloader,
                               num_classes=args.num_classes,
                               patch_size=(128, 64, 128),
                               stride_xy=18,
                               stride_z=4,
                               save_result=True,
                               test_save_path=args.test_save_path)

    return avg_metric
示例#7
0
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    # training path
    train_data_path = args.root_path
    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu 
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    #patch_size = (112, 112, 112)
    #patch_size = (160, 160, 160)
    patch_size = (64, 128, 128)
    num_classes = 2


    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm)
    net = net.cuda()

    #db_train = LAHeart(base_dir=train_data_path,
    #                   split='train',
    #                   transform = transforms.Compose([
    #                      RandomRotFlip(),
    #                      RandomCrop(patch_size),
    #                      ToTensor(),
    #                      ]))

    db_train = ABUS(base_dir=args.root_path,
                       split='train',
                       use_dismap=args.use_dismap,
                       transform = transforms.Compose([RandomRotFlip(use_dismap=args.use_dismap), RandomCrop(patch_size, use_dismap=args.use_dismap), ToTensor(use_dismap=args.use_dismap)]))
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    net.train()
    optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    #gdl = GeneralizedDiceLoss()

    logging.info("{} itertations per epoch".format(len(trainloader)))

    iter_num = 0
    alpha = 1.0
    max_epoch = max_iterations//len(trainloader)+1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch, dis_map_batch = sampled_batch['image'], sampled_batch['label'], sampled_batch['dis_map']
            volume_batch, label_batch, dis_map_batch = volume_batch.cuda(), label_batch.cuda(), dis_map_batch.cuda()
            #print('volume_batch.shape: ', volume_batch.shape)
            if args.use_tm:
                outputs, tm = net(volume_batch)
                tm = torch.sigmoid(tm)
            else:
                outputs = net(volume_batch)
            #print('volume_batch.shape: ', volume_batch.shape)
            #print('outputs.shape, ', outputs.shape)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            #print(outputs_soft.shape)
            #print(label_batch.shape)
            #loss_seg_dice = gdl(outputs_soft, label_batch)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
            #with torch.no_grad():
            #    # defalut using compute_sdf; however, compute_sdf1_1 is also worth to try;
            #    gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(), outputs_soft.shape)
            #    gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda(outputs_soft.device.index)
            #    print('gt_sdf.shape: ', gt_sdf.shape)
            #loss_boundary = boundary_loss(outputs_soft, gt_sdf)

            #print('dis_map.shape: ', dis_map_batch.shape)
            loss_boundary = boundary_loss(outputs_soft, dis_map_batch)

            if args.use_tm:
                loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :], tm[:, 0, ...], label_batch == 1)
                loss_th = (0.1 * loss_seg + 0.9 * loss_seg_dice) + 3 * loss_threshold
                loss = alpha*(loss_th) + (1 - alpha) * loss_boundary
            else:
                loss = alpha * loss_seg_dice + (1-alpha) * loss_boundary

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            out = outputs_soft.max(1)[1]
            dice = GeneralizedDiceLoss.dice_coeficient(out, label_batch)

            iter_num = iter_num + 1
            writer.add_scalar('train/lr', lr_, iter_num)
            writer.add_scalar('train/loss_seg', loss_seg, iter_num)
            writer.add_scalar('train/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('train/alpha', alpha, iter_num)
            writer.add_scalar('train/loss', loss, iter_num)
            writer.add_scalar('train/dice', dice, iter_num)
            if args.use_tm:
                writer.add_scalar('train/loss_threshold', loss_threshold, iter_num)
            if args.use_dismap:
                writer.add_scalar('train/loss_dis', loss_boundary, iter_num)

            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            logging.info('iteration %d : alpha : %f' % (iter_num, alpha))

            if iter_num % 50 == 0:
                image = volume_batch[0, 0:1, :, 30:71:10, :].permute(2,0,1,3)
                image = (image + 0.5) * 0.5
                grid_image = make_grid(image, 5)
                writer.add_image('train/Image', grid_image, iter_num)

                #outputs_soft = F.softmax(outputs, 1) #batchsize x num_classes x w x h x d
                image = outputs_soft[0, 1:2, :, 30:71:10, :].permute(2,0,1,3)
                grid_image = make_grid(image, 5, normalize=False)
                grid_image = grid_image.cpu().detach().numpy().transpose((1,2,0))

                gt = label_batch[0, :, 30:71:10, :].unsqueeze(0).permute(2,0,1,3)
                grid_gt = make_grid(gt, 5, normalize=False)
                grid_gt = grid_gt.cpu().detach().numpy().transpose((1,2,0))

                image_tm = dis_map_batch[0, :, :, 30:71:10, :].permute(2,0,1,3)
                #image_tm = tm[0, :, :, 30:71:10, :].permute(2,0,1,3)
                grid_tm = make_grid(image_tm, 5, normalize=False)
                grid_tm = grid_tm.cpu().detach().numpy().transpose((1,2,0))


                fig = plt.figure()
                ax = fig.add_subplot(311)
                ax.imshow(grid_gt[:, :, 0], 'gray')
                ax = fig.add_subplot(312)
                cs = ax.imshow(grid_image[:, :, 0], 'hot', vmin=0., vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                ax = fig.add_subplot(313)
                cs = ax.imshow(grid_tm[:, :, 0], 'hot', vmin=0, vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                writer.add_figure('train/prediction_results', fig, iter_num)
                fig.clear()

            ## change lr
            if iter_num % 5000 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 5000)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        alpha -= 0.005
        if alpha <= 0.01:
            alpha = 0.01
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
示例#8
0
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # training path
    train_data_path = args.root_path

    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path + "/log.txt",
                        level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s',
                        datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    # network
    if args.arch == 'vnet':
        model = VNet(n_channels=1,
                     n_classes=2,
                     normalization='batchnorm',
                     has_dropout=True,
                     use_tm=args.use_tm)
    elif args.arch == 'd2unet':
        model = D2UNet()
    else:
        raise (NotImplementedError('model {} not implement'.format(args.arch)))
    model = model.cuda()

    # dataset
    patch_size = (64, 128, 128)
    batch_size = args.ngpu * args.batch_size

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = ABUS(base_dir=args.root_path,
                    split='val',
                    fold=args.fold,
                    transform=transforms.Compose(
                        [RandomRotFlip(),
                         RandomCrop(patch_size),
                         ToTensor()]))
    db_val = ABUS(base_dir=args.root_path,
                  split='val',
                  fold=args.fold,
                  transform=transforms.Compose(
                      [CenterCrop(patch_size),
                       ToTensor()]))
    train_loader = DataLoader(db_train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              worker_init_fn=worker_init_fn)
    val_loader = DataLoader(db_val,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True,
                            worker_init_fn=worker_init_fn)

    # optimizer
    lr = args.lr
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           weight_decay=args.weight_decay)
    lr_scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epoch,
                                len(train_loader))

    # training
    logging.info('--- start training ---')

    best_pre = 0.
    nTrain = len(db_train)
    for epoch in range(args.start_epoch, args.epoch + 1):
        train(args, epoch, model, train_loader, optimizer, writer,
              lr_scheduler)
        dice = val(args, epoch, model, val_loader, writer)
        is_best = False
        if dice > best_pre:
            is_best = True
            best_pre = dice
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_pre': best_pre
            }, is_best, args.save, args.arch)
    writer.close()
示例#9
0
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # training path
    train_data_path = args.root_path

    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu 
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    patch_size = (64, 128, 128)
    num_classes = 2

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    # network
    if args.arch == 'vnet':
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm)
    elif args.arch == 'd2unet':
        net = D2UNet()
    else:
        raise(NotImplementedError('model {} not implement'.format(args.arch))) 
    net = net.cuda()

    # dataset 
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    db_train = ABUS(base_dir=args.root_path,
                       split='train',
                       fold=args.fold,
                       transform = transforms.Compose([RandomRotFlip(), RandomCrop(patch_size), ToTensor()]))
    db_val = ABUS(base_dir=args.root_path,
                       split='val',
                       fold=args.fold,
                       transform = transforms.Compose([CenterCrop(patch_size), ToTensor()]))
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    # optimizer
    optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    gdl = GeneralizedDiceLoss()

    logging.info("{} itertations per epoch".format(len(trainloader)))

    # training
    iter_num = 0
    max_epoch = max_iterations//len(trainloader)+1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        for i_batch, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            if args.use_tm: 
                outputs, tm = net(volume_batch)
                tm = torch.sigmoid(tm)
            else:
                outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
            if args.use_tm: 
                loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :], tm[:, 0, ...], label_batch == 1)
                loss = loss_seg_dice + 3 * loss_threshold
            else:
                loss = loss_seg_dice

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            # visualization on tensorboard
            out = outputs_soft.max(1)[1]
            dice = GeneralizedDiceLoss.dice_coeficient(out, label_batch)

            iter_num = iter_num + 1
            writer.add_scalar('train/lr', lr_, iter_num)
            writer.add_scalar('train/loss_seg', loss_seg, iter_num)
            writer.add_scalar('train/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('train/loss', loss, iter_num)
            writer.add_scalar('train/dice', dice, iter_num)
            if args.use_tm:
                writer.add_scalar('train/loss_threshold', loss_threshold, iter_num)

            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))

            if iter_num % 50 == 0:
                nrow = 5
                image = volume_batch[0, 0:1, :, 30:71:10, :].permute(2,0,1,3)
                image = (image + 0.5) * 0.5
                grid_image = make_grid(image, nrow=nrow)
                writer.add_image('train/Image', grid_image, iter_num)

                #outputs_soft = F.softmax(outputs, 1) #batchsize x num_classes x w x h x d
                image = outputs_soft[0, 1:2, :, 30:71:10, :].permute(2,0,1,3)
                grid_image = make_grid(image, nrow=nrow, normalize=False)
                grid_image = grid_image.cpu().detach().numpy().transpose((1,2,0))

                gt = label_batch[0, :, 30:71:10, :].unsqueeze(0).permute(2,0,1,3)
                grid_gt = make_grid(gt, nrow=nrow, normalize=False)
                grid_gt = grid_gt.cpu().detach().numpy().transpose((1,2,0))

                if args.use_tm:
                    image_tm = tm[0, :, :, 30:71:10, :].permute(2,0,1,3)
                else:
                    image_tm = gt
                grid_tm = make_grid(image_tm, nrow=nrow, normalize=False)
                grid_tm = grid_tm.cpu().detach().numpy().transpose((1,2,0))

                fig = plt.figure()
                ax = fig.add_subplot(311)
                ax.imshow(grid_gt[:, :, 0], 'gray')
                ax = fig.add_subplot(312)
                cs = ax.imshow(grid_image[:, :, 0], 'hot', vmin=0., vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                ax = fig.add_subplot(313)
                cs = ax.imshow(grid_tm[:, :, 0], 'hot', vmin=0, vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                writer.add_figure('train/prediction_results', fig, iter_num)
                fig.clear()

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0 and iter_num > 5000:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
示例#10
0
def debugger():
    patch_size = (112, 112, 80)
    training_data = data_loader(split='train')
    testing_data = data_loader(split='test')

    x_criterion = soft_cross_entropy  #supervised loss is 0.5*(x_criterion + dice_loss)
    u_criterion = nn.MSELoss()  #unsupervised loss

    labelled_index = np.random.permutation(LABELLED_INDEX)
    unlabelled_index = np.random.permutation(
        UNLABELLED_INDEX)[:len(labelled_index)]
    labelled_data = [training_data[i] for i in labelled_index]
    unlabelled_data = [training_data[i] for i in unlabelled_index]  #size = 16

    ##data transformation: rotation, flip, random_crop
    labelled_data = [
        shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
        for sample in labelled_data
    ]
    unlabelled_data = [
        shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
        for sample in unlabelled_data
    ]

    net = VNet(n_channels=1,
               n_classes=2,
               normalization='batchnorm',
               has_dropout=True).cuda()

    model_path = "../saved/0_supervised.pth"
    net.load_state_dict(torch.load(model_path))

    optimizer = optim.SGD(net.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=0.0001)
    training_loss = train_epoch(net=net,
                                labelled_data=labelled_data,
                                unlabelled_data=unlabelled_data,
                                batch_size=2,
                                supervised_only=True,
                                optimizer=optimizer,
                                x_criterion=x_criterion,
                                u_criterion=u_criterion,
                                K=1,
                                T=1,
                                alpha=1,
                                mixup_mode="__",
                                Lambda=0,
                                aug_factor=0)

    net = VNet(n_channels=1,
               n_classes=2,
               normalization='batchnorm',
               has_dropout=True).cuda()
    model_path = "../saved/8_expected_supervised.pth"
    net.load_state_dict(torch.load(model_path))

    optimizer = optim.SGD(net.parameters(),
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=0.0001)
    training_loss = train_epoch(net=net,
                                labelled_data=labelled_data,
                                unlabelled_data=unlabelled_data,
                                batch_size=2,
                                supervised_only=False,
                                optimizer=optimizer,
                                x_criterion=x_criterion,
                                u_criterion=u_criterion,
                                K=1,
                                T=1,
                                alpha=1,
                                mixup_mode="__",
                                Lambda=0,
                                aug_factor=0)
示例#11
0
num_classes = 2

if __name__ == "__main__":
    ## make logger file
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    if os.path.exists(snapshot_path + '/code'):
        shutil.rmtree(snapshot_path + '/code')
    shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__']))

    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
    net = net.cuda()

    db_train = LiverTumor(base_dir=train_data_path,
                       split='train',
                       transform = transforms.Compose([
                          RandomRotFlip(),
                          RandomCrop(patch_size),
                          ToTensor(),
                          ]))


    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    # training path
    train_data_path = args.root_path
    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu 
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    patch_size = (112, 112, 80)
    num_classes = 2

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
    net = net.cuda()

    db_train = LAHeart(base_dir=train_data_path,
                       split='train',
                       num=16,
                       transform = transforms.Compose([
                          RandomRotFlip(),
                          RandomCrop(patch_size),
                          ToTensor(),
                          ]))
    db_test = LAHeart(base_dir=train_data_path,
                       split='test',
                       transform = transforms.Compose([
                           CenterCrop(patch_size),
                           ToTensor()
                       ]))
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    net.train()
    optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

    logging.info("{} itertations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations//len(trainloader)+1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
            loss = 0.5*(loss_seg+loss_seg_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            if iter_num % 50 == 0:
                image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                outputs_soft = F.softmax(outputs, 1)
                image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image, iter_num)

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    # training path
    train_data_path = args.root_path
    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    patch_size = (112, 112, 80)
    num_classes = 2

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path + "/log.txt",
                        level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s',
                        datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    #training set
    db_train = LAHeart(base_dir=train_data_path,
                       split='train',
                       num=16,
                       transform=transforms.Compose([
                           RandomRotFlip(),
                           RandomCrop(patch_size),
                           ToTensor()
                       ]))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    trainloader = DataLoader(db_train,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    net = VNet(n_channels=1,
               n_classes=num_classes,
               normalization='batchnorm',
               has_dropout=True)
    net = net.cuda()
    net.train()
    optimizer = optim.SGD(net.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)

    logging.info("{} itertations per epoch".format(len(trainloader)))

    iter_num = 0
    alpha = 1.0
    max_epoch = max_iterations // len(trainloader) + 1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            # volume_batch.shape=(b,1,x,y,z) label_patch.shape=(b,x,y,z)
            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :],
                                      label_batch == 1)
            # compute gt_signed distance function and boundary loss
            with torch.no_grad():
                # defalut using compute_sdf; however, compute_sdf1_1 is also worth to try;
                gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(),
                                         outputs_soft.shape)
                gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda(
                    outputs_soft.device.index)
                # show signed distance map for debug
                # import matplotlib.pyplot as plt
                # plt.figure()
                # plt.subplot(121), plt.imshow(gt_sdf_npy[0,1,:,:,40]), plt.colorbar()
                # plt.subplot(122), plt.imshow(np.uint8(label_batch.cpu().numpy()[0,:,:,40]>0)), plt.colorbar()
                # plt.show()
            loss_boundary = boundary_loss(outputs_soft, gt_sdf)
            loss = alpha * (loss_seg + loss_seg_dice) + (1 -
                                                         alpha) * loss_boundary

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss_boundary', loss_boundary, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            writer.add_scalar('loss/alpha', alpha, iter_num)
            logging.info('iteration %d : alpha : %f' % (iter_num, alpha))
            logging.info('iteration %d : loss_seg_dice : %f' %
                         (iter_num, loss_seg_dice.item()))
            logging.info('iteration %d : loss_boundary : %f' %
                         (iter_num, loss_boundary.item()))
            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            if iter_num % 2 == 0:
                image = volume_batch[0, 0:1, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                image = outputs_soft[0, 1:2, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(
                    3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image,
                                 iter_num)

                image = gt_sdf[0, 1:2, :, :,
                               20:61:10].permute(3, 0, 1,
                                                 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/gt_sdf', grid_image, iter_num)

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1**(iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        alpha -= 0.01
        if alpha <= 0.01:
            alpha = 0.01
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path,
                                  'iter_' + str(max_iterations + 1) + '.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
示例#14
0
def test_calculate_metric(iter_nums):
    if args.net == 'vnet':
        net = VNet(n_channels=1,
                   num_classes=args.num_classes,
                   normalization='batchnorm',
                   has_dropout=False)
    elif args.net == 'unet':
        net = UNet3D(in_channels=1, num_classes=args.num_classes)
    elif args.net == 'segtran':
        get_default(args, 'num_modes', default_settings, -1,
                    [args.net, 'num_modes', args.in_fpn_layers])
        if args.segtran_type == '25d':
            set_segtran25d_config(args)
            net = Segtran25d(config25d)
        else:
            set_segtran3d_config(args)
            net = Segtran3d(config3d)

    net.cuda()
    net.eval()
    preproc_fn = None

    if not args.checkpoint_dir:
        if args.vis_mode is not None:
            visualize_model(net, args.vis_mode)
            return

        if args.eval_robustness:
            eval_robustness(net, testloader, args.aug_degree)
            return

    for iter_num in iter_nums:
        if args.checkpoint_dir:
            checkpoint_path = os.path.join(args.checkpoint_dir,
                                           'iter_' + str(iter_num) + '.pth')
            load_model(net, args, checkpoint_path)

            if args.vis_mode is not None:
                visualize_model(net, args.vis_mode)
                continue

            if args.eval_robustness:
                eval_robustness(net, testloader, args.aug_degree)
                continue

        save_result = not args.test_interp

        if save_result:
            test_save_paths = []
            test_save_dirs = []
            test_save_dir = "%s-%s-%s-%d" % (args.net, args.job_name,
                                             timestamp, iter_num)
            test_save_path = "../prediction/%s" % (test_save_dir)
            if not os.path.exists(test_save_path):
                os.makedirs(test_save_path)
            test_save_dirs.append(test_save_dir)
            test_save_paths.append(test_save_path)
        else:
            test_save_paths = [None]
            test_save_dirs = [None]

        # No need to use dataloader to pass data,
        # as one 3D image is split into many patches to do segmentation.
        allcls_avg_metric = test_all_cases(
            net,
            db_test,
            task_name=args.task_name,
            net_type=args.net,
            num_classes=args.num_classes,
            batch_size=args.batch_size,
            orig_patch_size=args.orig_patch_size,
            input_patch_size=args.input_patch_size,
            stride_xy=args.orig_patch_size[0] // 2,
            stride_z=args.orig_patch_size[2] // 2,
            save_result=save_result,
            test_save_path=test_save_paths[0],
            preproc_fn=preproc_fn,
            test_interp=args.test_interp,
            has_mask=has_mask)

        print("%d scores:" % iter_num)
        for cls in range(1, args.num_classes):
            dice, jc, hd, asd = allcls_avg_metric[cls - 1]
            print('%d: dice: %.3f, jc: %.3f, hd: %.3f, asd: %.3f' %
                  (cls, dice, jc, hd, asd))

        if save_result:
            FNULL = open(os.devnull, 'w')
            # Currently only save hard predictions.
            for pred_type, test_save_dir, test_save_path in zip(
                ('hard', ), test_save_dirs, test_save_paths):
                do_tar = subprocess.run(
                    ["tar", "cvf",
                     "%s.tar" % test_save_dir, test_save_dir],
                    cwd="../prediction",
                    stdout=FNULL,
                    stderr=subprocess.STDOUT)
                # print(do_tar)
                print("{} tarball:\n{}.tar".format(
                    pred_type, os.path.abspath(test_save_path)))

    return allcls_avg_metric
示例#15
0
    if args.translayer_compress_ratios is not None:
        args.translayer_compress_ratios = [
            float(r) for r in args.translayer_compress_ratios.split(",")
        ]
    else:
        args.translayer_compress_ratios = [
            1 for layer in range(args.num_translayers + 1)
        ]

    logging.info(str(args))
    base_lr = args.lr

    if args.net == 'vnet':
        net = VNet(n_channels=1,
                   n_classes=args.num_classes,
                   normalization='batchnorm',
                   has_dropout=True)
    elif args.net == 'unet':
        net = UNet3D(in_channels=1, n_classes=args.num_classes)
    elif args.net == 'segtran':
        if args.segtran_type == '3d':
            set_segtran3d_config(args)
            net = Segtran3d(config3d)
        else:
            set_segtran25d_config(args)
            net = Segtran25d(config25d)
    else:
        breakpoint()

    net.cuda()
    if args.opt == 'sgd':
示例#16
0
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    if os.path.exists(snapshot_path + '/code'):
        shutil.rmtree(snapshot_path + '/code')
    shutil.copytree('.', snapshot_path + '/code',
                    shutil.ignore_patterns(['.git', '__pycache__']))

    logging.basicConfig(filename=snapshot_path + "/log.txt",
                        level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s',
                        datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    net = VNet(n_channels=1,
               n_classes=num_classes,
               normalization='batchnorm',
               has_dropout=False)
    net = net.cuda()

    db_train = LiverTumor(base_dir=train_data_path,
                          split='train',
                          transform=transforms.Compose([
                              RandomRotFlip(),
                              RandomCrop(patch_size),
                              ToTensor(),
                          ]))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    trainloader = DataLoader(db_train,
示例#17
0
def experiment(exp_identifier,
               max_epoch,
               training_data,
               testing_data,
               batch_size=2,
               supervised_only=False,
               K=2,
               T=0.5,
               alpha=1,
               mixup_mode='all',
               Lambda=1,
               Lambda_ramp=None,
               base_lr=0.01,
               change_lr=None,
               aug_factor=1,
               from_saved=None,
               always_do_validation=True,
               decay=0):
    '''
    max_epoch: epochs to run. Going through labeled data once is one epoch.
    batch_size: batch size of labeled data. Unlabeled data is of the same size.
    training_data: data for train_epoch, list of dicts of numpy array.
    training_data: data for validation, list of dicts of numpy array.
    supervised_only: if True, only do supervised training on LABELLED_INDEX; otherwise, use both LABELLED_INDEX and UNLABELLED_INDEX
    
    Hyperparameters
    ---------------
    K: repeats of each unlabelled data
    T: temperature of sharpening
    alpha: mixup hyperparameter of beta distribution
    mixup_mode: how mixup is performed --
        '__': no mix up
        'ww': x and u both mixed up with w(x+u)
        'xx': both with x
        'xu': x with x, u with u
        'uu': both with u
        ... _ means no, x means with x, u means with u, w means with w(x+u)
    Lambda: loss = loss_x + Lambda * loss_u, relative weight for unsupervised loss
    base_lr: initial learning rate

    Lambda_ramp: callable or None. Lambda is ignored if this is not None. In this case,  Lambda = Lambda_ramp(epoch).
    change_lr: dict, {epoch: change_multiplier}


    '''
    print(
        f"Experiment {exp_identifier}: max_epoch = {max_epoch}, batch_size = {batch_size}, supervised_only = {supervised_only},"
        f"K = {K}, T = {T}, alpha = {alpha}, mixup_mode = {mixup_mode}, Lambda = {Lambda}, Lambda_ramp = {Lambda_ramp}, base_lr = {base_lr}, aug_factor = {aug_factor}."
    )

    net = VNet(n_channels=1,
               n_classes=2,
               normalization='batchnorm',
               has_dropout=True)
    eval_net = VNet(n_channels=1,
                    n_classes=2,
                    normalization='batchnorm',
                    has_dropout=True)

    if from_saved is not None:
        net.load_state_dict(torch.load(from_saved))

    if GPU:
        net = net.cuda()
        eval_net.cuda()

    ## eval_net is not updating
    for param in eval_net.parameters():
        param.detach_()

    net.train()
    eval_net.train()

    optimizer = optim.SGD(net.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    x_criterion = soft_cross_entropy  #supervised loss is 0.5*(x_criterion + dice_loss)
    u_criterion = nn.MSELoss()  #unsupervised loss

    training_losses = []
    testing_losses = []
    testing_accuracy = []  #dice accuracy

    patch_size = (112, 112, 80)

    testing_data = [
        shape_transform(CenterCrop(patch_size)(sample))
        for sample in testing_data
    ]
    t0 = time.time()

    lr = base_lr

    for epoch in range(max_epoch):
        labelled_index = np.random.permutation(LABELLED_INDEX)
        unlabelled_index = np.random.permutation(
            UNLABELLED_INDEX)[:len(labelled_index)]
        labelled_data = [training_data[i] for i in labelled_index]
        unlabelled_data = [training_data[i]
                           for i in unlabelled_index]  #size = 16

        ##data transformation: rotation, flip, random_crop
        labelled_data = [
            shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
            for sample in labelled_data
        ]
        unlabelled_data = [
            shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample)))
            for sample in unlabelled_data
        ]

        if Lambda_ramp is not None:
            Lambda = Lambda_ramp(epoch)
            print(f"Lambda ramp: Lambda = {Lambda}")

        if change_lr is not None:
            if epoch in change_lr:
                lr_ = lr * change_lr[epoch]
                print(
                    f"Learning rate decay at epoch {epoch}, from {lr} to {lr_}"
                )
                lr = lr_
                #change learning rate.
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_

        training_loss = train_epoch(net=net,
                                    eval_net=eval_net,
                                    labelled_data=labelled_data,
                                    unlabelled_data=unlabelled_data,
                                    batch_size=batch_size,
                                    supervised_only=supervised_only,
                                    optimizer=optimizer,
                                    x_criterion=x_criterion,
                                    u_criterion=u_criterion,
                                    K=K,
                                    T=T,
                                    alpha=alpha,
                                    mixup_mode=mixup_mode,
                                    Lambda=Lambda,
                                    aug_factor=aug_factor,
                                    decay=decay)

        training_losses.append(training_loss)

        if always_do_validation or epoch % 50 == 0:
            testing_dice_loss, accuracy = validation(net=net,
                                                     testing_data=testing_data,
                                                     x_criterion=x_criterion)

        testing_losses.append(testing_dice_loss)
        testing_accuracy.append(accuracy)
        print(
            f"Epoch {epoch+1}/{max_epoch}, time used: {time.time()-t0:.2f},  training loss: {training_loss:.6f}, testing dice_loss: {testing_dice_loss:.6f}, testing accuracy: {100.0*accuracy:.2f}% "
        )

    save_path = f"../saved/{exp_identifier}.pth"
    torch.save(net.state_dict(), save_path)
    print(f"Experiment {exp_identifier} finished. Model saved as {save_path}")
    return training_losses, testing_losses, testing_accuracy