def swap_axes(x: ep.TensorType, dim0: int, dim1: int) -> ep.TensorType: assert dim0 < x.ndim assert dim1 < x.ndim axes = list(range(x.ndim)) axes[dim0] = dim1 axes[dim1] = dim0 return ep.transpose(x, tuple(axes))
def test_transpose_1d(dummy: Tensor) -> None: t = ep.arange(dummy, 8).float32() assert (ep.transpose(t) == t).all()
def test_transpose_axes(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 60).float32().reshape((3, 4, 5)) return ep.transpose(t, axes=(1, 2, 0))
def test_transpose(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 8).float32().reshape((2, 4)) return ep.transpose(t)
def draw_proposals( bounds: Bounds, originals: ep.Tensor, perturbed: ep.Tensor, unnormalized_source_directions: ep.Tensor, source_directions: ep.Tensor, source_norms: ep.Tensor, spherical_steps: ep.Tensor, source_steps: ep.Tensor, mask: ep.Tensor, perlin_noise: ep.Tensor, ) -> Tuple[ep.Tensor, ep.Tensor]: # remember the actual shape shape = originals.shape assert perturbed.shape == shape assert unnormalized_source_directions.shape == shape assert source_directions.shape == shape # flatten everything to (batch, size) originals = flatten(originals) perturbed = flatten(perturbed) unnormalized_source_directions = flatten(unnormalized_source_directions) source_directions = flatten(source_directions) N, D = originals.shape assert source_norms.shape == (N, ) assert spherical_steps.shape == (N, ) assert source_steps.shape == (N, ) # draw from an iid Gaussian (we can share this across the whole batch) #eta = ep.normal(perturbed, (D, 1)) #print('type', type(ep.transpose(flatten(perlin_noise), (1, 0))), 'shape', ep.transpose(flatten(perlin_noise), (1, 0)).shape) #print('necessary type', type(ep.normal(perturbed, (D, 1))), 'necessary shape', ep.normal(perturbed, (D, 1)).shape) eta = ep.transpose(flatten(perlin_noise), (1, 0)) # make orthogonal (source_directions are normalized) eta = eta.T - ep.matmul(source_directions, eta) * source_directions eta *= flatten(mask) assert eta.shape == (N, D) # rescale norms = ep.norms.l2(eta, axis=-1) assert norms.shape == (N, ) eta = eta * atleast_kd(spherical_steps * source_norms / norms, eta.ndim) # project on the sphere using Pythagoras distances = atleast_kd((spherical_steps.square() + 1).sqrt(), eta.ndim) directions = eta - unnormalized_source_directions spherical_candidates = originals + directions / distances # clip min_, max_ = bounds spherical_candidates = spherical_candidates.clip(min_, max_) # step towards the original inputs new_source_directions = originals - spherical_candidates assert new_source_directions.ndim == 2 new_source_directions_norms = ep.norms.l2(flatten(new_source_directions), axis=-1) # length if spherical_candidates would be exactly on the sphere lengths = source_steps * source_norms # length including correction for numerical deviation from sphere lengths = lengths + new_source_directions_norms - source_norms # make sure the step size is positive lengths = ep.maximum(lengths, 0) # normalize the length lengths = lengths / new_source_directions_norms lengths = atleast_kd(lengths, new_source_directions.ndim) candidates = spherical_candidates + lengths * new_source_directions # clip candidates = candidates.clip(min_, max_) # restore shape candidates = candidates.reshape(shape) spherical_candidates = spherical_candidates.reshape(shape) return candidates, spherical_candidates