Пример #1
0
class Augmentation(object):

    transforms = {
        "h_flip": F.HFlip(),
        "v_flip": F.VFlip(),
        "rotation": F.Rotate(),
        "h_shift": F.HShift(),
        "v_shift": F.VShift(),
        # 'contrast': F.Contrast(),
        "add": F.Add(),
        "mul": F.Multiply(),
    }

    def __init__(self, **params):
        super().__init__()

        transforms = [Augmentation.transforms[k] for k in params.keys()]
        transform_params = [params[k] for k in params.keys()]

        # add identity parameters for all transforms and convert to list
        transform_params = [
            t.prepare(params)
            for t, params in zip(transforms, transform_params)
        ]

        # get all combinations of transforms params
        transform_params = list(itertools.product(*transform_params))

        self.forward_aug = [t.forward for t in transforms]
        self.forward_params = transform_params

        # reverse transforms
        self.backward_aug = [t.backward for t in transforms[::-1]]
        # reverse params
        self.backward_params = [p[::-1] for p in transform_params]

        self.n_transforms = len(transform_params)
        print(f"Will merge {self.n_transforms} augmentations for each image.")

    def forward(self, x):
        self.bs = x.shape[0]
        transformed_batches = []
        for i, args in enumerate(self.forward_params):
            batch = x
            for f, arg in zip(self.forward_aug, args):
                batch = f(batch, arg)
            transformed_batches.append(batch)
        # returns shape B*Aug x C x H x W
        return torch.cat(transformed_batches, 0)

    def backward(self, x):
        # reshape to separate batches
        x = x.reshape([-1, self.bs, *x.shape[1:]])
        transformed_batches = []
        for i, args in enumerate(self.backward_params):
            batch = x[i]
            for f, arg in zip(self.backward_aug, args):
                batch = f(batch, arg)
            transformed_batches.append(batch)
        return torch.cat(transformed_batches, 0)
Пример #2
0
def test_add():
    expected = INPUT + 2
    assert expected.allclose(F.Add().forward(INPUT, 2))