コード例 #1
0
                random_outputs = outputs_.data.cpu().numpy()[0,logging_frames,:,:,:]
            else:
                random_inputs = inputs_.data.numpy()[0,logging_frames,:,:,:]
                random_outputs = outputs_.data.numpy()[0,logging_frames,:,:,:]
            for imageIndex in range(10):
                inputImage = random_inputs[imageIndex, :, :, :]
                inputImage = np.transpose(inputImage, (1, 2, 0))
                inputImage = inputImage[:,:,::-1]
                images_list.append(inputImage)
                mask = random_outputs[imageIndex, 0, :, :]
		heatMap = getHeatMapFrom2DArray(mask)
		images_list.append(heatMap)
                mask = 1 / (1 + np.exp(-mask))
		print('max: ' + str(np.max(mask)) + 'min: ' + str(np.min(mask)))
                mask_ = np.greater(mask, 0.5).astype(np.float32)
                overlayedImage = helpers.overlay_mask(inputImage, mask_)
                overlayedImage = overlayedImage * 255
                images_list.append(overlayedImage)
            tboardLogger.image_summary('image_{}'.format(epoch), images_list, epoch)

        # Compute the losses, side outputs and fuse

        losses = [0] * len(outputs)
        for i in range(0, len(outputs)):
            losses[i] = class_balanced_cross_entropy_loss(outputs[i], gts, size_average=True)
            running_loss_tr[i] += losses[i].data[0]
        loss = (1 - epoch / nEpochs)*sum(losses[:-1]) + losses[-1]

        # Print stuff
        if ii % num_img_tr == num_img_tr - 1:
            running_loss_tr = [x / num_img_tr for x in running_loss_tr]
コード例 #2
0
ファイル: pascal.py プロジェクト: 4rshdeep/dextr-api

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import dataloaders.helpers as helpers
    import torch
    import dataloaders.custom_transforms as tr
    from torchvision import transforms

    transform = transforms.Compose([tr.ToTensor()])

    dataset = VOCSegmentation(split=['train', 'val'],
                              transform=transform,
                              retname=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=0)

    for i, sample in enumerate(dataloader):
        plt.figure()
        overlay = helpers.overlay_mask(
            helpers.tens2image(sample["image"]) / 255.,
            np.squeeze(helpers.tens2image(sample["gt"])))
        plt.imshow(overlay)
        plt.title(dataset.category_names[sample["meta"]["category"][0]])
        if i == 3:
            break

    plt.show(block=True)