Ejemplo n.º 1
0
    def _meshgrid(height, width, coord):
        x_t = ptcompat.torch_tile_nd(
            ptcompat.torch_reshape(
                torch.linspace(-1.0, 1.0, width, device=coord.device), [1, width]
            ),
            [height, 1],
        )
        y_t = ptcompat.torch_tile_nd(
            ptcompat.torch_reshape(
                torch.linspace(-1.0, 1.0, height, device=coord.device), [height, 1]
            ),
            [1, width],
        )
        x_t_flat = ptcompat.torch_reshape(x_t, (1, 1, -1))
        y_t_flat = ptcompat.torch_reshape(y_t, (1, 1, -1))

        px = torch.unsqueeze(coord[:, :, 0], 2)  # [bn, pn, 1]
        py = torch.unsqueeze(coord[:, :, 1], 2)  # [bn, pn, 1]

        d2 = (x_t_flat - px) ** 2 + (y_t_flat - py) ** 2
        r = d2 * torch.log(d2 + 1.0e-6)  # [bn, pn, h*w]

        x_t_flat_g = ptcompat.torch_tile_nd(x_t_flat, [num_batch, 1, 1])  # [bn, 1, h*w]
        y_t_flat_g = ptcompat.torch_tile_nd(y_t_flat, [num_batch, 1, 1])  # [bn, 1, h*w]
        ones = torch.ones_like(x_t_flat_g, device=x_t_flat_g.device)  # [bn, 1, h*w]

        grid = torch.cat([ones, x_t_flat_g, y_t_flat_g, r], 1)  # [bn, 3+pn, h*w]
        return grid
Ejemplo n.º 2
0
def static_param_2d(param):
    bn, d_1 = ptnn.shape_as_list(param)
    param = param[::2]
    param = ptcompat.torch_tile_nd(param, [1, 2])
    param = ptcompat.torch_reshape(param, [bn, d_1])

    return param
Ejemplo n.º 3
0
def tps_parameters(
    batch_size, scal, tps_scal, rot_scal, off_scal, scal_var, rescal=1, augm_scal=1.0
):
    coord = torch.tensor(
        [
            [
                [-0.5, -0.5],
                [0.5, -0.5],
                [-0.5, 0.5],
                [0.5, 0.5],
                [0.2, -0.2],
                [-0.2, 0.2],
                [0.2, 0.2],
                [-0.2, -0.2],
            ]
        ],
        dtype=torch.float32,
    )
    coord = ptcompat.torch_tile_nd(coord, [batch_size, 1, 1])
    shape = ptnn.shape_as_list(coord)
    coord = coord + ptcompat.torch_random_uniform(shape, -0.2, 0.2, dtype=torch.float32)
    vector = ptcompat.torch_random_uniform(
        shape, -tps_scal, tps_scal, dtype=torch.float32
    )

    offset = ptcompat.torch_random_uniform(
        [batch_size, 1, 2], -off_scal, off_scal, dtype=torch.float32
    )
    offset_2 = ptcompat.torch_random_uniform(
        [batch_size, 1, 2], -off_scal, off_scal, dtype=torch.float32
    )
    t_scal = ptcompat.torch_random_uniform(
        [batch_size, 2],
        scal * (1.0 - scal_var),
        scal * (1.0 + scal_var),
        dtype=torch.float32,
    )
    t_scal = t_scal * rescal

    rot_param = ptcompat.torch_random_uniform(
        [batch_size, 1], -rot_scal, rot_scal, dtype=torch.float32
    )
    rot_mat = torch.stack([pt_rotation_matrix(r) for r in rot_param], dim=0)
    parameter_dict = {
        "coord": coord,
        "vector": vector,
        "offset": offset,
        "offset_2": offset_2,
        "t_scal": t_scal,
        "rot_mat": rot_mat,
        "augm_scal": augm_scal,
    }
    if torch.cuda.is_available():
        for k, v in parameter_dict.items():
            if isinstance(v, torch.Tensor):
                parameter_dict[k] = v.cuda()
    return parameter_dict
Ejemplo n.º 4
0
def test_torch_gather_nd():
    import skimage

    bs = 4
    nk = 16
    image_t = torch.from_numpy(skimage.data.astronaut())
    params = ptcompat.torch_tile_nd(image_t.view(
        (1, 1, 512, 512, 3)), [bs, nk, 1, 1, 1])  # batch of stack of images
    indices = torch.stack(
        torch.meshgrid(
            torch.arange(bs),
            torch.arange(nk),
            torch.arange(128),
            torch.arange(128),
            torch.arange(3),
        ),
        dim=-1,
    )  # get 128 x 128 image slice from each item
    out = ptcompat.torch_gather_nd(params, indices)
    assert out.shape == (4, 16, 128, 128, 3)