示例#1
0
    def _meshgrid(height, width, coord):
        x_t = ptcompat.torch_tile_nd(
            ptcompat.torch_reshape(
                torch.linspace(-1.0, 1.0, width, device=coord.device), [1, width]
            ),
            [height, 1],
        )
        y_t = ptcompat.torch_tile_nd(
            ptcompat.torch_reshape(
                torch.linspace(-1.0, 1.0, height, device=coord.device), [height, 1]
            ),
            [1, width],
        )
        x_t_flat = ptcompat.torch_reshape(x_t, (1, 1, -1))
        y_t_flat = ptcompat.torch_reshape(y_t, (1, 1, -1))

        px = torch.unsqueeze(coord[:, :, 0], 2)  # [bn, pn, 1]
        py = torch.unsqueeze(coord[:, :, 1], 2)  # [bn, pn, 1]

        d2 = (x_t_flat - px) ** 2 + (y_t_flat - py) ** 2
        r = d2 * torch.log(d2 + 1.0e-6)  # [bn, pn, h*w]

        x_t_flat_g = ptcompat.torch_tile_nd(x_t_flat, [num_batch, 1, 1])  # [bn, 1, h*w]
        y_t_flat_g = ptcompat.torch_tile_nd(y_t_flat, [num_batch, 1, 1])  # [bn, 1, h*w]
        ones = torch.ones_like(x_t_flat_g, device=x_t_flat_g.device)  # [bn, 1, h*w]

        grid = torch.cat([ones, x_t_flat_g, y_t_flat_g, r], 1)  # [bn, 3+pn, h*w]
        return grid
示例#2
0
    def _solve_system(coord, vector):
        ones = torch.ones([num_batch, num_point, 1],
                          dtype=torch.float32,
                          device=coord.device)
        p = torch.cat([ones, coord], 2)  # [bn, pn, 3]

        p_1 = ptcompat.torch_reshape(p,
                                     [num_batch, -1, 1, 3])  # [bn, pn, 1, 3]
        p_2 = ptcompat.torch_reshape(p,
                                     [num_batch, 1, -1, 3])  # [bn, 1, pn, 3]
        d2 = torch.sum((p_1 - p_2)**2, 3)  # [bn, pn, pn]
        r = d2 * torch.log(d2 + 1.0e-6)  # Kernel [bn, pn, pn]

        zeros = torch.zeros([num_batch, 3, 3],
                            dtype=torch.float32,
                            device=coord.device)
        W_0 = torch.cat([p, r], 2)  # [bn, pn, 3+pn]
        W_1 = torch.cat([zeros, p.permute((0, 2, 1))], 2)  # [bn, 3, pn+3]
        W = torch.cat([W_0, W_1], 1)  # [bn, pn+3, pn+3]
        W_inv = torch.inverse(W)

        tp = torch.nn.functional.pad(coord + vector, (0, 0, 0, 3, 0, 0),
                                     mode="constant")  # [bn, pn+3, 2]
        T = torch.matmul(W_inv, tp)
        T = T.permute([0, 2, 1])
        return T
示例#3
0
 def _repeat(x, n_repeats):
     rep = torch.unsqueeze(
         torch.ones(torch.stack([torch.tensor([n_repeats])])), dim=1
     )
     rep = rep.permute([1, 0])
     rep = ptcompat.torch_astype(rep, torch.int32)
     x = torch.matmul(ptcompat.torch_reshape(x, (-1, 1)), rep)
     return ptcompat.torch_reshape(x, [-1])
示例#4
0
def static_param_2d(param):
    bn, d_1 = ptnn.shape_as_list(param)
    param = param[::2]
    param = ptcompat.torch_tile_nd(param, [1, 2])
    param = ptcompat.torch_reshape(param, [bn, d_1])

    return param
示例#5
0
 def grams(self, fs):
     gs = list()
     for f in fs:
         bs, c, h, w = list(f.shape)
         f = ptcompat.torch_reshape(f, [bs, c, h * w])
         ft = f.permute([0, 2, 1])
         g = torch.matmul(f, ft)
         g = g / (4.0 * h * w)
         gs.append(g)
     return gs
示例#6
0
    def _interpolate(im, y, x):
        # constants
        y = ptcompat.torch_astype(y, torch.float32)
        x = ptcompat.torch_astype(x, torch.float32)

        zero = torch.zeros([], dtype=torch.int32, device=im.device)
        max_y = int(height - 1)
        max_x = int(width - 1)

        # scale indices from aprox [-1, 1] to [0, width/height]

        y = (y + 1) * height_f / 2.0
        x = (x + 1) * width_f / 2.0

        y = ptcompat.torch_reshape(y, [-1])
        x = ptcompat.torch_reshape(x, [-1])

        # do sampling
        y0 = ptcompat.torch_astype(torch.floor(y), torch.int32)
        y1 = y0 + 1
        x0 = ptcompat.torch_astype(torch.floor(x), torch.int32)
        x1 = x0 + 1

        y0 = y0.clamp(zero, max_y)
        y1 = y1.clamp(zero, max_y)
        x0 = x0.clamp(zero, max_x)
        x1 = x1.clamp(zero, max_x)

        base = _repeat(
            torch.range(start=0, end=num_batch - 1, dtype=torch.int32) * width * height,
            out_height * out_width,
        )
        base = base.to(im.device)
        base_y0 = base + y0 * width
        base_y1 = base + y1 * width
        idx_a = base_y0 + x0
        idx_b = base_y1 + x0
        idx_c = base_y0 + x1
        idx_d = base_y1 + x1

        # use indices to lookup pixels in the flat image and restore
        # channels dim
        im_flat = ptcompat.torch_reshape(im, [-1, channels])
        # im_flat = tf.reshape(im, [-1, channels])
        im_flat = ptcompat.torch_astype(im_flat, torch.float32)
        Ia = ptcompat.torch_gather(im_flat, idx_a)
        Ib = ptcompat.torch_gather(im_flat, idx_b)
        Ic = ptcompat.torch_gather(im_flat, idx_c)
        Id = ptcompat.torch_gather(im_flat, idx_d)

        # and finally calculate interpolated values
        x0_f = ptcompat.torch_astype(x0, torch.float32)
        x1_f = ptcompat.torch_astype(x1, torch.float32)
        y0_f = ptcompat.torch_astype(y0, torch.float32)
        y1_f = ptcompat.torch_astype(y1, torch.float32)

        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
        output = wa * Ia + wb * Ib + wc * Ic + wd * Id
        return output
示例#7
0
def ThinPlateSpline(U, coord, vector, out_size, n_c, move=None, scal=None):
    # https://github.com/agrimgupta92/sgan/issues/22
    U = U.permute((0, 2, 3, 1)).contiguous()  # NCHW -> NHWC

    coord = ptnn.flip(coord, -1)
    vector = ptnn.flip(vector, -1)
    num_batch, height, width, _ = ptnn.shape_as_list(U)
    channels = n_c
    out_height = out_size
    out_width = out_size
    height_f = float(height)
    width_f = float(width)
    num_point = ptnn.shape_as_list(coord)[1]

    def _repeat(x, n_repeats):
        rep = torch.unsqueeze(
            torch.ones(torch.stack([torch.tensor([n_repeats])])), dim=1
        )
        rep = rep.permute([1, 0])
        rep = ptcompat.torch_astype(rep, torch.int32)
        x = torch.matmul(ptcompat.torch_reshape(x, (-1, 1)), rep)
        return ptcompat.torch_reshape(x, [-1])

    def _interpolate(im, y, x):
        # constants
        y = ptcompat.torch_astype(y, torch.float32)
        x = ptcompat.torch_astype(x, torch.float32)

        zero = torch.zeros([], dtype=torch.int32, device=im.device)
        max_y = int(height - 1)
        max_x = int(width - 1)

        # scale indices from aprox [-1, 1] to [0, width/height]

        y = (y + 1) * height_f / 2.0
        x = (x + 1) * width_f / 2.0

        y = ptcompat.torch_reshape(y, [-1])
        x = ptcompat.torch_reshape(x, [-1])

        # do sampling
        y0 = ptcompat.torch_astype(torch.floor(y), torch.int32)
        y1 = y0 + 1
        x0 = ptcompat.torch_astype(torch.floor(x), torch.int32)
        x1 = x0 + 1

        y0 = y0.clamp(zero, max_y)
        y1 = y1.clamp(zero, max_y)
        x0 = x0.clamp(zero, max_x)
        x1 = x1.clamp(zero, max_x)

        base = _repeat(
            torch.range(start=0, end=num_batch - 1, dtype=torch.int32) * width * height,
            out_height * out_width,
        )
        base = base.to(im.device)
        base_y0 = base + y0 * width
        base_y1 = base + y1 * width
        idx_a = base_y0 + x0
        idx_b = base_y1 + x0
        idx_c = base_y0 + x1
        idx_d = base_y1 + x1

        # use indices to lookup pixels in the flat image and restore
        # channels dim
        im_flat = ptcompat.torch_reshape(im, [-1, channels])
        # im_flat = tf.reshape(im, [-1, channels])
        im_flat = ptcompat.torch_astype(im_flat, torch.float32)
        Ia = ptcompat.torch_gather(im_flat, idx_a)
        Ib = ptcompat.torch_gather(im_flat, idx_b)
        Ic = ptcompat.torch_gather(im_flat, idx_c)
        Id = ptcompat.torch_gather(im_flat, idx_d)

        # and finally calculate interpolated values
        x0_f = ptcompat.torch_astype(x0, torch.float32)
        x1_f = ptcompat.torch_astype(x1, torch.float32)
        y0_f = ptcompat.torch_astype(y0, torch.float32)
        y1_f = ptcompat.torch_astype(y1, torch.float32)

        wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
        wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
        wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
        wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
        output = wa * Ia + wb * Ib + wc * Ic + wd * Id
        return output

    def _meshgrid(height, width, coord):
        x_t = ptcompat.torch_tile_nd(
            ptcompat.torch_reshape(
                torch.linspace(-1.0, 1.0, width, device=coord.device), [1, width]
            ),
            [height, 1],
        )
        y_t = ptcompat.torch_tile_nd(
            ptcompat.torch_reshape(
                torch.linspace(-1.0, 1.0, height, device=coord.device), [height, 1]
            ),
            [1, width],
        )
        x_t_flat = ptcompat.torch_reshape(x_t, (1, 1, -1))
        y_t_flat = ptcompat.torch_reshape(y_t, (1, 1, -1))

        px = torch.unsqueeze(coord[:, :, 0], 2)  # [bn, pn, 1]
        py = torch.unsqueeze(coord[:, :, 1], 2)  # [bn, pn, 1]

        d2 = (x_t_flat - px) ** 2 + (y_t_flat - py) ** 2
        r = d2 * torch.log(d2 + 1.0e-6)  # [bn, pn, h*w]

        x_t_flat_g = ptcompat.torch_tile_nd(x_t_flat, [num_batch, 1, 1])  # [bn, 1, h*w]
        y_t_flat_g = ptcompat.torch_tile_nd(y_t_flat, [num_batch, 1, 1])  # [bn, 1, h*w]
        ones = torch.ones_like(x_t_flat_g, device=x_t_flat_g.device)  # [bn, 1, h*w]

        grid = torch.cat([ones, x_t_flat_g, y_t_flat_g, r], 1)  # [bn, 3+pn, h*w]
        return grid

    def _transform(T, coord, move, scal):
        # grid of (x_t, y_t, 1), eq (1) in ref [1]
        grid = _meshgrid(out_height, out_width, coord)  # [bn, 3+pn, h*w]

        # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)
        # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]
        T_g = torch.matmul(T, grid)
        # x_s = ptcompat.torch_slice(T_g, [0, 0, 0], [-1, 1, -1])
        # y_s = ptcompat.torch_slice(T_g, [0, 1, 0], [-1, 1, -1])
        x_s = T_g[:, 0, :]
        y_s = T_g[:, 1, :]

        if move is not None and scal is not None:
            off_y = torch.unsqueeze(move[:, :, 0], dim=-1)
            off_x = torch.unsqueeze(move[:, :, 1], dims=-1)
            scal_y = torch.unsqueeze(torch.unsqueeze(scal[:, 0], dim=-1), dim=-1)
            scal_x = torch.unsqueeze(torch.unsqueeze(scal[:, 1], dim=-1), dim=-1)
            y = y_s * scal_y + off_y
            x = x_s * scal_x + off_x

        else:
            assert move is None and scal is None
            y = y_s
            x = x_s

        return y, x

    def _solve_system(coord, vector):
        ones = torch.ones(
            [num_batch, num_point, 1], dtype=torch.float32, device=coord.device
        )
        p = torch.cat([ones, coord], 2)  # [bn, pn, 3]

        p_1 = ptcompat.torch_reshape(p, [num_batch, -1, 1, 3])  # [bn, pn, 1, 3]
        p_2 = ptcompat.torch_reshape(p, [num_batch, 1, -1, 3])  # [bn, 1, pn, 3]
        d2 = torch.sum((p_1 - p_2) ** 2, 3)  # [bn, pn, pn]
        r = d2 * torch.log(d2 + 1.0e-6)  # Kernel [bn, pn, pn]

        zeros = torch.zeros([num_batch, 3, 3], dtype=torch.float32, device=coord.device)
        W_0 = torch.cat([p, r], 2)  # [bn, pn, 3+pn]
        W_1 = torch.cat([zeros, p.permute((0, 2, 1))], 2)  # [bn, 3, pn+3]
        W = torch.cat([W_0, W_1], 1)  # [bn, pn+3, pn+3]
        W_inv = torch.inverse(W)

        tp = torch.nn.functional.pad(
            coord + vector, (0, 0, 0, 3, 0, 0), mode="constant"
        )  # [bn, pn+3, 2]
        T = torch.matmul(W_inv, tp)
        T = T.permute([0, 2, 1])
        return T

    T = _solve_system(coord, vector)
    y, x = _transform(T, coord, move, scal)
    input_transformed = _interpolate(U, y, x)
    output = ptcompat.torch_reshape(
        input_transformed, [num_batch, out_height, out_width, channels]
    )
    y = ptcompat.torch_reshape(y, [num_batch, out_height, out_width, 1])
    x = ptcompat.torch_reshape(x, [num_batch, out_height, out_width, 1])
    t_arr = torch.cat([y, x], dim=-1)
    output = output.permute((0, 3, 1, 2))  # NHWC --> NCHW
    t_arr = t_arr.permute((0, 3, 1, 2))  # NHWC --> NCHW
    return output, t_arr