Exemple #1
0
def visualize(args, path):
    model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    model.load_state_dict(torch.load(path))
    task_dir = args.task
    
    testset = MedicalDataset(task_dir=task_dir, mode='test')
    testloader = data.DataLoader(testset, batch_size=1, shuffle=False)
    
    model.eval()
    with torch.no_grad():
        while testset.iteration < args.test_iteration:
            x, y = testset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #output = torch.nn.Sigmoid()(model(x))
            #output = torch.round(output)   
            output = model.forward(x,y,training=True)
            output = torch.round(output)
#elbo = model.elbo(y)

#            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
#            valid_loss = -elbo + 1e-5 * reg_loss
            print (x.size(), y.size(), output.size())

            grid = torch.cat((x,y,output), dim=0)
            torchvision.utils.save_image(grid, './save/'+testset.task_dir+'prediction'+str(testset.iteration)+'.png', nrow=8, padding=2, pad_value=1)
Exemple #2
0
# model
net = ProbabilisticUnet(input_channels=1,
                        num_classes=1,
                        num_filters=[32, 64, 128, 192],
                        latent_dim=2,
                        no_convs_fcomb=4,
                        beta=10.0)

if LOAD_MODEL_FROM is not None:
    import os
    net.load_state_dict(
        torch.load(os.path.join("./saved_checkpoints/", LOAD_MODEL_FROM)))

net.to(device)
net.eval()


def energy_distance(seg_samples, gt_seg_modes, num_samples=2):
    num_modes = 4  # fixed for LIDC

    # if num_samples != len(seg_samples) or num_samples != len(gt_seg_modes):
    #     raise ValueError

    d_matrix_YS = np.zeros(shape=(num_modes, num_samples), dtype=np.float32)
    d_matrix_YY = np.zeros(shape=(num_modes, num_modes), dtype=np.float32)
    d_matrix_SS = np.zeros(shape=(num_samples, num_samples), dtype=np.float32)

    # iterate all ground-truth modes
    for mode in range(num_modes):
Exemple #3
0
def train(args):
    num_epoch = args.epoch
    learning_rate = args.learning_rate
    task_dir = args.task
    
    trainset = MedicalDataset(task_dir=task_dir, mode='train' )
    validset = MedicalDataset(task_dir=task_dir, mode='valid')

    model =  ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    #summary(model, (1,320,320))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
    criterion = torch.nn.BCELoss()

    for epoch in range(num_epoch):
        model.train()
        while trainset.iteration < args.iteration:
            x, y = trainset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #print(x.size(), y.size())
            #output = torch.nn.Sigmoid()(model(x))
            model.forward(x,y,training=True)
            elbo = model.elbo(y)

            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
            loss = -elbo + 1e-5 * reg_loss
            #loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        trainset.iteration = 0

        model.eval()
        with torch.no_grad():
            while validset.iteration < args.test_iteration:
                x, y = validset.next()
                x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
                #output = torch.nn.Sigmoid()(model(x, y))
                model.forward(x,y,training=True)
                elbo = model.elbo(y)

                reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
                valid_loss = -elbo + 1e-5 * reg_loss
            validset.iteration = 0
                
        print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item()))
        """
        #Logger
         # 1. Log scalar values (scalar summary)
        info = { 'loss': loss.item(), 'accuracy': valid_loss.item() }

        for tag, value in info.items():
            Logger.scalar_summary(tag, value, epoch+1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
            Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
        """
    torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')