Esempio n. 1
0
def visualize(args, path):
    model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    model.load_state_dict(torch.load(path))
    task_dir = args.task
    
    testset = MedicalDataset(task_dir=task_dir, mode='test')
    testloader = data.DataLoader(testset, batch_size=1, shuffle=False)
    
    model.eval()
    with torch.no_grad():
        while testset.iteration < args.test_iteration:
            x, y = testset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #output = torch.nn.Sigmoid()(model(x))
            #output = torch.round(output)   
            output = model.forward(x,y,training=True)
            output = torch.round(output)
#elbo = model.elbo(y)

#            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
#            valid_loss = -elbo + 1e-5 * reg_loss
            print (x.size(), y.size(), output.size())

            grid = torch.cat((x,y,output), dim=0)
            torchvision.utils.save_image(grid, './save/'+testset.task_dir+'prediction'+str(testset.iteration)+'.png', nrow=8, padding=2, pad_value=1)
Esempio n. 2
0
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
if shuffle:
    np.random.shuffle(indices)
# train_indices, test_indices = indices[split:], indices[:split]

eval_indices = indices[:num_test_samples]
# eval_indices = indices
eval_sampler = SubsetRandomSampler(eval_indices)
eval_loader = DataLoader(dataset, batch_size=1, sampler=eval_sampler)
print("Number of test patches:", (len(eval_indices)))

# model
net = ProbabilisticUnet(input_channels=1,
                        num_classes=1,
                        num_filters=[32, 64, 128, 192],
                        latent_dim=2,
                        no_convs_fcomb=4,
                        beta=10.0)

if LOAD_MODEL_FROM is not None:
    import os
    net.load_state_dict(
        torch.load(os.path.join("./saved_checkpoints/", LOAD_MODEL_FROM)))

net.to(device)
net.eval()


def energy_distance(seg_samples, gt_seg_modes, num_samples=2):
    num_modes = 4  # fixed for LIDC
Esempio n. 3
0
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))

#np.random.shuffle(indices)
print('There is no random shuffle: initial portion of the dataset is used for train and the last portion for validation')

train_indices, test_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(dataset, batch_size=batch_size_train, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=batch_size_val, sampler=test_sampler)
print("Number of training/test patches:", (len(train_indices),len(test_indices)))

# network
net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
net.cuda()

# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=l2_reg)
secheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_every, gamma=lr_decay)

# logging
train_loss = []
test_loss = []
best_val_loss = 999.0

for epoch in range(epochs):
    net.train()
    loss_train = 0
    loss_segmentation = 0
Esempio n. 4
0
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler)
print("Number of training/test patches:",
      (len(train_indices), len(test_indices)))

# 加载已经训练好的网络进行预测
model = ProbabilisticUnet(input_channels=1,
                          num_classes=1,
                          num_filters=[32, 64, 128, 192],
                          latent_dim=2,
                          no_convs_fcomb=4,
                          beta=10.0)
net = load_model(model=model, path='model/unet_1.pt', device=device)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)

predict_time = 20  # 每张图片预测次数

### 显示结果
for step, (patch, mask, _) in enumerate(test_loader):
    if step == 0:
        results = []  # 保持每次预测的结果

        ## show the image
        label_np = mask.numpy()[0]
dataset = LIDC_IDRI(dataset_location='data/')
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler)
print("Number of training/test patches:",
      (len(train_indices), len(test_indices)))

net = ProbabilisticUnet(input_channels=1,
                        num_classes=1,
                        num_filters=[32, 64, 128, 192],
                        latent_dim=2,
                        no_convs_fcomb=4,
                        beta=10.0)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
epochs = 10
for epoch in range(epochs):
    for step, (patch, mask, _) in enumerate(train_loader):
        patch = patch.to(device)
        mask = mask.to(device)
        mask = torch.unsqueeze(mask, 1)
        net.forward(patch, mask, training=True)
        elbo = net.elbo(mask)
        reg_loss = l2_regularisation(net.posterior) + l2_regularisation(
            net.prior) + l2_regularisation(net.fcomb.layers)
        loss = -elbo + 1e-5 * reg_loss
Esempio n. 6
0
def func_2():
    """打印原图、标签、和预测结果"""
    # 参数
    class_num = param.class_num  # 选择分割类别数
    predict_time = 16  # 每张图预测次数,(1,4,8,16)
    latent_dim = 6  # 隐空间维度
    train_batch_size = 1  # 预测
    test_batch_size = 1  # 预测
    model_name = 'unet_e100_p6_c9_ld6.pt'  # 加载模型名称
    device = param.device  # 选gpu

    # 选择数据集
    dataset = BrainS18Dataset(root_dir='data/BrainS18',
                            folders=['1_img'],
                            class_num=class_num,
                            file_names=['_reg_T1.png', '_segm.png'])
    # dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_Brats17_CBICA_AAB_1_img'],
    #                           class_num=class_num,
    #                           file_names=['_reg_T1.png', '_segm.png'])

    # 数据划分并设置sampler((固定训练集和测试集))
    dataset_size = len(dataset)  # 数据集大小
    test_indices = list(range(dataset_size))
    test_sampler = SequentialSampler(test_indices)
    # 数据加载器
    test_loader = DataLoader(
        dataset, batch_size=test_batch_size, sampler=test_sampler)
    print("Number of test patches: {}".format(len(test_indices)))
    # 加载已经训练好的网络进行预测
    model = ProbabilisticUnet(input_channels=1,
                            num_classes=class_num,
                            num_filters=[32, 64, 128, 192],
                            latent_dim=latent_dim,
                            no_convs_fcomb=4,
                            beta=10.0)
    net = load_model(model=model,
                    path='model/{}'.format(model_name),
                    device=device)
    # 预测
    with torch.no_grad():
        for step, (patch, mask, series_uid) in enumerate(test_loader):
            if step == 14:
                for i in range(20):
                    print("Picture {} (patient {} - slice {})...".format(step,
                                                                        series_uid[0][0], series_uid[1][0]))
                    # 记录numpy
                    # (batch_size,1,240,240)->(1,240,240)
                    image_np = patch.cpu().numpy().reshape(240, 240)
                    label_np = mask.cpu().numpy().reshape(240, 240)  # (batch_size,1,240,240) 元素值1-10
                    label_np -= 1  # (batch_size,1,240,240) 元素值从1-10变为0-9
                    # 预测
                    patch = patch.to(device)
                    net.forward(patch, None, training=False)
                    # 预测结果, (batch_size,class_num,240,240)
                    mask_pre = net.sample(testing=True)
                    # torch变numpy(batch_size,class_num,240,240)
                    mask_pre_np = mask_pre.cpu().detach().numpy()
                    mask_pre_np = mask_pre_np.reshape((class_num, 240, 240))  # 降维
                    ## 统计每个像素的对应通道最大值所在通道即为对应类
                    # 计算每个batch的预测结果最大值,单通道,元素值0-9
                    mask_pro = mask_pre_np.argmax(axis=0)
                    # print(label_np.shape, image_np.shape, mask_pro.shape)
                    # 原图
                    # plt.figure(figsize=(1, 1))
                    # plt.imshow(image_np, aspect="auto", cmap="gray")
                    # plt.gca().xaxis.set_major_locator(plt.NullLocator())
                    # plt.gca().yaxis.set_major_locator(plt.NullLocator())
                    # # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
                    # plt.margins(0,0)
                    # plt.savefig('picture/a_func2_orgin.png', format='png', transparent=True, dpi=300, pad_inches = 0)
                    # plt.close()
                    # ground truth
                    # plt.figure(figsize=(1, 1))
                    # # 10 discrete colors,tab10,Paired
                    # cmap = plt.cm.get_cmap('tab10', 10)
                    # plt.imshow(label_np, cmap=cmap, aspect="auto", vmin=0, vmax=9)
                    # plt.gca().xaxis.set_major_locator(plt.NullLocator())
                    # plt.gca().yaxis.set_major_locator(plt.NullLocator())
                    # # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
                    # plt.margins(0,0)
                    # plt.savefig('picture/a_func2_gt.png', format='png', transparent=True, dpi=300, pad_inches = 0)
                    # plt.close()
                    # 预测结果
                    plt.figure(figsize=(1, 1))
                    # 10 discrete colors,tab10,Paired
                    cmap = plt.cm.get_cmap('tab10', 10)
                    plt.imshow(mask_pro, cmap=cmap, aspect="auto", vmin=0, vmax=9)
                    plt.gca().xaxis.set_major_locator(plt.NullLocator())
                    plt.gca().yaxis.set_major_locator(plt.NullLocator())
                    # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
                    plt.margins(0,0)
                    plt.savefig('picture/a_func2_pre{}.png'.format(i), format='png', transparent=True, dpi=300, pad_inches = 0)
                    plt.close()
Esempio n. 7
0
def func_4():
    """打印部分mri图像,标签,即不确定性结果"""
    imgs = np.zeros((8, 240, 240)) # 原图
    labels = np.zeros((8, 240, 240)) # 标签
    entropies = np.zeros((8, 240, 240)) # 熵
    variances = np.zeros((8, 240, 240)) # 方差
    count = 0 # 计数

    class_num = param.class_num # 选择分割类别数
    predict_time = 16 # 每张图预测次数,(1,4,8,16)
    latent_dim = 6 # 隐空间维度
    train_batch_size = 1 # 预测
    test_batch_size = 1 # 预测
    model_name = 'punet_e128_c9_ld6_f1.pt' # 加载模型名称
    device = param.device # 选gpu
    # 选择数据集
    dataset = BrainS18Dataset(root_dir='data/BrainS18', 
                              folders=['1_img'], 
                              class_num=class_num, 
                              file_names=['_reg_T1.png', '_segm.png'])
    # dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_Brats17_CBICA_AAB_1_img'],
    #                         class_num=class_num,
    #                         file_names=['_reg_T1.png', '_segm.png'])

    # 数据划分并设置sampler((固定训练集和测试集))
    dataset_size = len(dataset)  # 数据集大小
    test_indices = list(range(dataset_size))
    test_sampler = SequentialSampler(test_indices)

    # 数据加载器
    test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=test_sampler)
    print("Number of test patches: {}".format(len(test_indices)))

    # 加载已经训练好的网络进行预测
    model = ProbabilisticUnet(input_channels=1, 
                            num_classes=class_num, 
                            num_filters=[32,64,128,192], 
                            latent_dim=latent_dim,
                            no_convs_fcomb=4, 
                            beta=10.0)
    net = load_model(model=model, 
                    path='model/{}'.format(model_name), 
                    device=device)

    # 预测
    with torch.no_grad():
        for step, (patch, mask, series_uid) in enumerate(test_loader): 
            # if step in (0,6,12,18,24,30,36,42):
            if step in (14,15,16,17,18,19,20,21):
                print("Picture {} (patient {} - slice {})...".format(step, series_uid[0][0], series_uid[1][0]))
                mask_pros = [] # 记录每次预测结果(选择最大值后的)
                mask_pres = [] # 记录每次预测结果
                # 记录numpy
                image_np = patch.numpy().reshape(240,240) # (batch_size,1,240,240)->(1,240,240)
                label_np = mask.numpy().reshape(240,240) # (batch_size,1,240,240) 元素值1-10
                label_np -= 1 # (batch_size,1,240,240) 元素值从1-10变为0-9
                imgs[count] = image_np
                labels[count] = label_np

                # 预测predict_time次计算方差
                for i in range(predict_time):
                    patch = patch.to(device)
                    net.forward(patch, None, training=False) 
                    mask_pre = net.sample(testing=True) # 预测结果, (batch_size,class_num,240,240)
                    
                    # 记录softmax后的值
                    p_value = F.softmax(mask_pre, dim=1)
                    p_value = p_value.cpu().numpy().reshape((class_num,240,240)) # 降维
                    mask_pres.append(p_value)

                    # torch变numpy(batch_size,class_num,240,240)
                    mask_pre_np = mask_pre.cpu().detach().numpy()
                    mask_pre_np = mask_pre_np.reshape((class_num,240,240)) # 降维

                    ## 统计每个像素的对应通道最大值所在通道即为对应类
                    mask_pro = mask_pre_np.argmax(axis=0) # 计算每个batch的预测结果最大值,单通道,元素值0-9
                    mask_pros.append(mask_pro)

                # 计算均值和方差,并保存相应图片
                entropy, variance_result = cal_variance(image_np, label_np, mask_pros, mask_pres, class_num, series_uid)  
                entropies[count] = entropy
                variances[count] = variance_result
                count += 1

    
    fig, ax = plt.subplots(4, 8, sharey=True, figsize=(20, 10))
    cmap = plt.cm.get_cmap('tab10', 10)    # 10 discrete colors,tab10,Paired

    for i in range(8):
        ax[0][i].imshow(imgs[i], aspect="auto", cmap="gray")
        ax[1][i].imshow(labels[i], cmap=cmap, aspect="auto", vmin=0, vmax=9)
        ax[2][i].imshow(entropies[i], aspect="auto", cmap="jet", vmin=0, vmax=2)
        ax[3][i].imshow(variances[i], aspect="auto", cmap="jet")

    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    fig.savefig('picture/a_func_4_v4.png', format='png', transparent=True, dpi=300, pad_inches = 0)
    plt.close()
Esempio n. 8
0
def func_3():
    """打印部分mri图像,标签,单个预测结果
    """
    # 参数
    class_num = param.class_num # 选择分割类别数

    # 选择数据集
    # dataset = BrainS18Dataset(root_dir='data/BrainS18', 
    #                         folders=['1_img'], 
    #                         class_num=class_num, 
    #                         file_names=['_reg_T1.png', '_segm.png'])
    dataset = BrainS18Dataset(root_dir='data/BrainS18', folders=['1_Brats17_CBICA_AAB_1_img'],
                              class_num=class_num,
                              file_names=['_reg_T1.png', '_segm.png'])

    # 数据划分并设置sampler((固定训练集和测试集))
    model_name = 'punet_e128_c9_ld6_f1.pt' # 加载模型名称
    device = param.device  # 选gpu
    dataset_size = len(dataset)  # 数据集大小
    test_indices = list(range(dataset_size))
    test_sampler = SequentialSampler(test_indices)
    test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler)
    model = ProbabilisticUnet(input_channels=1, 
                            num_classes=class_num, 
                            num_filters=[32,64,128,192], 
                            latent_dim=6,
                            no_convs_fcomb=4, 
                            beta=10.0)
    net = load_model(model=model, 
                    path='model/{}'.format(model_name), 
                    device=device)


    imgs = np.zeros((8, 240, 240))
    labels = np.zeros((8, 240, 240))
    predicts = np.zeros((8, 240, 240))

    with torch.no_grad():
        count = 0
        for step, (patch, mask, series_uid) in enumerate(test_loader):
            if step in (0,6,12,18,24,30,36,42):
                print("Picture {} (patient {} - slice {})...".format(step,
                                                    series_uid[0][0], series_uid[1][0]))
                # 记录numpy
                # (batch_size,1,240,240)->(1,240,240)
                image_np = patch.cpu().numpy().reshape(240, 240)
                label_np = mask.cpu().numpy().reshape(240, 240)  # (batch_size,1,240,240) 元素值1-10
                label_np -= 1  # (batch_size,1,240,240) 元素值从1-10变为0-9
                # 预测
                patch = patch.to(device)
                net.forward(patch, None, training=False)
                # 预测结果, (batch_size,class_num,240,240)
                mask_pre = net.sample(testing=True)
                # torch变numpy(batch_size,class_num,240,240)
                mask_pre_np = mask_pre.cpu().detach().numpy()
                mask_pre_np = mask_pre_np.reshape(
                    (class_num, 240, 240))  # 降维
                ## 统计每个像素的对应通道最大值所在通道即为对应类
                # 计算每个batch的预测结果最大值,单通道,元素值0-9
                mask_pro = mask_pre_np.argmax(axis=0)
                predicts[count] = mask_pro
                count += 1

    count = 0
    for i in range(48):
        if i in (0,6,12,18,24,30,36,42):
        # if i in (14,15,16,17,18,19,20,21):
            image, label, series_uid = dataset.__getitem__(i)
            image = image.numpy().reshape(240, 240)
            label -= 1
            label = label.numpy().reshape(240, 240)
            imgs[count] = image
            labels[count] = label
            count += 1

    fig, ax = plt.subplots(3, 8, sharey=True, figsize=(20, 7.5))
    cmap = plt.cm.get_cmap('tab10', 10)    # 10 discrete colors,tab10,Paired

    for i in range(8):
        ax[0][i].imshow(imgs[i], aspect="auto", cmap="gray")
        ax[1][i].imshow(labels[i], cmap=cmap, aspect="auto", vmin=0, vmax=9)
        ax[2][i].imshow(predicts[i], cmap=cmap, aspect="auto", vmin=0, vmax=9)

    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    fig.savefig('picture/a_func_3_v2.png', format='png', transparent=True, dpi=300, pad_inches = 0)
    plt.close()
Esempio n. 9
0
from utils import l2_regularisation

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = LIDC_IDRI(dataset_location = 'data/')
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler)
print("Number of training/test patches:", (len(train_indices),len(test_indices)))

net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
epochs = 10
for epoch in range(epochs):
    for step, (patch, mask, _) in enumerate(train_loader): 
        patch = patch.to(device)
        mask = mask.to(device)
        mask = torch.unsqueeze(mask,1)
        net.forward(patch, mask, training=True)
        elbo = net.elbo(mask)
        reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
        loss = -elbo + 1e-5 * reg_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Esempio n. 10
0
eval_dataset = GanDataset(GAN_DATASET['EVAL']['INPUT'],
                          GAN_DATASET['EVAL']['GT'], LAYER)
train_loader = DataLoader(train_dataset,
                          batch_size=1,
                          shuffle=True,
                          num_workers=0)
eval_loader = DataLoader(eval_dataset,
                         batch_size=1,
                         shuffle=True,
                         num_workers=0)

# INITIALISE NETWORKS
net = ProbabilisticUnet(input_channels=256,
                        num_classes=256,
                        num_filters=[256, 512, 1024, 2048],
                        latent_dim=latent_dims_layer[LAYER],
                        no_convs_fcomb=fcomb_layer[LAYER],
                        beta=10.0,
                        layer=LAYER).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)

currEpoch = 0
max_epochs = 200
savedEpoch = getLatestModelEpoch(MODELS_DEST, LAYER)
if savedEpoch:
    loadModel(net, optimizer, getModelFilePath(MODELS_DEST, LAYER, savedEpoch))
    currEpoch = savedEpoch + 1

for epoch in range(currEpoch, max_epochs):
    # TRAINING
    net.train()
Esempio n. 11
0
	sys.exit()


latent_dims_layer = {
    'fpn_res5_2_sum': 10,
    'fpn_res4_5_sum': 20,
    'fpn_res3_3_sum': 10,
    'fpn_res2_2_sum': 100
}

fcomb_layer = {
    'fpn_res5_2_sum': 8,
    'fpn_res4_5_sum': 4,
    'fpn_res3_3_sum': 8,
    'fpn_res2_2_sum': 4
}
LAYER = args.layer

net = ProbabilisticUnet(input_channels=256, num_classes=256, num_filters=[256, 512, 1024, 2048], latent_dim=latent_dims_layer[LAYER], no_convs_fcomb=fcomb_layer[LAYER], beta=10.0, layer=LAYER)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
loadModel(net, optimizer, args.model)
print("Reading input from:", args.input)
inp = torch.load(args.input, map_location="cpu")
net.forward(inp, training=False)
out = net.sample(testing=True)
if(args.output):
	print("Output saved to:", args.output)
	torch.save(out, args.output)

print(inp.shape, out.shape)
Esempio n. 12
0
def train(args):
    num_epoch = args.epoch
    learning_rate = args.learning_rate
    task_dir = args.task
    
    trainset = MedicalDataset(task_dir=task_dir, mode='train' )
    validset = MedicalDataset(task_dir=task_dir, mode='valid')

    model =  ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    #summary(model, (1,320,320))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
    criterion = torch.nn.BCELoss()

    for epoch in range(num_epoch):
        model.train()
        while trainset.iteration < args.iteration:
            x, y = trainset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #print(x.size(), y.size())
            #output = torch.nn.Sigmoid()(model(x))
            model.forward(x,y,training=True)
            elbo = model.elbo(y)

            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
            loss = -elbo + 1e-5 * reg_loss
            #loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        trainset.iteration = 0

        model.eval()
        with torch.no_grad():
            while validset.iteration < args.test_iteration:
                x, y = validset.next()
                x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
                #output = torch.nn.Sigmoid()(model(x, y))
                model.forward(x,y,training=True)
                elbo = model.elbo(y)

                reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
                valid_loss = -elbo + 1e-5 * reg_loss
            validset.iteration = 0
                
        print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item()))
        """
        #Logger
         # 1. Log scalar values (scalar summary)
        info = { 'loss': loss.item(), 'accuracy': valid_loss.item() }

        for tag, value in info.items():
            Logger.scalar_summary(tag, value, epoch+1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
            Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
        """
    torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')
Esempio n. 13
0
dataset = LIDC_IDRI(dataset_location='data/')
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))

train_indices, test_indices = indices[split:], indices[:split]
test_sampler = SubsetRandomSampler(test_indices)
test_loader = DataLoader(dataset,
                         batch_size=batch_size_val,
                         sampler=test_sampler)
print("Number of test patches:", len(test_indices))

# network
net = ProbabilisticUnet(input_channels=1,
                        num_classes=1,
                        num_filters=[32, 64, 128, 192],
                        latent_dim=2,
                        no_convs_fcomb=4,
                        beta=10.0)
net.cuda()

# load pretrained model
cpk_name = os.path.join(cpk_directory, 'model_dict.pth')
net.load_state_dict(torch.load(cpk_name))

net.eval()
with torch.no_grad():
    for step, (patch, mask, _) in enumerate(test_loader):
        if step >= save_batches_n:
            break
        patch = patch.cuda()
        mask = mask.cuda()