Esempio n. 1
0
	sys.exit()


latent_dims_layer = {
    'fpn_res5_2_sum': 10,
    'fpn_res4_5_sum': 20,
    'fpn_res3_3_sum': 10,
    'fpn_res2_2_sum': 100
}

fcomb_layer = {
    'fpn_res5_2_sum': 8,
    'fpn_res4_5_sum': 4,
    'fpn_res3_3_sum': 8,
    'fpn_res2_2_sum': 4
}
LAYER = args.layer

net = ProbabilisticUnet(input_channels=256, num_classes=256, num_filters=[256, 512, 1024, 2048], latent_dim=latent_dims_layer[LAYER], no_convs_fcomb=fcomb_layer[LAYER], beta=10.0, layer=LAYER)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
loadModel(net, optimizer, args.model)
print("Reading input from:", args.input)
inp = torch.load(args.input, map_location="cpu")
net.forward(inp, training=False)
out = net.sample(testing=True)
if(args.output):
	print("Output saved to:", args.output)
	torch.save(out, args.output)

print(inp.shape, out.shape)
Esempio n. 2
0
train_loss,val_loss=[],[]
dice_score_train,dice_score_val=[],[]


for epoch in range(epochs):
    running_train_loss = []
    running_train_score = []
    print('Numbers of epoch:{}/{}'.format(epoch+1,epochs))
    starded = time.time()
          
    for batch_idx, (data, target) in enumerate(train_loader):
        #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))
        elbo = net.elbo(target.to(device),data.to(device))
        reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)
        loss = -elbo + 1e-5*reg_loss
        score = batch_dice(F.softmax(net.sample(data,mean=False)))
        #running_loss += loss.item() * inputs.size(0) 
        #print(loss) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        running_train_loss.append(loss.item())
        running_train_score.append(score.item())
        print('loss batch: {},score batch: {}, batch_idx: {}'.format(loss.item(),score.item(),batch_idx))
    else:
        running_val_loss=[]
        running_val_score=[]
          
        with torch.no_grad():
            for data,target in val_loader:
          
Esempio n. 3
0
cpk_name = os.path.join(cpk_directory, 'model_dict.pth')
net.load_state_dict(torch.load(cpk_name))

net.eval()
with torch.no_grad():
    for step, (patch, mask, _) in enumerate(test_loader):
        if step >= save_batches_n:
            break
        patch = patch.cuda()
        mask = mask.cuda()
        mask = torch.unsqueeze(mask, 1)
        output_samples = []
        for i in range(samples_per_example):
            net.forward(patch, mask, training=True)
            output_samples.append(
                torch.sigmoid(net.sample()).detach().cpu().numpy())

        for k in range(patch.shape[0]):  # for all items in batch
            patch_out = patch[k, 0, :, :].detach().cpu().numpy()
            mask_out = mask[k, 0, :, :].detach().cpu().numpy()
            # pred_out = pred_mask[k, 0, :,:].detach().cpu().numpy()
            plt.figure()

            plt.subplot(3, 2, 1)
            plt.imshow(patch_out)
            plt.title('patch')
            plt.axis('off')
            plt.subplot(3, 2, 2)
            plt.imshow(mask_out)
            plt.title('GT Mask')
            plt.axis('off')