Esempio n. 1
0
def visualize(args, path):
    model = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    model.load_state_dict(torch.load(path))
    task_dir = args.task
    
    testset = MedicalDataset(task_dir=task_dir, mode='test')
    testloader = data.DataLoader(testset, batch_size=1, shuffle=False)
    
    model.eval()
    with torch.no_grad():
        while testset.iteration < args.test_iteration:
            x, y = testset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #output = torch.nn.Sigmoid()(model(x))
            #output = torch.round(output)   
            output = model.forward(x,y,training=True)
            output = torch.round(output)
#elbo = model.elbo(y)

#            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
#            valid_loss = -elbo + 1e-5 * reg_loss
            print (x.size(), y.size(), output.size())

            grid = torch.cat((x,y,output), dim=0)
            torchvision.utils.save_image(grid, './save/'+testset.task_dir+'prediction'+str(testset.iteration)+'.png', nrow=8, padding=2, pad_value=1)
Esempio n. 2
0
# logging
train_loss = []
test_loss = []
best_val_loss = 999.0

for epoch in range(epochs):
    net.train()
    loss_train = 0
    loss_segmentation = 0
    # training loop
    for step, (patch, mask, _) in enumerate(train_loader): 
        patch = patch.cuda()
        mask = mask.cuda()
        mask = torch.unsqueeze(mask,1)
        net.forward(patch, mask, training=True)
        elbo = net.elbo(mask)
        loss = -elbo
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_train += loss.detach().cpu().item()
        
        if step%100==0:
            print('[Ep ', epoch+1, (step+1), ' of ', len(train_loader) ,'] train loss: ', loss_train/(step+1))
        
    # end of training loop
    loss_train /= len(train_loader)
    
    # valdiation loop
Esempio n. 3
0
     'l2pri': [],
     'l2fcom': [],
     'total': []
 }
 train_targetLosses = []
 train_count = 0
 for idx, data in enumerate(train_loader):
     # print("Epoch:", epoch, "idx:", idx)
     inp = data["input"][0].cuda()
     gt = data["gt"][0].cuda()
     targetLoss = torch.nn.L1Loss()(inp, gt)
     print("Target Loss:", targetLoss.item())
     # Extremely important to protect from initial KL collapse
     if (torch.isnan(targetLoss)):
         continue
     net.forward(inp, gt, training=True)
     reconLoss, klLoss = net.elbo(gt)
     elbo = -(reconLoss + 10.0 * klLoss)
     l2posterior = l2_regularisation(net.posterior)
     l2prior = l2_regularisation(net.prior)
     l2fcomb = l2_regularisation(net.fcomb.layers)
     reg_loss = l2posterior + l2prior + l2fcomb
     loss = -elbo + 1e-5 * reg_loss
     if (loss.item() > 100000):
         continue
     print("Total Loss: ", loss.item())
     train_losses['rec'].append(reconLoss.item())
     train_losses['kl'].append(klLoss.item())
     train_losses['l2pos'].append(l2posterior.item())
     train_losses['l2pri'].append(l2prior.item())
     train_losses['l2fcom'].append(l2fcomb.item())
Esempio n. 4
0
	sys.exit()


latent_dims_layer = {
    'fpn_res5_2_sum': 10,
    'fpn_res4_5_sum': 20,
    'fpn_res3_3_sum': 10,
    'fpn_res2_2_sum': 100
}

fcomb_layer = {
    'fpn_res5_2_sum': 8,
    'fpn_res4_5_sum': 4,
    'fpn_res3_3_sum': 8,
    'fpn_res2_2_sum': 4
}
LAYER = args.layer

net = ProbabilisticUnet(input_channels=256, num_classes=256, num_filters=[256, 512, 1024, 2048], latent_dim=latent_dims_layer[LAYER], no_convs_fcomb=fcomb_layer[LAYER], beta=10.0, layer=LAYER)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
loadModel(net, optimizer, args.model)
print("Reading input from:", args.input)
inp = torch.load(args.input, map_location="cpu")
net.forward(inp, training=False)
out = net.sample(testing=True)
if(args.output):
	print("Output saved to:", args.output)
	torch.save(out, args.output)

print(inp.shape, out.shape)
Esempio n. 5
0
def train(args):
    num_epoch = args.epoch
    learning_rate = args.learning_rate
    task_dir = args.task
    
    trainset = MedicalDataset(task_dir=task_dir, mode='train' )
    validset = MedicalDataset(task_dir=task_dir, mode='valid')

    model =  ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0)
    model.to(device)
    #summary(model, (1,320,320))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
    criterion = torch.nn.BCELoss()

    for epoch in range(num_epoch):
        model.train()
        while trainset.iteration < args.iteration:
            x, y = trainset.next()
            x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
            #print(x.size(), y.size())
            #output = torch.nn.Sigmoid()(model(x))
            model.forward(x,y,training=True)
            elbo = model.elbo(y)

            reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
            loss = -elbo + 1e-5 * reg_loss
            #loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        trainset.iteration = 0

        model.eval()
        with torch.no_grad():
            while validset.iteration < args.test_iteration:
                x, y = validset.next()
                x, y = torch.from_numpy(x).unsqueeze(0).cuda(), torch.from_numpy(y).unsqueeze(0).cuda()
                #output = torch.nn.Sigmoid()(model(x, y))
                model.forward(x,y,training=True)
                elbo = model.elbo(y)

                reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + l2_regularisation(model.fcomb.layers)
                valid_loss = -elbo + 1e-5 * reg_loss
            validset.iteration = 0
                
        print('Epoch: {}, elbo: {:.4f}, regloss: {:.4f}, loss: {:.4f}, valid loss: {:.4f}'.format(epoch+1, elbo.item(), reg_loss.item(), loss.item(), valid_loss.item()))
        """
        #Logger
         # 1. Log scalar values (scalar summary)
        info = { 'loss': loss.item(), 'accuracy': valid_loss.item() }

        for tag, value in info.items():
            Logger.scalar_summary(tag, value, epoch+1)

        # 2. Log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            Logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
            Logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
        """
    torch.save(model.state_dict(), './save/'+trainset.task_dir+'model.pth')