def create_inputs(rng, batch_size, channels, size_out, size_inp, align_corners): if len(size_out) == 2: inp = rng.randn(batch_size, channels, size_inp[0], size_inp[1]).astype(np.float32) affine = generate_transformation_2d(rng, batch_size) grid_s = affine_grid_2d(affine, size_out, align_corners) elif len(size_out) == 3: inp = rng.randn(batch_size, channels, size_inp[0], size_inp[1], size_inp[2]).astype(np.float32) affine = generate_transformation_3d(rng, batch_size) grid_s = affine_grid_3d(affine, size_out, align_corners) return inp, grid_s
def ref_affine_grid(theta, size, align_corners): if len(size) == 2: grid_s = affine_grid_2d(theta, size, align_corners) elif len(size) == 3: grid_s = affine_grid_3d(theta, size, align_corners) return grid_s