def test_calculate_metric(epoch_num, patch_size=(128, 128, 64), stride_xy=64, stride_z=32, device='cuda'): net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).to(device) save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') print(save_mode_path) net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() metrics = test_all_case(net, image_list, num_classes=num_classes, name_classes=name_classes, patch_size=patch_size, stride_xy=stride_xy, stride_z=stride_z, save_result=True, test_save_path=test_save_path, device=device) return metrics
def test_calculate_metric(epoch_num): net = VNet(n_channels=1, n_classes=num_classes-1, normalization='batchnorm', has_dropout=False).cuda() save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, image_list, num_classes=num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=test_save_path, metric_detail=FLAGS.detail, nms=FLAGS.nms) return avg_metric
def test_calculate_metric( model_path, patch_size=(128, 128, 64), stride_xy=64, stride_z=32, device='cuda'): #net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).to(device) #net.load_state_dict(torch.load(model_path)) net = torch.load(model_path) net.eval() metrics = test_all_case( net, image_list, num_classes=num_classes, name_classes=name_classes, patch_size=patch_size, stride_xy=stride_xy, stride_z=stride_z, save_result=True, test_save_path=test_save_path) return metrics
def test_calculate_metric(): net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() save_mode_path = FLAGS.model net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, image_list, num_classes=num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=test_save_path) return avg_metric
def Inference(FLAGS): snapshot_path = "../model/{}/{}".format(FLAGS.exp, FLAGS.model) num_classes = 2 test_save_path = "../model/BraTs2019_Mean_Teacher_25/{}_Prediction".format( FLAGS.model) if os.path.exists(test_save_path): shutil.rmtree(test_save_path) os.makedirs(test_save_path) net = unet_3D(n_classes=num_classes, in_channels=1).cuda() save_mode_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, base_dir=FLAGS.root_path, method=FLAGS.model, test_list="test.txt", num_classes=num_classes, patch_size=(96, 96, 96), stride_xy=64, stride_z=64, test_save_path=test_save_path) return avg_metric