def swap_axes(x: ep.TensorType, dim0: int, dim1: int) -> ep.TensorType:
    assert dim0 < x.ndim
    assert dim1 < x.ndim

    axes = list(range(x.ndim))
    axes[dim0] = dim1
    axes[dim1] = dim0

    return ep.transpose(x, tuple(axes))
Exemple #2
0
def test_transpose_1d(dummy: Tensor) -> None:
    t = ep.arange(dummy, 8).float32()
    assert (ep.transpose(t) == t).all()
Exemple #3
0
def test_transpose_axes(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 60).float32().reshape((3, 4, 5))
    return ep.transpose(t, axes=(1, 2, 0))
Exemple #4
0
def test_transpose(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 8).float32().reshape((2, 4))
    return ep.transpose(t)
def draw_proposals(
    bounds: Bounds,
    originals: ep.Tensor,
    perturbed: ep.Tensor,
    unnormalized_source_directions: ep.Tensor,
    source_directions: ep.Tensor,
    source_norms: ep.Tensor,
    spherical_steps: ep.Tensor,
    source_steps: ep.Tensor,
    mask: ep.Tensor,
    perlin_noise: ep.Tensor,
) -> Tuple[ep.Tensor, ep.Tensor]:
    # remember the actual shape
    shape = originals.shape
    assert perturbed.shape == shape
    assert unnormalized_source_directions.shape == shape
    assert source_directions.shape == shape

    # flatten everything to (batch, size)
    originals = flatten(originals)
    perturbed = flatten(perturbed)
    unnormalized_source_directions = flatten(unnormalized_source_directions)
    source_directions = flatten(source_directions)
    N, D = originals.shape

    assert source_norms.shape == (N, )
    assert spherical_steps.shape == (N, )
    assert source_steps.shape == (N, )

    # draw from an iid Gaussian (we can share this across the whole batch)
    #eta = ep.normal(perturbed, (D, 1))
    #print('type', type(ep.transpose(flatten(perlin_noise), (1, 0))), 'shape', ep.transpose(flatten(perlin_noise), (1, 0)).shape)
    #print('necessary type', type(ep.normal(perturbed, (D, 1))), 'necessary shape', ep.normal(perturbed, (D, 1)).shape)

    eta = ep.transpose(flatten(perlin_noise), (1, 0))

    # make orthogonal (source_directions are normalized)
    eta = eta.T - ep.matmul(source_directions, eta) * source_directions
    eta *= flatten(mask)
    assert eta.shape == (N, D)

    # rescale
    norms = ep.norms.l2(eta, axis=-1)
    assert norms.shape == (N, )
    eta = eta * atleast_kd(spherical_steps * source_norms / norms, eta.ndim)

    # project on the sphere using Pythagoras
    distances = atleast_kd((spherical_steps.square() + 1).sqrt(), eta.ndim)
    directions = eta - unnormalized_source_directions
    spherical_candidates = originals + directions / distances

    # clip
    min_, max_ = bounds
    spherical_candidates = spherical_candidates.clip(min_, max_)

    # step towards the original inputs
    new_source_directions = originals - spherical_candidates
    assert new_source_directions.ndim == 2
    new_source_directions_norms = ep.norms.l2(flatten(new_source_directions),
                                              axis=-1)

    # length if spherical_candidates would be exactly on the sphere
    lengths = source_steps * source_norms

    # length including correction for numerical deviation from sphere
    lengths = lengths + new_source_directions_norms - source_norms

    # make sure the step size is positive
    lengths = ep.maximum(lengths, 0)

    # normalize the length
    lengths = lengths / new_source_directions_norms
    lengths = atleast_kd(lengths, new_source_directions.ndim)

    candidates = spherical_candidates + lengths * new_source_directions

    # clip
    candidates = candidates.clip(min_, max_)

    # restore shape
    candidates = candidates.reshape(shape)
    spherical_candidates = spherical_candidates.reshape(shape)
    return candidates, spherical_candidates