Ejemplo n.º 1
0
def test_vshift():
    # fmt: off
    expected = torch.Tensor([[[[12, 13, 14, 15], [0, 1, 2, 3], [4, 5, 6, 7],
                               [8, 9, 10, 11]]]])
    # fmt: on
    forward = F.VShift().forward(INPUT, 1)
    assert expected.allclose(forward)
    assert INPUT.allclose(F.VShift().backward(forward, 1))
Ejemplo n.º 2
0
class Augmentation(object):

    transforms = {
        "h_flip": F.HFlip(),
        "v_flip": F.VFlip(),
        "rotation": F.Rotate(),
        "h_shift": F.HShift(),
        "v_shift": F.VShift(),
        # 'contrast': F.Contrast(),
        "add": F.Add(),
        "mul": F.Multiply(),
    }

    def __init__(self, **params):
        super().__init__()

        transforms = [Augmentation.transforms[k] for k in params.keys()]
        transform_params = [params[k] for k in params.keys()]

        # add identity parameters for all transforms and convert to list
        transform_params = [
            t.prepare(params)
            for t, params in zip(transforms, transform_params)
        ]

        # get all combinations of transforms params
        transform_params = list(itertools.product(*transform_params))

        self.forward_aug = [t.forward for t in transforms]
        self.forward_params = transform_params

        # reverse transforms
        self.backward_aug = [t.backward for t in transforms[::-1]]
        # reverse params
        self.backward_params = [p[::-1] for p in transform_params]

        self.n_transforms = len(transform_params)
        print(f"Will merge {self.n_transforms} augmentations for each image.")

    def forward(self, x):
        self.bs = x.shape[0]
        transformed_batches = []
        for i, args in enumerate(self.forward_params):
            batch = x
            for f, arg in zip(self.forward_aug, args):
                batch = f(batch, arg)
            transformed_batches.append(batch)
        # returns shape B*Aug x C x H x W
        return torch.cat(transformed_batches, 0)

    def backward(self, x):
        # reshape to separate batches
        x = x.reshape([-1, self.bs, *x.shape[1:]])
        transformed_batches = []
        for i, args in enumerate(self.backward_params):
            batch = x[i]
            for f, arg in zip(self.backward_aug, args):
                batch = f(batch, arg)
            transformed_batches.append(batch)
        return torch.cat(transformed_batches, 0)