def _meshgrid(height, width, coord): x_t = ptcompat.torch_tile_nd( ptcompat.torch_reshape( torch.linspace(-1.0, 1.0, width, device=coord.device), [1, width] ), [height, 1], ) y_t = ptcompat.torch_tile_nd( ptcompat.torch_reshape( torch.linspace(-1.0, 1.0, height, device=coord.device), [height, 1] ), [1, width], ) x_t_flat = ptcompat.torch_reshape(x_t, (1, 1, -1)) y_t_flat = ptcompat.torch_reshape(y_t, (1, 1, -1)) px = torch.unsqueeze(coord[:, :, 0], 2) # [bn, pn, 1] py = torch.unsqueeze(coord[:, :, 1], 2) # [bn, pn, 1] d2 = (x_t_flat - px) ** 2 + (y_t_flat - py) ** 2 r = d2 * torch.log(d2 + 1.0e-6) # [bn, pn, h*w] x_t_flat_g = ptcompat.torch_tile_nd(x_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] y_t_flat_g = ptcompat.torch_tile_nd(y_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] ones = torch.ones_like(x_t_flat_g, device=x_t_flat_g.device) # [bn, 1, h*w] grid = torch.cat([ones, x_t_flat_g, y_t_flat_g, r], 1) # [bn, 3+pn, h*w] return grid
def _solve_system(coord, vector): ones = torch.ones([num_batch, num_point, 1], dtype=torch.float32, device=coord.device) p = torch.cat([ones, coord], 2) # [bn, pn, 3] p_1 = ptcompat.torch_reshape(p, [num_batch, -1, 1, 3]) # [bn, pn, 1, 3] p_2 = ptcompat.torch_reshape(p, [num_batch, 1, -1, 3]) # [bn, 1, pn, 3] d2 = torch.sum((p_1 - p_2)**2, 3) # [bn, pn, pn] r = d2 * torch.log(d2 + 1.0e-6) # Kernel [bn, pn, pn] zeros = torch.zeros([num_batch, 3, 3], dtype=torch.float32, device=coord.device) W_0 = torch.cat([p, r], 2) # [bn, pn, 3+pn] W_1 = torch.cat([zeros, p.permute((0, 2, 1))], 2) # [bn, 3, pn+3] W = torch.cat([W_0, W_1], 1) # [bn, pn+3, pn+3] W_inv = torch.inverse(W) tp = torch.nn.functional.pad(coord + vector, (0, 0, 0, 3, 0, 0), mode="constant") # [bn, pn+3, 2] T = torch.matmul(W_inv, tp) T = T.permute([0, 2, 1]) return T
def _repeat(x, n_repeats): rep = torch.unsqueeze( torch.ones(torch.stack([torch.tensor([n_repeats])])), dim=1 ) rep = rep.permute([1, 0]) rep = ptcompat.torch_astype(rep, torch.int32) x = torch.matmul(ptcompat.torch_reshape(x, (-1, 1)), rep) return ptcompat.torch_reshape(x, [-1])
def static_param_2d(param): bn, d_1 = ptnn.shape_as_list(param) param = param[::2] param = ptcompat.torch_tile_nd(param, [1, 2]) param = ptcompat.torch_reshape(param, [bn, d_1]) return param
def grams(self, fs): gs = list() for f in fs: bs, c, h, w = list(f.shape) f = ptcompat.torch_reshape(f, [bs, c, h * w]) ft = f.permute([0, 2, 1]) g = torch.matmul(f, ft) g = g / (4.0 * h * w) gs.append(g) return gs
def _interpolate(im, y, x): # constants y = ptcompat.torch_astype(y, torch.float32) x = ptcompat.torch_astype(x, torch.float32) zero = torch.zeros([], dtype=torch.int32, device=im.device) max_y = int(height - 1) max_x = int(width - 1) # scale indices from aprox [-1, 1] to [0, width/height] y = (y + 1) * height_f / 2.0 x = (x + 1) * width_f / 2.0 y = ptcompat.torch_reshape(y, [-1]) x = ptcompat.torch_reshape(x, [-1]) # do sampling y0 = ptcompat.torch_astype(torch.floor(y), torch.int32) y1 = y0 + 1 x0 = ptcompat.torch_astype(torch.floor(x), torch.int32) x1 = x0 + 1 y0 = y0.clamp(zero, max_y) y1 = y1.clamp(zero, max_y) x0 = x0.clamp(zero, max_x) x1 = x1.clamp(zero, max_x) base = _repeat( torch.range(start=0, end=num_batch - 1, dtype=torch.int32) * width * height, out_height * out_width, ) base = base.to(im.device) base_y0 = base + y0 * width base_y1 = base + y1 * width idx_a = base_y0 + x0 idx_b = base_y1 + x0 idx_c = base_y0 + x1 idx_d = base_y1 + x1 # use indices to lookup pixels in the flat image and restore # channels dim im_flat = ptcompat.torch_reshape(im, [-1, channels]) # im_flat = tf.reshape(im, [-1, channels]) im_flat = ptcompat.torch_astype(im_flat, torch.float32) Ia = ptcompat.torch_gather(im_flat, idx_a) Ib = ptcompat.torch_gather(im_flat, idx_b) Ic = ptcompat.torch_gather(im_flat, idx_c) Id = ptcompat.torch_gather(im_flat, idx_d) # and finally calculate interpolated values x0_f = ptcompat.torch_astype(x0, torch.float32) x1_f = ptcompat.torch_astype(x1, torch.float32) y0_f = ptcompat.torch_astype(y0, torch.float32) y1_f = ptcompat.torch_astype(y1, torch.float32) wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1) wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1) wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1) wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1) output = wa * Ia + wb * Ib + wc * Ic + wd * Id return output
def ThinPlateSpline(U, coord, vector, out_size, n_c, move=None, scal=None): # https://github.com/agrimgupta92/sgan/issues/22 U = U.permute((0, 2, 3, 1)).contiguous() # NCHW -> NHWC coord = ptnn.flip(coord, -1) vector = ptnn.flip(vector, -1) num_batch, height, width, _ = ptnn.shape_as_list(U) channels = n_c out_height = out_size out_width = out_size height_f = float(height) width_f = float(width) num_point = ptnn.shape_as_list(coord)[1] def _repeat(x, n_repeats): rep = torch.unsqueeze( torch.ones(torch.stack([torch.tensor([n_repeats])])), dim=1 ) rep = rep.permute([1, 0]) rep = ptcompat.torch_astype(rep, torch.int32) x = torch.matmul(ptcompat.torch_reshape(x, (-1, 1)), rep) return ptcompat.torch_reshape(x, [-1]) def _interpolate(im, y, x): # constants y = ptcompat.torch_astype(y, torch.float32) x = ptcompat.torch_astype(x, torch.float32) zero = torch.zeros([], dtype=torch.int32, device=im.device) max_y = int(height - 1) max_x = int(width - 1) # scale indices from aprox [-1, 1] to [0, width/height] y = (y + 1) * height_f / 2.0 x = (x + 1) * width_f / 2.0 y = ptcompat.torch_reshape(y, [-1]) x = ptcompat.torch_reshape(x, [-1]) # do sampling y0 = ptcompat.torch_astype(torch.floor(y), torch.int32) y1 = y0 + 1 x0 = ptcompat.torch_astype(torch.floor(x), torch.int32) x1 = x0 + 1 y0 = y0.clamp(zero, max_y) y1 = y1.clamp(zero, max_y) x0 = x0.clamp(zero, max_x) x1 = x1.clamp(zero, max_x) base = _repeat( torch.range(start=0, end=num_batch - 1, dtype=torch.int32) * width * height, out_height * out_width, ) base = base.to(im.device) base_y0 = base + y0 * width base_y1 = base + y1 * width idx_a = base_y0 + x0 idx_b = base_y1 + x0 idx_c = base_y0 + x1 idx_d = base_y1 + x1 # use indices to lookup pixels in the flat image and restore # channels dim im_flat = ptcompat.torch_reshape(im, [-1, channels]) # im_flat = tf.reshape(im, [-1, channels]) im_flat = ptcompat.torch_astype(im_flat, torch.float32) Ia = ptcompat.torch_gather(im_flat, idx_a) Ib = ptcompat.torch_gather(im_flat, idx_b) Ic = ptcompat.torch_gather(im_flat, idx_c) Id = ptcompat.torch_gather(im_flat, idx_d) # and finally calculate interpolated values x0_f = ptcompat.torch_astype(x0, torch.float32) x1_f = ptcompat.torch_astype(x1, torch.float32) y0_f = ptcompat.torch_astype(y0, torch.float32) y1_f = ptcompat.torch_astype(y1, torch.float32) wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1) wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1) wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1) wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1) output = wa * Ia + wb * Ib + wc * Ic + wd * Id return output def _meshgrid(height, width, coord): x_t = ptcompat.torch_tile_nd( ptcompat.torch_reshape( torch.linspace(-1.0, 1.0, width, device=coord.device), [1, width] ), [height, 1], ) y_t = ptcompat.torch_tile_nd( ptcompat.torch_reshape( torch.linspace(-1.0, 1.0, height, device=coord.device), [height, 1] ), [1, width], ) x_t_flat = ptcompat.torch_reshape(x_t, (1, 1, -1)) y_t_flat = ptcompat.torch_reshape(y_t, (1, 1, -1)) px = torch.unsqueeze(coord[:, :, 0], 2) # [bn, pn, 1] py = torch.unsqueeze(coord[:, :, 1], 2) # [bn, pn, 1] d2 = (x_t_flat - px) ** 2 + (y_t_flat - py) ** 2 r = d2 * torch.log(d2 + 1.0e-6) # [bn, pn, h*w] x_t_flat_g = ptcompat.torch_tile_nd(x_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] y_t_flat_g = ptcompat.torch_tile_nd(y_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] ones = torch.ones_like(x_t_flat_g, device=x_t_flat_g.device) # [bn, 1, h*w] grid = torch.cat([ones, x_t_flat_g, y_t_flat_g, r], 1) # [bn, 3+pn, h*w] return grid def _transform(T, coord, move, scal): # grid of (x_t, y_t, 1), eq (1) in ref [1] grid = _meshgrid(out_height, out_width, coord) # [bn, 3+pn, h*w] # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s) # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w] T_g = torch.matmul(T, grid) # x_s = ptcompat.torch_slice(T_g, [0, 0, 0], [-1, 1, -1]) # y_s = ptcompat.torch_slice(T_g, [0, 1, 0], [-1, 1, -1]) x_s = T_g[:, 0, :] y_s = T_g[:, 1, :] if move is not None and scal is not None: off_y = torch.unsqueeze(move[:, :, 0], dim=-1) off_x = torch.unsqueeze(move[:, :, 1], dims=-1) scal_y = torch.unsqueeze(torch.unsqueeze(scal[:, 0], dim=-1), dim=-1) scal_x = torch.unsqueeze(torch.unsqueeze(scal[:, 1], dim=-1), dim=-1) y = y_s * scal_y + off_y x = x_s * scal_x + off_x else: assert move is None and scal is None y = y_s x = x_s return y, x def _solve_system(coord, vector): ones = torch.ones( [num_batch, num_point, 1], dtype=torch.float32, device=coord.device ) p = torch.cat([ones, coord], 2) # [bn, pn, 3] p_1 = ptcompat.torch_reshape(p, [num_batch, -1, 1, 3]) # [bn, pn, 1, 3] p_2 = ptcompat.torch_reshape(p, [num_batch, 1, -1, 3]) # [bn, 1, pn, 3] d2 = torch.sum((p_1 - p_2) ** 2, 3) # [bn, pn, pn] r = d2 * torch.log(d2 + 1.0e-6) # Kernel [bn, pn, pn] zeros = torch.zeros([num_batch, 3, 3], dtype=torch.float32, device=coord.device) W_0 = torch.cat([p, r], 2) # [bn, pn, 3+pn] W_1 = torch.cat([zeros, p.permute((0, 2, 1))], 2) # [bn, 3, pn+3] W = torch.cat([W_0, W_1], 1) # [bn, pn+3, pn+3] W_inv = torch.inverse(W) tp = torch.nn.functional.pad( coord + vector, (0, 0, 0, 3, 0, 0), mode="constant" ) # [bn, pn+3, 2] T = torch.matmul(W_inv, tp) T = T.permute([0, 2, 1]) return T T = _solve_system(coord, vector) y, x = _transform(T, coord, move, scal) input_transformed = _interpolate(U, y, x) output = ptcompat.torch_reshape( input_transformed, [num_batch, out_height, out_width, channels] ) y = ptcompat.torch_reshape(y, [num_batch, out_height, out_width, 1]) x = ptcompat.torch_reshape(x, [num_batch, out_height, out_width, 1]) t_arr = torch.cat([y, x], dim=-1) output = output.permute((0, 3, 1, 2)) # NHWC --> NCHW t_arr = t_arr.permute((0, 3, 1, 2)) # NHWC --> NCHW return output, t_arr