def get_datasets(initial_pool, path): IM_SIZE = 224 # TODO add better data augmentation scheme. transform = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(IM_SIZE), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) test_transform = transforms.Compose([ transforms.Resize(512), transforms.CenterCrop(IM_SIZE), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) target_transform = transforms.Compose([ transforms.Resize(512, interpolation=Image.NEAREST), transforms.CenterCrop(IM_SIZE), PILToLongTensor(pascal_voc_ids) ]) active_set, test_set = active_pascal(path=path, transform=transform, test_transform=test_transform, target_transform=target_transform) active_set.label_randomly(initial_pool) return active_set, test_set
def test_pil_to_long_tensor(img): transformer = PILToLongTensor( classes=[np.array([100, 100, 100]), np.array([101, 102, 104])]) # test with numpy: long_img = transformer(img) assert isinstance(long_img, torch.Tensor) # test with PIL img = Image.fromarray(img) long_img_2 = transformer(img) assert isinstance(long_img, torch.Tensor) assert (long_img == long_img_2).all()
def test_segmentation_pipeline(self): class DrawSquare: def __init__(self, side): self.side = side def __call__(self, x, **kwargs): x, canvas = x # x is a [int, ndarray] canvas[:self.side, :self.side] = x return canvas target_trans = BaaLCompose([ GetCanvas(), DrawSquare(3), ToPILImage(mode=None), Resize(60, interpolation=0), RandomRotation(10, resample=NEAREST, fill=0.0), PILToLongTensor() ]) file_dataset = FileDataset(self.paths, [1] * len(self.paths), self.transform, target_trans) x, y = file_dataset[0] assert np.allclose(np.unique(y), [0, 1]) assert y.shape[1:] == x.shape[1:]