class Augmentation(object): transforms = { "h_flip": F.HFlip(), "v_flip": F.VFlip(), "rotation": F.Rotate(), "h_shift": F.HShift(), "v_shift": F.VShift(), # 'contrast': F.Contrast(), "add": F.Add(), "mul": F.Multiply(), } def __init__(self, **params): super().__init__() transforms = [Augmentation.transforms[k] for k in params.keys()] transform_params = [params[k] for k in params.keys()] # add identity parameters for all transforms and convert to list transform_params = [ t.prepare(params) for t, params in zip(transforms, transform_params) ] # get all combinations of transforms params transform_params = list(itertools.product(*transform_params)) self.forward_aug = [t.forward for t in transforms] self.forward_params = transform_params # reverse transforms self.backward_aug = [t.backward for t in transforms[::-1]] # reverse params self.backward_params = [p[::-1] for p in transform_params] self.n_transforms = len(transform_params) print(f"Will merge {self.n_transforms} augmentations for each image.") def forward(self, x): self.bs = x.shape[0] transformed_batches = [] for i, args in enumerate(self.forward_params): batch = x for f, arg in zip(self.forward_aug, args): batch = f(batch, arg) transformed_batches.append(batch) # returns shape B*Aug x C x H x W return torch.cat(transformed_batches, 0) def backward(self, x): # reshape to separate batches x = x.reshape([-1, self.bs, *x.shape[1:]]) transformed_batches = [] for i, args in enumerate(self.backward_params): batch = x[i] for f, arg in zip(self.backward_aug, args): batch = f(batch, arg) transformed_batches.append(batch) return torch.cat(transformed_batches, 0)
def test_multiply(): expected = INPUT * 1.1 assert expected.allclose(F.Multiply().forward(INPUT, 1.1))