def test_calculate_metric(epoch_num, patch_size=(128, 128, 64), stride_xy=64, stride_z=32, device='cuda'): net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).to(device) save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') print(save_mode_path) net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() metrics = test_all_case(net, image_list, num_classes=num_classes, name_classes=name_classes, patch_size=patch_size, stride_xy=stride_xy, stride_z=stride_z, save_result=True, test_save_path=test_save_path, device=device) return metrics
def main(): args = get_args() # dataset db_test = ABUS(base_dir=args.root_path, split='test') testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) args.testloader = testloader # network if args.arch == 'vnet': model = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm) elif args.arch == 'd2unet': model = D2UNet() else: raise (NotImplementedError('model {} not implement'.format(args.arch))) model = model.cuda() if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_pre = checkpoint['best_pre'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) # --- saving path --- if 'best' in args.resume: file_name = 'model_best_' + str(checkpoint['epoch']) elif 'check' in args.resume: file_name = 'checkpoint_{}_result'.format(checkpoint['epoch']) if args.save is not None: save_path = os.path.join(args.save, file_name) else: save_path = os.path.join(os.path.dirname(args.resume), file_name) if os.path.exists(save_path): shutil.rmtree(save_path) os.makedirs(save_path, exist_ok=True) test_all_case(model, args.testloader, num_classes=args.num_classes, patch_size=(64, 128, 128), stride_xy=64, stride_z=64, save_result=True, test_save_path=save_path)
def test_calculate_metric(epoch_num): net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, image_list, num_classes=num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=test_save_path) return avg_metric
def test_calculate_metric(args): net = VNet(n_channels=1, n_classes=args.num_classes, normalization='batchnorm', has_dropout=False).cuda() save_mode_path = os.path.join(args.snapshot_path, 'iter_' + str(args.start_epoch) + '.pth') net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, args.testloader, num_classes=args.num_classes, patch_size=(128, 64, 128), stride_xy=18, stride_z=4, save_result=True, test_save_path=args.test_save_path) return avg_metric
def debugger(): patch_size = (112, 112, 80) training_data = data_loader(split='train') testing_data = data_loader(split='test') x_criterion = soft_cross_entropy #supervised loss is 0.5*(x_criterion + dice_loss) u_criterion = nn.MSELoss() #unsupervised loss labelled_index = np.random.permutation(LABELLED_INDEX) unlabelled_index = np.random.permutation( UNLABELLED_INDEX)[:len(labelled_index)] labelled_data = [training_data[i] for i in labelled_index] unlabelled_data = [training_data[i] for i in unlabelled_index] #size = 16 ##data transformation: rotation, flip, random_crop labelled_data = [ shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample))) for sample in labelled_data ] unlabelled_data = [ shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample))) for sample in unlabelled_data ] net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True).cuda() model_path = "../saved/0_supervised.pth" net.load_state_dict(torch.load(model_path)) optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001) training_loss = train_epoch(net=net, labelled_data=labelled_data, unlabelled_data=unlabelled_data, batch_size=2, supervised_only=True, optimizer=optimizer, x_criterion=x_criterion, u_criterion=u_criterion, K=1, T=1, alpha=1, mixup_mode="__", Lambda=0, aug_factor=0) net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True).cuda() model_path = "../saved/8_expected_supervised.pth" net.load_state_dict(torch.load(model_path)) optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001) training_loss = train_epoch(net=net, labelled_data=labelled_data, unlabelled_data=unlabelled_data, batch_size=2, supervised_only=False, optimizer=optimizer, x_criterion=x_criterion, u_criterion=u_criterion, K=1, T=1, alpha=1, mixup_mode="__", Lambda=0, aug_factor=0)
def experiment(exp_identifier, max_epoch, training_data, testing_data, batch_size=2, supervised_only=False, K=2, T=0.5, alpha=1, mixup_mode='all', Lambda=1, Lambda_ramp=None, base_lr=0.01, change_lr=None, aug_factor=1, from_saved=None, always_do_validation=True, decay=0): ''' max_epoch: epochs to run. Going through labeled data once is one epoch. batch_size: batch size of labeled data. Unlabeled data is of the same size. training_data: data for train_epoch, list of dicts of numpy array. training_data: data for validation, list of dicts of numpy array. supervised_only: if True, only do supervised training on LABELLED_INDEX; otherwise, use both LABELLED_INDEX and UNLABELLED_INDEX Hyperparameters --------------- K: repeats of each unlabelled data T: temperature of sharpening alpha: mixup hyperparameter of beta distribution mixup_mode: how mixup is performed -- '__': no mix up 'ww': x and u both mixed up with w(x+u) 'xx': both with x 'xu': x with x, u with u 'uu': both with u ... _ means no, x means with x, u means with u, w means with w(x+u) Lambda: loss = loss_x + Lambda * loss_u, relative weight for unsupervised loss base_lr: initial learning rate Lambda_ramp: callable or None. Lambda is ignored if this is not None. In this case, Lambda = Lambda_ramp(epoch). change_lr: dict, {epoch: change_multiplier} ''' print( f"Experiment {exp_identifier}: max_epoch = {max_epoch}, batch_size = {batch_size}, supervised_only = {supervised_only}," f"K = {K}, T = {T}, alpha = {alpha}, mixup_mode = {mixup_mode}, Lambda = {Lambda}, Lambda_ramp = {Lambda_ramp}, base_lr = {base_lr}, aug_factor = {aug_factor}." ) net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True) eval_net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=True) if from_saved is not None: net.load_state_dict(torch.load(from_saved)) if GPU: net = net.cuda() eval_net.cuda() ## eval_net is not updating for param in eval_net.parameters(): param.detach_() net.train() eval_net.train() optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) x_criterion = soft_cross_entropy #supervised loss is 0.5*(x_criterion + dice_loss) u_criterion = nn.MSELoss() #unsupervised loss training_losses = [] testing_losses = [] testing_accuracy = [] #dice accuracy patch_size = (112, 112, 80) testing_data = [ shape_transform(CenterCrop(patch_size)(sample)) for sample in testing_data ] t0 = time.time() lr = base_lr for epoch in range(max_epoch): labelled_index = np.random.permutation(LABELLED_INDEX) unlabelled_index = np.random.permutation( UNLABELLED_INDEX)[:len(labelled_index)] labelled_data = [training_data[i] for i in labelled_index] unlabelled_data = [training_data[i] for i in unlabelled_index] #size = 16 ##data transformation: rotation, flip, random_crop labelled_data = [ shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample))) for sample in labelled_data ] unlabelled_data = [ shape_transform(RandomRotFlip()(RandomCrop(patch_size)(sample))) for sample in unlabelled_data ] if Lambda_ramp is not None: Lambda = Lambda_ramp(epoch) print(f"Lambda ramp: Lambda = {Lambda}") if change_lr is not None: if epoch in change_lr: lr_ = lr * change_lr[epoch] print( f"Learning rate decay at epoch {epoch}, from {lr} to {lr_}" ) lr = lr_ #change learning rate. for param_group in optimizer.param_groups: param_group['lr'] = lr_ training_loss = train_epoch(net=net, eval_net=eval_net, labelled_data=labelled_data, unlabelled_data=unlabelled_data, batch_size=batch_size, supervised_only=supervised_only, optimizer=optimizer, x_criterion=x_criterion, u_criterion=u_criterion, K=K, T=T, alpha=alpha, mixup_mode=mixup_mode, Lambda=Lambda, aug_factor=aug_factor, decay=decay) training_losses.append(training_loss) if always_do_validation or epoch % 50 == 0: testing_dice_loss, accuracy = validation(net=net, testing_data=testing_data, x_criterion=x_criterion) testing_losses.append(testing_dice_loss) testing_accuracy.append(accuracy) print( f"Epoch {epoch+1}/{max_epoch}, time used: {time.time()-t0:.2f}, training loss: {training_loss:.6f}, testing dice_loss: {testing_dice_loss:.6f}, testing accuracy: {100.0*accuracy:.2f}% " ) save_path = f"../saved/{exp_identifier}.pth" torch.save(net.state_dict(), save_path) print(f"Experiment {exp_identifier} finished. Model saved as {save_path}") return training_losses, testing_losses, testing_accuracy