def visualize(args, path): model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) model.to(device) model.load_state_dict(torch.load(path)) task_dir = args.task testset = MedicalDataset(task_dir=task_dir, mode='test') testloader = data.DataLoader(testset, batch_size=1, shuffle=False) model.eval() with torch.no_grad(): while testset.iteration < args.test_iteration: x, y = testset.next() x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda() #output = torch.nn.Sigmoid()(model(x)) #output = torch.round(output) output = model.forward(x,y,training=True) output = torch.round(output) #elbo = model.elbo(y) # reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers) # valid_loss = -elbo + 1e-5 * reg_loss print (x.size(), y.size(), output.size()) grid = torch.cat((x,y,output), dim=0) torchvision.utils.save_image(grid, './save/'+testset.task_dir+'prediction'+str(testset.iteration)+'.png', nrow=8, padding=2, pad_value=1)
indices = list(range(dataset_size)) split = int(np.floor(0.1 * dataset_size)) if shuffle: np.random.shuffle(indices) # train_indices, test_indices = indices[split:], indices[:split] eval_indices = indices[:num_test_samples] # eval_indices = indices eval_sampler = SubsetRandomSampler(eval_indices) eval_loader = DataLoader(dataset, batch_size=1, sampler=eval_sampler) print("Number of test patches:", (len(eval_indices))) # model net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=2, no_convs_fcomb=4, beta=10.0) if LOAD_MODEL_FROM is not None: import os net.load_state_dict( torch.load(os.path.join("./saved_checkpoints/", LOAD_MODEL_FROM))) net.to(device) net.eval() def energy_distance(seg_samples, gt_seg_modes, num_samples=2): num_modes = 4 # fixed for LIDC
dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.1 * dataset_size)) #np.random.shuffle(indices) print('There is no random shuffle: initial portion of the dataset is used for train and the last portion for validation') train_indices, test_indices = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) train_loader = DataLoader(dataset, batch_size=batch_size_train, sampler=train_sampler) test_loader = DataLoader(dataset, batch_size=batch_size_val, sampler=test_sampler) print("Number of training/test patches:", (len(train_indices),len(test_indices))) # network net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) net.cuda() # optimizer optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=l2_reg) secheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_every, gamma=lr_decay) # logging train_loss = [] test_loss = [] best_val_loss = 999.0 for epoch in range(epochs): net.train() loss_train = 0 loss_segmentation = 0
indices = list(range(dataset_size)) split = int(np.floor(0.1 * dataset_size)) np.random.shuffle(indices) train_indices, test_indices = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler) test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler) print("Number of training/test patches:", (len(train_indices), len(test_indices))) # 加载已经训练好的网络进行预测 model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=2, no_convs_fcomb=4, beta=10.0) net = load_model(model=model, path='model/unet_1.pt', device=device) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) predict_time = 20 # 每张图片预测次数 ### 显示结果 for step, (patch, mask, _) in enumerate(test_loader): if step == 0: results = [] # 保持每次预测的结果 ## show the image label_np = mask.numpy()[0]
dataset = LIDC_IDRI(dataset_location='data/') dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.1 * dataset_size)) np.random.shuffle(indices) train_indices, test_indices = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler) test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler) print("Number of training/test patches:", (len(train_indices), len(test_indices))) net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=2, no_convs_fcomb=4, beta=10.0) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) epochs = 10 for epoch in range(epochs): for step, (patch, mask, _) in enumerate(train_loader): patch = patch.to(device) mask = mask.to(device) mask = torch.unsqueeze(mask, 1) net.forward(patch, mask, training=True) elbo = net.elbo(mask) reg_loss = l2_regularisation(net.posterior) + l2_regularisation( net.prior) + l2_regularisation(net.fcomb.layers) loss = -elbo + 1e-5 * reg_loss
def func_2(): """打印原图、标签、和预测结果""" # 参数 class_num = param.class_num # 选择分割类别数 predict_time = 16 # 每张图预测次数,(1,4,8,16) latent_dim = 6 # 隐空间维度 train_batch_size = 1 # 预测 test_batch_size = 1 # 预测 model_name = 'unet_e100_p6_c9_ld6.pt' # 加载模型名称 device = param.device # 选gpu # 选择数据集 dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_img'], class_num=class_num, file_names=['_reg_T1.png', '_segm.png']) # dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_Brats17_CBICA_AAB_1_img'], # class_num=class_num, # file_names=['_reg_T1.png', '_segm.png']) # 数据划分并设置sampler((固定训练集和测试集)) dataset_size = len(dataset) # 数据集大小 test_indices = list(range(dataset_size)) test_sampler = SequentialSampler(test_indices) # 数据加载器 test_loader = DataLoader( dataset, batch_size=test_batch_size, sampler=test_sampler) print("Number of test patches: {}".format(len(test_indices))) # 加载已经训练好的网络进行预测 model = ProbabilisticUnet(input_channels=1, num_classes=class_num, num_filters=[32, 64, 128, 192], latent_dim=latent_dim, no_convs_fcomb=4, beta=10.0) net = load_model(model=model, path='model/{}'.format(model_name), device=device) # 预测 with torch.no_grad(): for step, (patch, mask, series_uid) in enumerate(test_loader): if step == 14: for i in range(20): print("Picture {} (patient {} - slice {})...".format(step, series_uid[0][0], series_uid[1][0])) # 记录numpy # (batch_size,1,240,240)->(1,240,240) image_np = patch.cpu().numpy().reshape(240, 240) label_np = mask.cpu().numpy().reshape(240, 240) # (batch_size,1,240,240) 元素值1-10 label_np -= 1 # (batch_size,1,240,240) 元素值从1-10变为0-9 # 预测 patch = patch.to(device) net.forward(patch, None, training=False) # 预测结果, (batch_size,class_num,240,240) mask_pre = net.sample(testing=True) # torch变numpy(batch_size,class_num,240,240) mask_pre_np = mask_pre.cpu().detach().numpy() mask_pre_np = mask_pre_np.reshape((class_num, 240, 240)) # 降维 ## 统计每个像素的对应通道最大值所在通道即为对应类 # 计算每个batch的预测结果最大值,单通道,元素值0-9 mask_pro = mask_pre_np.argmax(axis=0) # print(label_np.shape, image_np.shape, mask_pro.shape) # 原图 # plt.figure(figsize=(1, 1)) # plt.imshow(image_np, aspect="auto", cmap="gray") # plt.gca().xaxis.set_major_locator(plt.NullLocator()) # plt.gca().yaxis.set_major_locator(plt.NullLocator()) # # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) # plt.margins(0,0) # plt.savefig('picture/a_func2_orgin.png', format='png', transparent=True, dpi=300, pad_inches = 0) # plt.close() # ground truth # plt.figure(figsize=(1, 1)) # # 10 discrete colors,tab10,Paired # cmap = plt.cm.get_cmap('tab10', 10) # plt.imshow(label_np, cmap=cmap, aspect="auto", vmin=0, vmax=9) # plt.gca().xaxis.set_major_locator(plt.NullLocator()) # plt.gca().yaxis.set_major_locator(plt.NullLocator()) # # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) # plt.margins(0,0) # plt.savefig('picture/a_func2_gt.png', format='png', transparent=True, dpi=300, pad_inches = 0) # plt.close() # 预测结果 plt.figure(figsize=(1, 1)) # 10 discrete colors,tab10,Paired cmap = plt.cm.get_cmap('tab10', 10) plt.imshow(mask_pro, cmap=cmap, aspect="auto", vmin=0, vmax=9) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) plt.savefig('picture/a_func2_pre{}.png'.format(i), format='png', transparent=True, dpi=300, pad_inches = 0) plt.close()
def func_4(): """打印部分mri图像,标签,即不确定性结果""" imgs = np.zeros((8, 240, 240)) # 原图 labels = np.zeros((8, 240, 240)) # 标签 entropies = np.zeros((8, 240, 240)) # 熵 variances = np.zeros((8, 240, 240)) # 方差 count = 0 # 计数 class_num = param.class_num # 选择分割类别数 predict_time = 16 # 每张图预测次数,(1,4,8,16) latent_dim = 6 # 隐空间维度 train_batch_size = 1 # 预测 test_batch_size = 1 # 预测 model_name = 'punet_e128_c9_ld6_f1.pt' # 加载模型名称 device = param.device # 选gpu # 选择数据集 dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_img'], class_num=class_num, file_names=['_reg_T1.png', '_segm.png']) # dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_Brats17_CBICA_AAB_1_img'], # class_num=class_num, # file_names=['_reg_T1.png', '_segm.png']) # 数据划分并设置sampler((固定训练集和测试集)) dataset_size = len(dataset) # 数据集大小 test_indices = list(range(dataset_size)) test_sampler = SequentialSampler(test_indices) # 数据加载器 test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=test_sampler) print("Number of test patches: {}".format(len(test_indices))) # 加载已经训练好的网络进行预测 model = ProbabilisticUnet(input_channels=1, num_classes=class_num, num_filters=[32,64,128,192], latent_dim=latent_dim, no_convs_fcomb=4, beta=10.0) net = load_model(model=model, path='model/{}'.format(model_name), device=device) # 预测 with torch.no_grad(): for step, (patch, mask, series_uid) in enumerate(test_loader): # if step in (0,6,12,18,24,30,36,42): if step in (14,15,16,17,18,19,20,21): print("Picture {} (patient {} - slice {})...".format(step, series_uid[0][0], series_uid[1][0])) mask_pros = [] # 记录每次预测结果(选择最大值后的) mask_pres = [] # 记录每次预测结果 # 记录numpy image_np = patch.numpy().reshape(240,240) # (batch_size,1,240,240)->(1,240,240) label_np = mask.numpy().reshape(240,240) # (batch_size,1,240,240) 元素值1-10 label_np -= 1 # (batch_size,1,240,240) 元素值从1-10变为0-9 imgs[count] = image_np labels[count] = label_np # 预测predict_time次计算方差 for i in range(predict_time): patch = patch.to(device) net.forward(patch, None, training=False) mask_pre = net.sample(testing=True) # 预测结果, (batch_size,class_num,240,240) # 记录softmax后的值 p_value = F.softmax(mask_pre, dim=1) p_value = p_value.cpu().numpy().reshape((class_num,240,240)) # 降维 mask_pres.append(p_value) # torch变numpy(batch_size,class_num,240,240) mask_pre_np = mask_pre.cpu().detach().numpy() mask_pre_np = mask_pre_np.reshape((class_num,240,240)) # 降维 ## 统计每个像素的对应通道最大值所在通道即为对应类 mask_pro = mask_pre_np.argmax(axis=0) # 计算每个batch的预测结果最大值,单通道,元素值0-9 mask_pros.append(mask_pro) # 计算均值和方差,并保存相应图片 entropy, variance_result = cal_variance(image_np, label_np, mask_pros, mask_pres, class_num, series_uid) entropies[count] = entropy variances[count] = variance_result count += 1 fig, ax = plt.subplots(4, 8, sharey=True, figsize=(20, 10)) cmap = plt.cm.get_cmap('tab10', 10) # 10 discrete colors,tab10,Paired for i in range(8): ax[0][i].imshow(imgs[i], aspect="auto", cmap="gray") ax[1][i].imshow(labels[i], cmap=cmap, aspect="auto", vmin=0, vmax=9) ax[2][i].imshow(entropies[i], aspect="auto", cmap="jet", vmin=0, vmax=2) ax[3][i].imshow(variances[i], aspect="auto", cmap="jet") plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) fig.savefig('picture/a_func_4_v4.png', format='png', transparent=True, dpi=300, pad_inches = 0) plt.close()
def func_3(): """打印部分mri图像,标签,单个预测结果 """ # 参数 class_num = param.class_num # 选择分割类别数 # 选择数据集 # dataset = BrainS18Dataset(root_dir='data/BrainS18', # folders=['1_img'], # class_num=class_num, # file_names=['_reg_T1.png', '_segm.png']) dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_Brats17_CBICA_AAB_1_img'], class_num=class_num, file_names=['_reg_T1.png', '_segm.png']) # 数据划分并设置sampler((固定训练集和测试集)) model_name = 'punet_e128_c9_ld6_f1.pt' # 加载模型名称 device = param.device # 选gpu dataset_size = len(dataset) # 数据集大小 test_indices = list(range(dataset_size)) test_sampler = SequentialSampler(test_indices) test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler) model = ProbabilisticUnet(input_channels=1, num_classes=class_num, num_filters=[32,64,128,192], latent_dim=6, no_convs_fcomb=4, beta=10.0) net = load_model(model=model, path='model/{}'.format(model_name), device=device) imgs = np.zeros((8, 240, 240)) labels = np.zeros((8, 240, 240)) predicts = np.zeros((8, 240, 240)) with torch.no_grad(): count = 0 for step, (patch, mask, series_uid) in enumerate(test_loader): if step in (0,6,12,18,24,30,36,42): print("Picture {} (patient {} - slice {})...".format(step, series_uid[0][0], series_uid[1][0])) # 记录numpy # (batch_size,1,240,240)->(1,240,240) image_np = patch.cpu().numpy().reshape(240, 240) label_np = mask.cpu().numpy().reshape(240, 240) # (batch_size,1,240,240) 元素值1-10 label_np -= 1 # (batch_size,1,240,240) 元素值从1-10变为0-9 # 预测 patch = patch.to(device) net.forward(patch, None, training=False) # 预测结果, (batch_size,class_num,240,240) mask_pre = net.sample(testing=True) # torch变numpy(batch_size,class_num,240,240) mask_pre_np = mask_pre.cpu().detach().numpy() mask_pre_np = mask_pre_np.reshape( (class_num, 240, 240)) # 降维 ## 统计每个像素的对应通道最大值所在通道即为对应类 # 计算每个batch的预测结果最大值,单通道,元素值0-9 mask_pro = mask_pre_np.argmax(axis=0) predicts[count] = mask_pro count += 1 count = 0 for i in range(48): if i in (0,6,12,18,24,30,36,42): # if i in (14,15,16,17,18,19,20,21): image, label, series_uid = dataset.__getitem__(i) image = image.numpy().reshape(240, 240) label -= 1 label = label.numpy().reshape(240, 240) imgs[count] = image labels[count] = label count += 1 fig, ax = plt.subplots(3, 8, sharey=True, figsize=(20, 7.5)) cmap = plt.cm.get_cmap('tab10', 10) # 10 discrete colors,tab10,Paired for i in range(8): ax[0][i].imshow(imgs[i], aspect="auto", cmap="gray") ax[1][i].imshow(labels[i], cmap=cmap, aspect="auto", vmin=0, vmax=9) ax[2][i].imshow(predicts[i], cmap=cmap, aspect="auto", vmin=0, vmax=9) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) plt.margins(0,0) fig.savefig('picture/a_func_3_v2.png', format='png', transparent=True, dpi=300, pad_inches = 0) plt.close()
from utils import l2_regularisation device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = LIDC_IDRI(dataset_location = 'data/') dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.1 * dataset_size)) np.random.shuffle(indices) train_indices, test_indices = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler) test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler) print("Number of training/test patches:", (len(train_indices),len(test_indices))) net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) epochs = 10 for epoch in range(epochs): for step, (patch, mask, _) in enumerate(train_loader): patch = patch.to(device) mask = mask.to(device) mask = torch.unsqueeze(mask,1) net.forward(patch, mask, training=True) elbo = net.elbo(mask) reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers) loss = -elbo + 1e-5 * reg_loss optimizer.zero_grad() loss.backward() optimizer.step()
eval_dataset = GanDataset(GAN_DATASET['EVAL']['INPUT'], GAN_DATASET['EVAL']['GT'], LAYER) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0) eval_loader = DataLoader(eval_dataset, batch_size=1, shuffle=True, num_workers=0) # INITIALISE NETWORKS net = ProbabilisticUnet(input_channels=256, num_classes=256, num_filters=[256, 512, 1024, 2048], latent_dim=latent_dims_layer[LAYER], no_convs_fcomb=fcomb_layer[LAYER], beta=10.0, layer=LAYER).cuda() optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) currEpoch = 0 max_epochs = 200 savedEpoch = getLatestModelEpoch(MODELS_DEST, LAYER) if savedEpoch: loadModel(net, optimizer, getModelFilePath(MODELS_DEST, LAYER, savedEpoch)) currEpoch = savedEpoch + 1 for epoch in range(currEpoch, max_epochs): # TRAINING net.train()
sys.exit() latent_dims_layer = { 'fpn_res5_2_sum': 10, 'fpn_res4_5_sum': 20, 'fpn_res3_3_sum': 10, 'fpn_res2_2_sum': 100 } fcomb_layer = { 'fpn_res5_2_sum': 8, 'fpn_res4_5_sum': 4, 'fpn_res3_3_sum': 8, 'fpn_res2_2_sum': 4 } LAYER = args.layer net = ProbabilisticUnet(input_channels=256, num_classes=256, num_filters=[256, 512, 1024, 2048], latent_dim=latent_dims_layer[LAYER], no_convs_fcomb=fcomb_layer[LAYER], beta=10.0, layer=LAYER) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) loadModel(net, optimizer, args.model) print("Reading input from:", args.input) inp = torch.load(args.input, map_location="cpu") net.forward(inp, training=False) out = net.sample(testing=True) if(args.output): print("Output saved to:", args.output) torch.save(out, args.output) print(inp.shape, out.shape)
def train(args): num_epoch = args.epoch learning_rate = args.learning_rate task_dir = args.task trainset = MedicalDataset(task_dir=task_dir, mode='train' ) validset = MedicalDataset(task_dir=task_dir, mode='valid') model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) model.to(device) #summary(model, (1,320,320)) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0) criterion = torch.nn.BCELoss() for epoch in range(num_epoch): model.train() while trainset.iteration < args.iteration: x, y = trainset.next() x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda() #print(x.size(), y.size()) #output = torch.nn.Sigmoid()(model(x)) model.forward(x,y,training=True) elbo = model.elbo(y) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers) loss = -elbo + 1e-5 * reg_loss #loss = criterion(output, y) optimizer.zero_grad() loss.backward() optimizer.step() trainset.iteration = 0 model.eval() with torch.no_grad(): while validset.iteration < args.test_iteration: x, y = validset.next() x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda() #output = torch.nn.Sigmoid()(model(x, y)) model.forward(x,y,training=True) elbo = model.elbo(y) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers) valid_loss = -elbo + 1e-5 * reg_loss validset.iteration = 0 print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item())) """ #Logger # 1. Log scalar values (scalar summary) info = { 'loss': loss.item(), 'accuracy': valid_loss.item() } for tag, value in info.items(): Logger.scalar_summary(tag, value, epoch+1) # 2. Log values and gradients of the parameters (histogram summary) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1) Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1) """ torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')
dataset = LIDC_IDRI(dataset_location='data/') dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.1 * dataset_size)) train_indices, test_indices = indices[split:], indices[:split] test_sampler = SubsetRandomSampler(test_indices) test_loader = DataLoader(dataset, batch_size=batch_size_val, sampler=test_sampler) print("Number of test patches:", len(test_indices)) # network net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32, 64, 128, 192], latent_dim=2, no_convs_fcomb=4, beta=10.0) net.cuda() # load pretrained model cpk_name = os.path.join(cpk_directory, 'model_dict.pth') net.load_state_dict(torch.load(cpk_name)) net.eval() with torch.no_grad(): for step, (patch, mask, _) in enumerate(test_loader): if step >= save_batches_n: break patch = patch.cuda() mask = mask.cuda()