Esempio n. 1
0
def predict(trainer, test_loader, save_dir=None):

    trainer.eval_mode()
    for image, target in test_loader:

        # transfer image to gpu
        image = image.cuda() if USE_CUDA else image

        # get batch size from image
        batch_size = image.size()[0]

        for b in range(batch_size):
            prediction = trainer.apply_model(image)
            prediction = torch.nn.functional.sigmoid(prediction)

            image = unwrap(image, as_numpy=True, to_cpu=True)
            prediction = unwrap(prediction, as_numpy=True, to_cpu=True)
            target = unwrap(target, as_numpy=True, to_cpu=True)

            fig = plt.figure()

            ax = fig.add_subplot(2, 2, 1)
            ax.imshow(image[b, 0, ...])
            ax.set_title('raw data')

            ax = fig.add_subplot(2, 2, 2)
            ax.imshow(target[b, ...])
            ax.set_title('ground truth')

            ax = fig.add_subplot(2, 2, 4)
            ax.imshow(prediction[b, ...])
            ax.set_title('prediction')

            fig.tight_layout()
            plt.show()
Esempio n. 2
0
 def save_losses(self, losses):
     if self.trainer is None:
         return
     if not self.logging_enabled:
         if self.enable_logging:
             self.register_logger(self.trainer.logger)
         else:
             return
     losses = [loss.detach().mean() for loss in losses]
     for i, current in enumerate(losses):
         self.trainer.update_state(self.get_loss_name(i),
                                   thu.unwrap(current))
Esempio n. 3
0
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
trainer.eval_mode()

if USE_CUDA:
    trainer.cuda()

# look at an example
for img,target in test_loader:
    if USE_CUDA:
        img = img.cuda()

    # softmax on each of the prediction
    preds = trainer.apply_model(img)
    preds = [nn.functional.softmax(pred,dim=1)        for pred in preds]
    preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds]
    img    = unwrap(img,  as_numpy=True, to_cpu=True)
    target  = unwrap(target, as_numpy=True, to_cpu=True)

    n_plots = len(preds) + 2
    batch_size = preds[0].shape[0]

    for b in range(batch_size):

        fig = pylab.figure()

        ax1 = fig.add_subplot(2,4,1)
        ax1.set_title('image')
        ax1.imshow(img[b,0,...])

        ax2 = fig.add_subplot(2,4,2)