def __init__(self, sigma, b, canopy, sl, sr):
     super(ResizeBrightnessNoiseTransformer, self).__init__()
     self.resize_adder = transforms.Resize(canopy, sl, sr)
     self.noise_adder = transforms.Noise(sigma)
     self.brightness_adder = transforms.BrightnessShift(b)
     self.sigma = sigma
     self.b = b
 def __init__(self, sigma_k, sigma_b):
     super(BrightnessTransformer, self).__init__()
     self.sigma_k = sigma_k
     self.sigma_b = sigma_b
     self.scaler = transforms.BrightnessScale(sigma_k)
     self.brighter = transforms.BrightnessShift(sigma_b)
     self.k_l = self.k_r = 0
 def __init__(self, sigma, b, canopy, rotation_angle=180.0):
     super(RotationBrightnessNoiseTransformer, self).__init__()
     self.sigma = sigma
     self.b = b
     self.noise_adder = transforms.Noise(self.sigma)
     self.brightness_adder = transforms.BrightnessShift(self.b)
     self.rotation_adder = transforms.Rotation(canopy, rotation_angle)
     self.round = 2
     self.masking = True
 def __init__(self, sigma, b, k, canopy, rotation_angle=180.0):
     super(RotationBrightnessContrastNoiseTransformer, self).__init__()
     self.sigma = sigma
     self.sigma_b = b
     self.sigma_k = k
     self.noise_adder = transforms.Noise(self.sigma)
     self.scaler = transforms.BrightnessScale(self.sigma_k)
     self.brightness_adder = transforms.BrightnessShift(self.sigma_b)
     self.rotation_adder = transforms.Rotation(canopy, rotation_angle)
     self.input_dim = canopy.numel()
     self.round = 2
     self.masking = True
     self.k_l = self.k_r = 0
    def __init__(self, sigma, sigma_k, sigma_b, lamb, sigma_trans, sl, sr,
                 rotation_angle, canopy):
        super(UniversalTransformer, self).__init__()
        self.sigma = sigma

        self.sigma_k = sigma_k
        self.sigma_b = sigma_b
        self.scaler = transforms.BrightnessScale(sigma_k)
        self.brighter = transforms.BrightnessShift(sigma_b)
        self.gaussian_adder = transforms.ExpGaussian(lamb)
        self.translation_adder = transforms.Translational(canopy, sigma_trans)
        self.resize_adder = transforms.Resize(canopy, sl, sr)
        self.rotation_adder = transforms.Rotation(canopy, rotation_angle)
        self.noise_adder = transforms.Noise(self.sigma)
        self.round = 1
        self.masking = True
    model.eval()

    # iterate through the dataset
    dataset = get_dataset(args.dataset, args.split)

    cells = 1
    for x in get_dataset_shape(args.dataset):
        cells *= x

    # init transformers
    tinst = None
    tfunc = None
    tinst2 = None
    tfunc2 = None
    if args.transtype == 'brightness':
        tinst = T.BrightnessShift(0.0)
        tfunc = T.BrightnessShift.proc
    elif args.transtype == 'contrast':
        # note: contrast is in exponential scale
        tinst = T.BrightnessScale(0.0)
        tfunc = T.BrightnessScale.proc
    elif args.transtype == 'brightness-contrast':
        tinst = T.BrightnessShift(0.0)
        tfunc = T.BrightnessShift.proc
        tinst2 = T.BrightnessScale(0.0)
        tfunc2 = T.BrightnessScale.proc
    elif args.transtype == 'rotation':
        tinst = T.Rotation(dataset[0][0], 0.0)
        # tfunc = T.Rotation.proc
        tfunc = T.Rotation.raw_proc
    elif args.transtype == 'resize':
nums = [0, 20, 40, 60, 80, 100]

if __name__ == '__main__':

    for dataset in datasets:

        ds = get_dataset(dataset, 'test')

        canopy = ds[0][0]

        # init transformers
        noiseT = transforms.Noise(sigma=0.5)
        rotationT = transforms.Rotation(canopy, rotation_angle=180.0)
        translationT = transforms.Translational(canopy, sigma=5.0)
        blackTranslationT = transforms.BlackTranslational(canopy, sigma=5.0)
        brightnessShiftT = transforms.BrightnessShift(sigma=0.1)
        brightnessScaleT = transforms.BrightnessScale(sigma=0.1)
        sizeScaleT = transforms.Resize(canopy, sl=0.5, sr=5.0)
        gaussianT = transforms.Gaussian(sigma=5.0)

        for num in nums:
            print(dataset, num)
            transforms.visualize(ds[num][0], f'visualize/{dataset}/{num}.png')
            # rotation
            angles = [-10, 30, 70]
            for angle in angles:
                transforms.visualize(
                    rotationT.masking(rotationT.raw_proc(ds[num][0], angle)),
                    f'visualize/{dataset}/{num}_rot_{angle}.png')
            transforms.visualize(
                rotationT.masking(
def gen_transform_and_params(args, canopy):

    # init transformers
    tinst1 = None
    tfunc1 = None
    tinst2 = None
    tfunc2 = None
    if args.transtype == 'gaussian':
        tinst1 = TR.Gaussian(0.0)
        tfunc1 = TR.Gaussian.proc
    elif args.transtype == 'translation':
        tinst1 = TR.Translational(canopy, 0.0)
        tfunc1 = TR.Translational.proc
    elif args.transtype == 'brightness':
        tinst1 = TR.BrightnessShift(0.0)
        tfunc1 = TR.BrightnessShift.proc
    elif args.transtype == 'brightness-contrast':
        # note: contrast is in exponential scale
        tinst1 = TR.BrightnessShift(0.0)
        tfunc1 = TR.BrightnessShift.proc
        tinst2 = TR.BrightnessScale(0.0)
        tfunc2 = TR.BrightnessScale.proc
    elif args.transtype == 'rotation':
        tinst1 = TR.Rotation(canopy, 0.0)
        tfunc1 = TR.Rotation.raw_proc
    elif args.transtype == 'scaling':
        # note: resize is in original scale
        tinst1 = TR.Resize(canopy, 1.0, 1.0)
        tfunc1 = TR.Resize.proc
    elif args.transtype == 'rotation-brightness' or args.transtype == 'rotation-brightness-l2':
        tinst1 = TR.Rotation(canopy, 0.0)
        tfunc1 = TR.Rotation.raw_proc
        tinst2 = TR.BrightnessShift(0.0)
        tfunc2 = TR.BrightnessShift.proc
    elif args.transtype == 'scaling-brightness' or args.transtype == 'scaling-brightness-l2':
        # note: resize is in original scale
        tinst1 = TR.Resize(canopy, 1.0, 1.0)
        tfunc1 = TR.Resize.proc
        tinst2 = TR.BrightnessShift(0.0)
        tfunc2 = TR.BrightnessShift.proc

    # random generator
    param1l, param1r, param2l, param2r, candidates = None, None, None, None, None
    if args.transtype == 'gaussian':
        param1l, param1r = 0.0, args.blur_alpha
    elif args.transtype == 'translation':
        candidates = torch.tensor(
            list(
                set([(x, y) for x in range(int(args.displacement) + 1)
                     for y in range(int(args.displacement) + 1)
                     if float(x * x + y * y) <= args.displacement *
                     args.displacement] +
                    [(x, -y) for x in range(int(args.displacement) + 1)
                     for y in range(int(args.displacement) + 1)
                     if float(x * x + y * y) <= args.displacement *
                     args.displacement] +
                    [(-x, y) for x in range(int(args.displacement) + 1)
                     for y in range(int(args.displacement) + 1)
                     if float(x * x + y * y) <= args.displacement *
                     args.displacement] +
                    [(-x, -y) for x in range(int(args.displacement) + 1)
                     for y in range(int(args.displacement) + 1)
                     if float(x * x + y * y) <= args.displacement *
                     args.displacement])))
        c_len = candidates.shape[0]
        param1l, param1r = 0.0, c_len
    elif args.transtype == 'brightness':
        param1l, param1r = -args.b, +args.b
    elif args.transtype == 'brightness-contrast':
        param1l, param1r = -args.b, +args.b
        param2l, param2r = math.log(1.0 - args.k), math.log(1.0 + args.k)
    elif args.transtype == 'rotation':
        param1l, param1r = -args.r, +args.r
    elif args.transtype == 'scaling':
        param1l, param1r = 1.0 - args.s, 1.0 + args.s
    elif args.transtype == 'rotation-brightness' or args.transtype == 'rotation-brightness-l2':
        param1l, param1r = -args.r, +args.r
        param2l, param2r = -args.b, +args.b
    elif args.transtype == 'scaling-brightness' or args.transtype == 'scaling-brightness-l2':
        param1l, param1r = 1.0 - args.s, 1.0 + args.s
        param2l, param2r = -args.b, +args.b
    else:
        raise Exception(f'Unknown transtype: {args.transtype}')

    print(f"""
    param1: [{param1l}, {param1r}]
    param2: [{param2l}, {param2r}]
""")
    return tinst1, tfunc1, tinst2, tfunc2, param1l, param1r, param2l, param2r, candidates