def _meshgrid(height, width, coord): x_t = ptcompat.torch_tile_nd( ptcompat.torch_reshape( torch.linspace(-1.0, 1.0, width, device=coord.device), [1, width] ), [height, 1], ) y_t = ptcompat.torch_tile_nd( ptcompat.torch_reshape( torch.linspace(-1.0, 1.0, height, device=coord.device), [height, 1] ), [1, width], ) x_t_flat = ptcompat.torch_reshape(x_t, (1, 1, -1)) y_t_flat = ptcompat.torch_reshape(y_t, (1, 1, -1)) px = torch.unsqueeze(coord[:, :, 0], 2) # [bn, pn, 1] py = torch.unsqueeze(coord[:, :, 1], 2) # [bn, pn, 1] d2 = (x_t_flat - px) ** 2 + (y_t_flat - py) ** 2 r = d2 * torch.log(d2 + 1.0e-6) # [bn, pn, h*w] x_t_flat_g = ptcompat.torch_tile_nd(x_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] y_t_flat_g = ptcompat.torch_tile_nd(y_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] ones = torch.ones_like(x_t_flat_g, device=x_t_flat_g.device) # [bn, 1, h*w] grid = torch.cat([ones, x_t_flat_g, y_t_flat_g, r], 1) # [bn, 3+pn, h*w] return grid
def static_param_2d(param): bn, d_1 = ptnn.shape_as_list(param) param = param[::2] param = ptcompat.torch_tile_nd(param, [1, 2]) param = ptcompat.torch_reshape(param, [bn, d_1]) return param
def tps_parameters( batch_size, scal, tps_scal, rot_scal, off_scal, scal_var, rescal=1, augm_scal=1.0 ): coord = torch.tensor( [ [ [-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], [0.2, -0.2], [-0.2, 0.2], [0.2, 0.2], [-0.2, -0.2], ] ], dtype=torch.float32, ) coord = ptcompat.torch_tile_nd(coord, [batch_size, 1, 1]) shape = ptnn.shape_as_list(coord) coord = coord + ptcompat.torch_random_uniform(shape, -0.2, 0.2, dtype=torch.float32) vector = ptcompat.torch_random_uniform( shape, -tps_scal, tps_scal, dtype=torch.float32 ) offset = ptcompat.torch_random_uniform( [batch_size, 1, 2], -off_scal, off_scal, dtype=torch.float32 ) offset_2 = ptcompat.torch_random_uniform( [batch_size, 1, 2], -off_scal, off_scal, dtype=torch.float32 ) t_scal = ptcompat.torch_random_uniform( [batch_size, 2], scal * (1.0 - scal_var), scal * (1.0 + scal_var), dtype=torch.float32, ) t_scal = t_scal * rescal rot_param = ptcompat.torch_random_uniform( [batch_size, 1], -rot_scal, rot_scal, dtype=torch.float32 ) rot_mat = torch.stack([pt_rotation_matrix(r) for r in rot_param], dim=0) parameter_dict = { "coord": coord, "vector": vector, "offset": offset, "offset_2": offset_2, "t_scal": t_scal, "rot_mat": rot_mat, "augm_scal": augm_scal, } if torch.cuda.is_available(): for k, v in parameter_dict.items(): if isinstance(v, torch.Tensor): parameter_dict[k] = v.cuda() return parameter_dict
def test_torch_gather_nd(): import skimage bs = 4 nk = 16 image_t = torch.from_numpy(skimage.data.astronaut()) params = ptcompat.torch_tile_nd(image_t.view( (1, 1, 512, 512, 3)), [bs, nk, 1, 1, 1]) # batch of stack of images indices = torch.stack( torch.meshgrid( torch.arange(bs), torch.arange(nk), torch.arange(128), torch.arange(128), torch.arange(3), ), dim=-1, ) # get 128 x 128 image slice from each item out = ptcompat.torch_gather_nd(params, indices) assert out.shape == (4, 16, 128, 128, 3)