Ejemplo n.º 1
0
    def __getitem__(self, index):
        img, target = ImageFolder.__getitem__(self, index)

        #img = self.__image_transformer(img)
        origin = img
        img = transforms.ToPILImage()(img)

        s = float(img.size[0]) / self.slice
        a = s / 2
        tiles = [None] * self.slice**2
        for n in range(self.slice**2):
            i = n // self.slice
            j = n % self.slice
            c = [a * i * 2 + a, a * j * 2 + a]
            c = np.array([c[1] - a, c[0] - a, c[1] + a + 1,
                          c[0] + a + 1]).astype(int)
            tile = img.crop(c.tolist())
            tile = self.__augment_tile(tile)
            # Normalize the patches indipendently to avoid low level features shortcut
            #m, s = tile.view(3,-1).mean(dim=1).numpy(), tile.view(3,-1).std(dim=1).numpy()
            #s[s==0]=1
            #norm = transforms.Normalize(mean=m.tolist(),std=s.tolist())
            #tile = norm(tile)
            tiles[n] = tile

        order = np.random.permutation(self.slice**2)
        data = [tiles[order[t]] for t in range(self.slice**2)]
        data = torch.stack(data, 0)
        tiles = torch.stack(tiles)

        return origin, data, order, tiles
Ejemplo n.º 2
0
class Faces(Dataset):
    def __init__(self,
                 root,
                 loader=default_loader,
                 transform=None,
                 target_transform=None):
        super(Faces, self).__init__()
        self.folder = ImageFolder(root,
                                  transform=transform,
                                  target_transform=target_transform,
                                  loader=loader)

    def __getitem__(self, index):
        return self.folder.__getitem__(index)

    def __len__(self):
        return self.folder.__len__()