Ejemplo n.º 1
0
vgg_model = VGGNet(freeze_max=False)
net = FCN8s(vgg_model)

checkpoint = torch.load('baseline.pth')

net.load_state_dict(checkpoint['model_state_dict'])

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device(
    'cpu')

net.to(device)

# train dataloader
scale = Rescale(int(1.5 * 230))
crop = RandomCrop(224)
rotate = RandomRotate(20.0)
norm = Normalize()
tupled = ToTupleTensor()
tupled_with_roi_align = ToRoIAlignTensor()

composed_for_tracking = transforms.Compose(
    [Rescale(224), norm, tupled_with_roi_align])

dataset = SISSDataset(num_slices=153,
                      num_scans=2,
                      root_dir=Path.cwd().parents[0],
                      transform=composed_for_tracking,
                      train=True)

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=1,
Ejemplo n.º 2
0
    fig.suptitle('Sample %d' % idx)

    for slice_, scan in enumerate(['dwi', 'flair', 't1', 't2', 'label']):
        ax = plt.subplot(1, 5, slice_ + 1)
        show_single_img(sample[:, :, slice_], scan == 'label')
        plt.tight_layout()
        ax.set_title(scan)
        ax.axis('off')

    plt.show()


# train dataloader
scale = Rescale(int(1.5 * 230))
crop = RandomCrop(224)
rotate = RandomRotate(20.0)
norm = Normalize()
tupled = ToTupleTensor()
tupled_multiscaled_masks = ToMultiScaleMasks()
tupled_with_roialign = ToRoIAlignTensor()

composed = transforms.Compose(
    [scale, rotate, crop, norm, tupled_multiscaled_masks])

# transforms coupling testing
composed_wo_tupling_norm = transforms.Compose([
    Rescale(int(1.5 * 230)),
    RandomRotate(50.0),
    RandomCrop(224),
    Normalize()
])