Beispiel #1
0
def NTK_torch(X, d_max, fix_dep=0):
    K = torch.zeros((d_max, X.shape[0], X.shape[0]))
    S = torch.matmul(X, X.T)
    H = torch.zeros_like(S)
    for dep in range(d_max):
        if fix_dep <= dep:
            H += S
        K[dep] = H
        L = torch.diag(S)
        P = torch.clip(torch.sqrt(torch.outer(L, L)), a_min=1e-9, a_max=None)
        Sn = torch.clip(S / P, a_min=-1, a_max=1)
        S = (Sn * (torch.pi - torch.arccos(Sn)) +
             torch.sqrt(1.0 - Sn * Sn)) * P / 2.0 / torch.pi
        H = H * (torch.pi - torch.arccos(Sn)) / 2.0 / torch.pi
    return K[d_max - 1]
def test_interpolate_rotation():
    R1 = euler_angles_to_matrix(torch.tensor([0.0, 0.0, 0.0]),
                                convention="ZYX")
    R2 = euler_angles_to_matrix(torch.tensor([np.pi / 2., 0., 0.]),
                                convention="ZYX")

    R = interp_rotation(R1, R2, interp_factor=0.)
    assert torch.allclose(R, R1), (R, R1)

    R = interp_rotation(R1, R2, interp_factor=1.)
    assert torch.allclose(R, R2), (R, R2)

    R = interp_rotation(R2, R2, interp_factor=0.75)
    assert torch.allclose(R, R2), (R, R2)

    R = interp_rotation(R1, R2, interp_factor=0.5)
    expected = euler_angles_to_matrix(torch.tensor([np.pi / 4., 0., 0.]),
                                      convention="ZYX")
    assert torch.allclose(R, expected), (R, expected)

    R3 = euler_angles_to_matrix(torch.tensor([np.pi, 0., 0.]),
                                convention="ZYX")
    R = interp_rotation(R1, R3, interp_factor=0.5)
    angle_distance = torch.arccos(
        (torch.trace(torch.matmul(R.transpose(1, 0), R3)) - 1) / 2.)
    assert torch.isclose(angle_distance,
                         torch.tensor(np.pi / 2.)), (angle_distance,
                                                     np.pi / 2.)

    R4 = euler_angles_to_matrix(torch.tensor([3. * np.pi / 2., 0., 0.]),
                                convention="ZYX")
    R = interp_rotation(R1, R4, interp_factor=0.5)
    expected = euler_angles_to_matrix(torch.tensor([-np.pi / 4., 0., 0.]),
                                      convention="ZYX")
    assert torch.allclose(R, expected), (R, expected)
Beispiel #3
0
def slerp(v0, v1, t, DOT_THRESHHOLD=0.9995):
    r"""Spherical interpolation between two tensors
    Arguments:
        v0 (tensor): The first point to be interpolated from. 
        v1 (tensor): The second point to be interpolated from.
        t (float): The ratio between the two points.
        DOT_THRESHHOLD (float): How close should the dot product be to a
                                straight line before deciding to use a linear
                                 interpolation instead.
    Returns:
        Tensor of a single step from the interpolated path between v0 to v1
        at ratio t.  
    """
    v0_copy = torch.clone(v0)
    v1_copy = torch.clone(v1)

    v0 = v0 / torch.norm(v0)
    v1 = v1 / torch.norm(v1)

    dot = torch.sum(v0 * v1)

    if torch.abs(dot) > DOT_THRESHHOLD:
        return torch.lerp(t, v0_copy, c1_copy)
    
    theta_0 = torch.arccos(dot)
    sin_theta_0 = torch.sin(theta_0)

    theta_t = theta_0 * t
    sin_theta_t = torch.sin(theta_t)

    s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0
    v2 = s0 * v0_copy + s1 * v1_copy
    return v2
def safe_arccos(x):
    mask = (torch.abs(x) < 1).float()
    x_clip = torch.clamp(x,min=-1,max=1)
    output_arccos = torch.arccos(x_clip)
    output_linear = (1 - x)*pi/2
    output = mask*output_arccos + (1-mask)*output_linear
    return output
Beispiel #5
0
def sam(noise: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
    """
    Measure spectral similarity using spectral angle mapper. Result is in radians.

    Parameters
    ----------
    noise: torch.Tensor
    reference: torch.Tensor

    Returns
    -------
    SAM: torch.Tensor
        spectral similarity in radians. If inputs are matrices, matrices are returned.
        i.e. a MxN is returned if a MxNxC is given
    """
    if len(noise.shape) == 1:
        # case 1: vectors -> easy
        numer = torch.dot(noise, reference)
        denom = torch.linalg.norm(noise) * torch.linalg.norm(reference)
    else:
        # case 2: matrices -> return a MxN if MxNxC is given
        numer = torch.sum(noise * reference, dim=-1)
        denom = torch.linalg.norm(noise, dim=-1) * torch.linalg.norm(reference, dim=-1)
    eps = torch.finfo(denom.dtype).eps
    return torch.arccos(numer / (denom + eps))
Beispiel #6
0
def slerp(val, low, high):
    """Spherical interpolation. val has a range of 0 to 1."""
    if val <= 0:
        return low
    elif val >= 1:
        return high
    elif torch.allclose(low, high):
        return low
    omega = torch.arccos(torch.dot(low/torch.norm(low), high/torch.norm(high)))
    so = torch.sin(omega)
    return torch.sin((1.0-val)*omega) / so * low + torch.sin(val*omega)/so * high
def get_opposite_angles(vs, edge_points, side):
    edges_a = vs[:, edge_points[:, side //
                                2]] - vs[:, edge_points[:, side // 2 + 2]]
    edges_b = vs[:, edge_points[:, 1 - side //
                                2]] - vs[:, edge_points[:, side // 2 + 2]]

    edges_a /= fixed_division(torch.norm(edges_a, dim=-1, keepdim=True),
                              epsilon=0.1)
    edges_b /= fixed_division(torch.norm(edges_b, dim=-1, keepdim=True),
                              epsilon=0.1)
    dot = torch.sum(edges_a * edges_b, dim=-1).clip(-1, 1)
    return torch.arccos(dot)
    def intersection_angles(self, x0, x1) -> torch.Tensor:
        """ Compute all of the up to 2M intersections of the ellipse and the linear constraints """
        g1 = self.A.matmul(x0)
        g2 = self.A.matmul(x1)

        r = torch.sqrt(g1**2 + g2**2)
        phi = 2 * torch.atan(g2 / (r + g1)).squeeze()

        # two solutions per linear constraint, shape of theta: (M, 2)
        arg = -(self.b / r.squeeze(-1)).squeeze()
        theta = torch.zeros((self.A.shape[0], 2),
                            dtype=self.A.dtype,
                            device=self.A.device)

        # write NaNs if there is no intersection
        arg[torch.absolute(arg) > 1] = torch.tensor(float("nan"))
        theta[:, 0] = torch.arccos(arg) + phi
        theta[:, 1] = -torch.arccos(arg) + phi
        theta = theta[torch.isfinite(theta)]

        return torch.sort(theta +
                          (theta < 0.) * 2. * math.pi)[0]  # in [0, 2*pi]
Beispiel #9
0
def get_sda(xyz, M, triple_mask):
    """
    Input: (xyz) in list type, (M) = get_edge_matrix, (triple_mask) = mask of 0's and 1's
    Output: SDA of angles specified in triple_mask in degrees
    """
    xyz = torch.tensor(xyz)
    edges = torch.matmul(M, xyz)
    edges = F.normalize(edges, dim=1)  # [42 x 3]
    gram = torch.matmul(edges, edges.T)  # [42 x 42]
    gram = torch.clamp(gram, -1., 1.)  # [42 x 42]
    angles = torch.masked_select(gram, triple_mask > 0.)
    angles = torch.rad2deg(torch.arccos(angles))
    return torch.std(angles)
Beispiel #10
0
def get_symmetry_values(b, sym_mask, gram):
    """
    Inputs: b = number of batches, sym_mask = symmetry_mask, gram = gram matrix
    Output: list of deviations from symmetry, angles selected in sym_mask 
    """
    symmetry = torch.zeros(b)
    for i in range(b):
        sym_max = torch.max(sym_mask[i, :, :]).int()
        for sm in range(sym_max):
            cos_angles_S = torch.masked_select(gram[i, :, :],
                                               sym_mask == sm + 1)
            angles_S = torch.rad2deg(torch.arccos(cos_angles_S))
            #print(angles_S[0] , angles_S[1])
            symmetry[i] += torch.abs(angles_S[0] - angles_S[1])
    return symmetry
def torch_angle_between_vectors(vec0, vec1):
    """
    Returns the cosine angle between two vectors
    Parameters:
        vec_a [torch vector] with shape [vector_length, 1]
        vec_b [torch vector] with shape [vector_length, 1]
    Outputs:
        angle [float] angle between the two vectors, in radians
    """
    assert vec0.shape == vec1.shape, (f'vec0.shape = {vec0.shape} must equal vec1.shape={vec1.shape}')
    assert vec0.ndim == 2, (f'vec0 should have ndim==2 with secondary dimension=1, not {vec0.shape}')
    inner_products = torch.matmul(torch_l2_normalize(vec0).T, torch_l2_normalize(vec1))
    inner_products = torch.clip(inner_products, -1.0, 1.0)
    angle = torch.arccos(inner_products)
    return angle
Beispiel #12
0
def get_planarity_values(b, face_mask, gram, ps):
    """
    Inputs: b = number of batches, face_mask = face mask, gram = gram matrix, ps = total internal angles
    of object, usually 1080
    Output: List of deviations from planarity for each object 
    """
    planarity = torch.zeros(b)
    P = torch.sum(face_mask > 0, dim=(1, 2))
    P = torch.cumsum(P, dim=0)
    P = torch.hstack((torch.tensor(0), P)).int()
    cos_angles_P = torch.masked_select(gram, face_mask > 0.)
    angles_P = torch.rad2deg(torch.arccos(cos_angles_P))
    for i in range(b):
        planarity[i] = torch.abs(ps[i] - torch.sum(angles_P[P[i]:P[i + 1]]))
    return planarity
Beispiel #13
0
    def get_vertex_angles(self, vrt):
        bs = vrt.size(0)
        angle_pts = vrt.index_select(1, self.angle_vrt_idx.view(-1))
        angle_pts = angle_pts.reshape(bs, -1, 6, 3, 3)
        a = angle_pts[:, :, :, 0]
        b = angle_pts[:, :, :, 1]
        c = angle_pts[:, :, :, 2]

        ba = a - b
        bc = c - b

        ba_nrm = torch.norm(ba, dim=-1).unsqueeze(-1)
        bc_nrm = torch.norm(bc, dim=-1).unsqueeze(-1)
        one = torch.tensor(1.).to(vrt.device)
        ba_nrm = torch.where(ba_nrm > 0, ba_nrm, one)
        bc_nrm = torch.where(bc_nrm > 0, bc_nrm, one)

        ba_normed = ba / ba_nrm
        bc_normed = bc / bc_nrm
        dot_bac = (ba_normed * bc_normed).sum(dim=-1).unsqueeze(-1)
        angles = torch.arccos(dot_bac) * self.angle_vrt_wt
        return angles
Beispiel #14
0
def best_k_candidates_for_each_trigger_token(trigger_token_ids, trigger_mask,
                                             trigger_length,
                                             embedding_matrices,
                                             num_candidates):
    '''
    equation 2: (embedding_matrix - trigger embedding)T @ trigger_grad
    '''
    trigger_grad_shape = [max(trigger_mask.shape[0], 1), trigger_length, -1]
    trigger_grads = tools.EXTRACTED_GRADS[0][trigger_mask].reshape(trigger_grad_shape)\
                        .mean(0).unsqueeze(0)
    clean_trigger_grads = torch.stack(tools.EXTRACTED_CLEAN_GRADS)\
                            [trigger_mask.unsqueeze(0).repeat([len(tools.EXTRACTED_CLEAN_GRADS), 1, 1])]\
                            .reshape([len(tools.EXTRACTED_CLEAN_GRADS)]+trigger_grad_shape)\
                            .mean([0,1]).unsqueeze(0)

    trigger_token_embeds = torch.nn.functional.embedding(
        trigger_token_ids.to(DEVICE),
        embedding_matrices[0]).detach().unsqueeze(1)
    gradient_dot_embedding_matrix = tools.BETA * tools.SIGN * torch.einsum(
        "bij,ikj->bik", (trigger_grads, embedding_matrices[0].unsqueeze(0)))[0]

    trigger_token_embeds = torch.nn.functional.embedding(
        trigger_token_ids.to(DEVICE),
        embedding_matrices[1]).detach().unsqueeze(1)
    clean_gradient_dot_embedding_matrix = tools.LAMBDA * tools.SIGN * torch.einsum(
        "bij,ikj->bik",
        (clean_trigger_grads, embedding_matrices[1].unsqueeze(0)))[0]

    gradient_dot_embedding_matrix += clean_gradient_dot_embedding_matrix

    _, best_k_ids = torch.topk(gradient_dot_embedding_matrix,
                               num_candidates,
                               dim=1)
    theta = torch.arccos(_[0] /
                         (embedding_matrices[1][best_k_ids[0]].norm(dim=1) *
                          trigger_grads.norm()))

    return best_k_ids, _
Beispiel #15
0
 def diffusion_kernel(a, tmpt, dim, use_cons=False):
     cons = (4 * np.pi * tmpt)**(-dim / 2) if use_cons else 1
     return cons * torch.exp(-torch.square(torch.arccos(a)) / tmpt)
Beispiel #16
0
def rod2cubo(ro):
    """ Rodrigues-Frank vector to cubochoric vector"""
    """ Step 1: Rodrigues-Frank vector to homochoric vector."""

    f = torch.where(
        torch.isfinite(ro[..., 3:4]), 2.0 * torch.arctan(ro[..., 3:4]) -
        torch.sin(2.0 * torch.arctan(ro[..., 3:4])),
        _precision_check(math.pi, ro.dtype, ro.device))

    ho = torch.where(
        torch.lt(torch.sum(ro[..., 0:3]**2.0, -1, keepdim=True),
                 _precision_check(1.e-8, ro.dtype)).expand(ro[..., 0:3].shape),
        _precision_check([0, 0, 0], ro.dtype, ro.device),
        ro[..., 0:3] * _cbrt(0.75 * f))

    # return ho # here for homochoric vector
    """
        Step 2: Homochoric vector to cubochoric vector.

        References
        ----------
        D. Roşca et al., Modelling and Simulation in Materials Science and Engineering 22:075013, 2014
        https://doi.org/10.1088/0965-0393/22/7/075013

        """
    rs = torch.linalg.norm(ho, dim=-1, keepdim=True)

    xyz3 = torch.gather(ho, -1, _get_tensor_pyramid_order(ho, 'forward'))

    xyz2 = xyz3[..., 0:2] * torch.sqrt(2.0 * rs /
                                       (rs + torch.abs(xyz3[..., 2:3])))
    qxy = torch.sum(xyz2**2, -1, keepdim=True)

    q2 = qxy + torch.amax(torch.abs(xyz2), -1, keepdim=True)**2
    sq2 = torch.sqrt(q2)
    q = (beta / math.sqrt(2.0) / R1) * torch.sqrt(
        q2 * qxy / (q2 - torch.amax(torch.abs(xyz2), -1, keepdim=True) * sq2))
    tt = torch.clip((torch.amin(torch.abs(xyz2), -1, keepdim=True) ** 2 \
                  + torch.amax(torch.abs(xyz2), -1, keepdim=True) * sq2) / math.sqrt(2.0) / qxy, -1.0, 1.0)
    T_inv = torch.where(
        torch.le(torch.abs(xyz2[..., 1:2]), torch.abs(xyz2[..., 0:1])),
        torch.cat(
            (torch.ones_like(tt), torch.arccos(tt) / math.pi * 12), dim=-1),
        torch.cat((torch.arccos(tt) / math.pi * 12, torch.ones_like(tt)),
                  dim=-1)) * q
    T_inv[xyz2 < 0.0] *= -1.0  # warning

    T_inv[(torch.isclose(qxy,
                         _precision_check(0.0, qxy.dtype),
                         rtol=0.0,
                         atol=1.0e-12)).expand(T_inv.shape)] = 0.0
    cu = torch.cat((T_inv, torch.where(torch.lt(xyz3[..., 2:3],0.0), -torch.ones_like(xyz3[..., 2:3]), torch.ones_like(xyz3[..., 2:3])) \
                    * rs / math.sqrt(6.0 / math.pi)), dim=-1) / sc

    cu[torch.isclose(torch.sum(torch.abs(ho), -1),
                     _precision_check(0.0, ho.dtype),
                     rtol=0.0,
                     atol=1.0e-16)] = 0.0
    cu = torch.gather(cu, -1, _get_tensor_pyramid_order(ho, 'backward'))

    return cu
Beispiel #17
0
    def __call__(
        self,
        evaluation_embeddings: torch.Tensor,
        bias_direction1: torch.Tensor,
        bias_direction2: torch.Tensor,
    ):
        """

        # Parameters

        evaluation_embeddings : `torch.Tensor`
            A tensor of size (batch_size, ..., dim) of embeddings for which to mitigate bias.
        bias_direction1 : `torch.Tensor`
            A unit tensor of size (dim, ) representing a concept subspace (e.g. gender).
        bias_direction2 : `torch.Tensor`
            A unit tensor of size (dim, ) representing another concept subspace from
            which bias_direction1 should be dissociated (e.g. occupation).

        !!! Note
            All tensors are expected to be on the same device.

        # Returns

        bias_mitigated_embeddings : `torch.Tensor`
            A tensor of the same size as evaluation_embeddings.
        """
        # Some sanity checks
        if evaluation_embeddings.ndim < 2:
            raise ConfigurationError("evaluation_embeddings must have at least two dimensions.")
        if bias_direction1.ndim != 1 or bias_direction2.ndim != 1:
            raise ConfigurationError("bias_direction1 and bias_direction2 must be one-dimensional.")
        if evaluation_embeddings.size(-1) != bias_direction1.size(-1) or evaluation_embeddings.size(
            -1
        ) != bias_direction2.size(-1):
            raise ConfigurationError(
                "All embeddings, bias_direction1, and bias_direction2 must have the same dimensionality."
            )
        if bias_direction1.size(-1) < 2:
            raise ConfigurationError(
                "Dimensionality of all embeddings, bias_direction1, and bias_direction2 must \
                be >= 2."
            )

        with torch.set_grad_enabled(self.requires_grad):
            bias_direction1 = bias_direction1 / torch.linalg.norm(bias_direction1)
            bias_direction2 = bias_direction2 / torch.linalg.norm(bias_direction2)

            bias_direction2_orth = self._remove_component(
                bias_direction2.reshape(1, -1), bias_direction1
            )[0]
            bias_direction2_orth = bias_direction2_orth / torch.linalg.norm(bias_direction2_orth)

            # Create rotation matrix as orthonormal basis
            # with v1 and v2'
            init_orth_matrix = torch.eye(
                bias_direction1.size(0),
                device=evaluation_embeddings.device,
                requires_grad=self.requires_grad,
            )
            rotation_matrix = torch.zeros(
                (bias_direction1.size(0), bias_direction1.size(0)),
                device=evaluation_embeddings.device,
                requires_grad=self.requires_grad,
            )
            rotation_matrix = torch.cat(
                [
                    bias_direction1.reshape(1, -1),
                    bias_direction2_orth.reshape(1, -1),
                    rotation_matrix[2:],
                ]
            )
            # Apply Gram-Schmidt
            for i in range(len(rotation_matrix) - 2):
                subspace_proj = torch.sum(
                    self._proj(
                        rotation_matrix[: i + 2].clone(), init_orth_matrix[i], normalize=True
                    ),
                    dim=0,
                )
                rotation_matrix[i + 2] = (init_orth_matrix[i] - subspace_proj) / torch.linalg.norm(
                    init_orth_matrix[i] - subspace_proj
                )

            mask = ~(evaluation_embeddings == 0).all(dim=-1)
            # Transform all evaluation embeddings
            # using orthonormal basis computed above
            rotated_evaluation_embeddings = torch.matmul(
                evaluation_embeddings[mask], rotation_matrix.t()
            )
            # Want to adjust first 2 coordinates and leave d - 2
            # other orthogonal components fixed
            fixed_rotated_evaluation_embeddings = rotated_evaluation_embeddings[..., 2:]
            # Restrict attention to subspace S spanned by bias directions
            # which we hope to make orthogonal
            restricted_rotated_evaluation_embeddings = torch.cat(
                [
                    torch.matmul(rotated_evaluation_embeddings, bias_direction1.reshape(-1, 1)),
                    torch.matmul(
                        rotated_evaluation_embeddings, bias_direction2_orth.reshape(-1, 1)
                    ),
                ],
                dim=-1,
            )

            # Transform and restrict bias directions
            restricted_bias_direction1 = torch.tensor(
                [1.0, 0.0], device=evaluation_embeddings.device, requires_grad=self.requires_grad
            )
            bias_direction_inner_prod = torch.dot(bias_direction1, bias_direction2)
            restricted_bias_direction2 = torch.tensor(
                [
                    bias_direction_inner_prod,
                    torch.sqrt(1 - torch.square(bias_direction_inner_prod)),
                ],
                device=evaluation_embeddings.device,
                requires_grad=self.requires_grad,
            )
            restricted_bias_direction2_orth = torch.tensor(
                [0.0, 1.0], device=evaluation_embeddings.device, requires_grad=self.requires_grad
            )

            restricted_bias_direction_inner_prod = torch.dot(
                restricted_bias_direction1, restricted_bias_direction2
            )
            theta = torch.abs(torch.arccos(restricted_bias_direction_inner_prod))
            theta_proj = np.pi / 2 - theta
            phi = torch.arccos(
                torch.matmul(
                    restricted_rotated_evaluation_embeddings
                    / torch.linalg.norm(
                        restricted_rotated_evaluation_embeddings, dim=-1, keepdim=True
                    ),
                    restricted_bias_direction1,
                )
            )
            d = torch.matmul(
                restricted_rotated_evaluation_embeddings
                / torch.linalg.norm(restricted_rotated_evaluation_embeddings, dim=-1, keepdim=True),
                restricted_bias_direction2_orth,
            )

            # Add noise to avoid DivideByZero
            theta_x = torch.zeros_like(phi, requires_grad=self.requires_grad)
            theta_x = torch.where(
                (d > 0) & (phi < theta_proj),
                theta * (phi / (theta_proj + 1e-10)),
                theta_x,
            )
            theta_x = torch.where(
                (d > 0) & (phi > theta_proj),
                theta * ((np.pi - phi) / (np.pi - theta_proj + 1e-10)),
                theta_x,
            )
            theta_x = torch.where(
                (d < 0) & (phi >= np.pi - theta_proj),
                theta * ((np.pi - phi) / (theta_proj + 1e-10)),
                theta_x,
            )
            theta_x = torch.where(
                (d < 0) & (phi < np.pi - theta_proj),
                theta * (phi / (np.pi - theta_proj + 1e-10)),
                theta_x,
            )

            f_matrix = torch.cat(
                [
                    torch.cos(theta_x).unsqueeze(-1),
                    -torch.sin(theta_x).unsqueeze(-1),
                    torch.sin(theta_x).unsqueeze(-1),
                    torch.cos(theta_x).unsqueeze(-1),
                ],
                dim=-1,
            )
            f_matrix = f_matrix.reshape(f_matrix.size()[:-1] + (2, 2))

            evaluation_embeddings_clone = evaluation_embeddings.clone()
            evaluation_embeddings_clone[mask] = torch.cat(
                [
                    torch.bmm(
                        f_matrix,
                        restricted_rotated_evaluation_embeddings.unsqueeze(-1),
                    ).squeeze(-1),
                    fixed_rotated_evaluation_embeddings,
                ],
                dim=-1,
            )
            return torch.matmul(evaluation_embeddings_clone, rotation_matrix)
Beispiel #18
0
def dist_geod(X, Y):
    u, s, v = torch.linalg.svd(multiprod_torch(multitransp_torch(X), Y))
    #s[s > 1] = 1
    s = torch.arccos(s)
    return torch.linalg.norm(s)
Beispiel #19
0
def quat2cubo(qu):
    """Quaternion to cubochoric vector.
        Quaternion must be in form s, <x,y,z>
        where s is real component, and <x,y,z>
        is imaginary vector component
    """
    """ Step 1: Quaternion to homochoric vector."""

    omega = 2.0 * torch.arccos(torch.clip(qu[..., 0:1], -1.0, 1.0))

    ho = torch.where(torch.lt(torch.abs(omega), 1.0e-12),
                  _precision_check([0, 0, 0], qu.dtype, qu.device),
                  qu[..., 1:4] / torch.linalg.norm(qu[..., 1:4], dim=-1, keepdim=True) \
                  * _cbrt(0.75 * (omega - torch.sin(omega))))

    # return ho # inserted here gives back the homochoric coordinates
    """
        Step 2: Homochoric vector to cubochoric vector.

        References
        ----------
        D. Roşca et al., Modelling and Simulation in Materials Science and Engineering 22:075013, 2014
        https://doi.org/10.1088/0965-0393/22/7/075013

        """
    rs = torch.linalg.norm(ho, dim=-1, keepdim=True)

    xyz3 = torch.gather(ho, -1, _get_tensor_pyramid_order(ho, 'forward'))

    xyz2 = xyz3[..., 0:2] * torch.sqrt(2.0 * rs /
                                       (rs + torch.abs(xyz3[..., 2:3])))
    qxy = torch.sum(xyz2**2, -1, keepdim=True)

    q2 = qxy + torch.amax(torch.abs(xyz2), -1, keepdim=True)**2
    sq2 = torch.sqrt(q2)
    q = (beta / math.sqrt(2.0) / R1) * torch.sqrt(
        q2 * qxy / (q2 - torch.amax(torch.abs(xyz2), -1, keepdim=True) * sq2))
    tt = torch.clip((torch.amin(torch.abs(xyz2), -1, keepdim=True) ** 2 \
                  + torch.amax(torch.abs(xyz2), -1, keepdim=True) * sq2) / math.sqrt(2.0) / qxy, -1.0, 1.0)
    T_inv = torch.where(
        torch.le(torch.abs(xyz2[..., 1:2]), torch.abs(xyz2[..., 0:1])),
        torch.cat(
            (torch.ones_like(tt), torch.arccos(tt) / math.pi * 12), dim=-1),
        torch.cat((torch.arccos(tt) / math.pi * 12, torch.ones_like(tt)),
                  dim=-1)) * q
    T_inv[xyz2 < 0.0] *= -1.0

    T_inv[(torch.isclose(qxy,
                         _precision_check(0.0, qxy.dtype),
                         rtol=0.0,
                         atol=1.0e-12)).expand(T_inv.shape)] = 0.0
    cu = torch.cat((T_inv, torch.where(torch.lt(xyz3[..., 2:3],0.0), -torch.ones_like(xyz3[..., 2:3]), torch.ones_like(xyz3[..., 2:3])) \
                    * rs / math.sqrt(6.0 / math.pi)), dim=-1) / sc

    cu[torch.isclose(torch.sum(torch.abs(ho), -1),
                     _precision_check(0.0, ho.dtype),
                     rtol=0.0,
                     atol=1.0e-16)] = 0.0
    cu = torch.gather(cu, -1, _get_tensor_pyramid_order(ho, 'backward'))

    return cu
Beispiel #20
0
def angular_torch(values: Tensor, targets: Tensor) -> Tensor:
    return 2 * torch.arccos(cosine_torch(values, targets)) / np.pi
Beispiel #21
0
def angular_distance(values: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    cosine = cosine_distance(values, targets)
    angular = 2 * torch.arccos(cosine) / np.pi
    return angular
Beispiel #22
0
# flake8: noqa
import torch
import math

a = torch.randn(4)
b = torch.randn(4)
t = torch.tensor([-1, -2, 3], dtype=torch.int8)

# abs/absolute
torch.abs(torch.tensor([-1, -2, 3]))
torch.absolute(torch.tensor([-1, -2, 3]))

# acos/arccos
torch.acos(a)
torch.arccos(a)

# acosh/arccosh
torch.acosh(a.uniform_(1, 2))

# add
torch.add(a, 20)
torch.add(a, torch.randn(4, 1), alpha=10)

# addcdiv
torch.addcdiv(torch.randn(1, 3),
              torch.randn(3, 1),
              torch.randn(1, 3),
              value=0.1)

# addcmul
torch.addcmul(torch.randn(1, 3),
Beispiel #23
0
    def run_inference(self, observations, grad_calculations, do_binding,
                      do_rotation, do_translation, order, reorder):
        [grad_calc_binding, grad_calc_rotation,
         grad_calc_translation] = grad_calculations

        if reorder is not None:
            reorder = reorder.to(self.device)

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        ###########################  BINDING  #################################
        if do_binding:
            ## Binding matrices
            # Init binding entries
            bm = self.binder.init_binding_matrix_det_()
            # bm = binder.init_binding_matrix_rand_()
            # print(bm)
            dummie_line = torch.ones(1, self.num_observations).to(
                self.device) * self.dummie_init

            for i in range(self.tuning_length + 1):
                matrix = bm.clone().to(self.device)
                if self.nxm:
                    matrix = torch.cat([matrix, dummie_line])
                matrix.requires_grad_()
                self.Bs.append(matrix)

        ###########################  ROTATION  ################################
        if do_rotation:
            if self.rotation_type == 'qrotate':
                ## Rotation quaternion
                rq = self.perspective_taker.init_quaternion()
                # print(rq)

                for i in range(self.tuning_length + 1):
                    quat = rq.clone().to(self.device)
                    quat.requires_grad_()
                    self.Rs.append(quat)

            elif self.rotation_type == 'eulrotate':
                ## Rotation euler angles
                # ra = perspective_taker.init_angles_()
                # ra = torch.Tensor([[309.89], [82.234], [95.765]])
                ra = torch.Tensor([[75.0], [6.0], [128.0]])
                # print(ra)

                for i in range(self.tuning_length + 1):
                    angles = []
                    for j in range(self.num_spatial_dimensions):
                        angle = ra[j].clone().to(self.device)
                        angle.requires_grad_()
                        angles.append(angle)
                    self.Rs.append(angles)

            else:
                print('ERROR: Received unknown rotation type!')
                exit()

        ###########################  TRANSLATION  #############################
        if do_translation:
            tb = self.perspective_taker.init_translation_bias_()
            # print(tb)

            for i in range(self.tuning_length + 1):
                transba = tb.clone().to(self.device)
                transba.requires_grad = True
                self.Cs.append(transba)

        #######################################################################

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_observations,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(self.Bs[i],
                                                      self.scale_mode,
                                                      self.scale_combo,
                                                      self.nxm_enhance,
                                                      self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
            else:
                x_B = o

            if self.gestalten:
                if self.dir_mag_gest:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions],
                    x_B[:, self.num_spatial_dimensions:]
                ])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[i])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else:
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[i])
            else:
                x_C = x_R

            if self.gestalten:
                if self.dir_mag_gest:
                    x_C = torch.cat([x_C, dir, mag], dim=1)
                else:
                    x_C = torch.cat([x_C, dir], dim=1)
            #######################################################################

            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            ###########################  BINDING  #################################
            if do_binding:
                bm = self.binder.scale_binding_matrix(self.Bs[-1],
                                                      self.scale_mode,
                                                      self.scale_combo,
                                                      self.nxm_enhance,
                                                      self.nxm_last_line_scale)
                if self.nxm:
                    bm = bm[:-1]
                x_B = self.binder.bind(o, bm)
            else:
                x_B = o

            if self.gestalten:
                if self.dir_mag_gest:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions],
                    x_B[:, self.num_spatial_dimensions:]
                ])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else:
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
            else:
                x_C = x_R

            if self.gestalten:
                if self.dir_mag_gest:
                    x_C = torch.cat([x_C, dir, mag], dim=1)
                else:
                    x_C = torch.cat([x_C, dir], dim=1)
            #######################################################################

            x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                # Update parameters
                with torch.no_grad():
                    # self.at_losses.append(loss.clone().detach())
                    self.at_losses.append(loss.clone().detach().cpu())
                    # self.at_losses.append(loss.clone().detach().cpu().numpy())
                    print(
                        f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                    ###########################  BINDING  #################################
                    if do_binding:
                        # Calculate gradients with respect to the entires
                        for i in range(self.tuning_length + 1):
                            self.B_grads[i] = self.Bs[i].grad
                        # print(B_grads[tuning_length])

                        # Calculate overall gradients
                        if grad_calc_binding == 'lastOfTunHor':
                            ### version 1
                            grad_B = self.B_grads[0]
                        elif grad_calc_binding == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_B = torch.mean(torch.stack(self.B_grads),
                                                dim=0)
                        elif grad_calc_binding == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_B = [None
                                                ] * (self.tuning_length + 1)
                            for i in range(self.tuning_length + 1):
                                weighted_grads_B[i] = np.power(
                                    self.grad_bias_binding,
                                    i) * self.B_grads[i]
                            grad_B = torch.mean(torch.stack(weighted_grads_B),
                                                dim=0)
                        # print(f'grad_B: {grad_B}')

                        # Update parameters in time step t-H with saved gradients
                        grad_B = grad_B.to(self.device)
                        # upd_B = self.binder.update_binding_matrix_(
                        upd_B = self.binder.decay_update_binding_matrix_(
                            self.Bs[0], grad_B, self.at_learning_rate_binding,
                            self.bm_momentum)

                        # Compare binding matrix to ideal matrix
                        # NOTE: ideal matrix is always identity, bc then the FBE and determinant can be calculated => provide reorder
                        c_bm = self.binder.scale_binding_matrix(
                            upd_B, self.scale_mode, self.scale_combo)
                        if order is not None:
                            c_bm = c_bm.gather(
                                1,
                                reorder.unsqueeze(0).expand(c_bm.shape))

                        if self.nxm:
                            self.oc_grads.append(grad_B[-1])
                            FBE = self.evaluator.FBE_nxm_additional_features(
                                c_bm, self.ideal_binding,
                                self.additional_features)
                            c_bm = self.evaluator.clear_nxm_binding_matrix(
                                c_bm, self.additional_features)

                        mat_loss = self.evaluator.FBE(c_bm, self.ideal_binding)

                        if self.nxm:
                            mat_loss = torch.stack(
                                [mat_loss, FBE, mat_loss + FBE])
                        self.bm_losses.append(mat_loss)
                        print(f'loss of binding matrix (FBE): {mat_loss}')

                        # Compute determinante of binding matrix
                        det = torch.det(c_bm)
                        self.bm_dets.append(det)
                        print(f'determinante of binding matrix: {det}')

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Bs[i].requires_grad = False
                            self.Bs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            self.Bs[i].data = upd_B.clone().data
                            self.Bs[i].requires_grad = True

                    ###########################  ROTATION  ################################
                    if do_rotation:
                        ## get gradients
                        if self.rotation_type == 'qrotate':
                            for i in range(self.tuning_length + 1):
                                # save grads for all parameters in all time steps of tuning horizon
                                self.R_grads[i] = self.Rs[i].grad
                        else:
                            for i in range(self.tuning_length + 1):
                                # save grads for all parameters in all time steps of tuning horizon
                                grad = []
                                for j in range(self.num_input_dimensions):
                                    grad.append(self.Rs[i][j].grad)
                                self.R_grads[i] = torch.stack(grad)
                        # print(self.R_grads[self.tuning_length])

                        # Calculate overall gradients
                        if grad_calc_rotation == 'lastOfTunHor':
                            ### version 1
                            grad_R = self.R_grads[0]
                        elif grad_calc_rotation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_R = torch.mean(torch.stack(self.R_grads),
                                                dim=0)
                        elif grad_calc_rotation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_R = [None
                                                ] * (self.tuning_length + 1)
                            for i in range(self.tuning_length + 1):
                                weighted_grads_R[i] = np.power(
                                    self.grad_bias_rotation,
                                    i) * self.R_grads[i]
                            grad_R = torch.mean(torch.stack(weighted_grads_R),
                                                dim=0)
                        # print(f'grad_R: {grad_R}')

                        grad_R = grad_R.to(self.device)
                        if self.rotation_type == 'qrotate':
                            # Update parameters in time step t-H with saved gradients
                            upd_R = self.perspective_taker.update_quaternion(
                                self.Rs[0], grad_R,
                                self.at_learning_rate_rotation,
                                self.r_momentum)
                            print(f'updated quaternion: {upd_R}')

                            # Compare quaternion values
                            # quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                            quat_loss = 2 * torch.arccos(
                                torch.abs(
                                    torch.sum(torch.mul(
                                        self.ideal_quat, upd_R))))
                            quat_loss = torch.rad2deg(quat_loss)
                            print(f'loss of quaternion: {quat_loss}')
                            self.rv_losses.append(quat_loss)
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.quaternion2rotmat(
                                upd_R)

                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length + 1):
                                self.Rs[i].requires_grad = False
                                self.Rs[i].grad.data.zero_()

                            # Update all parameters for all time steps
                            for i in range(self.tuning_length + 1):
                                quat = upd_R.clone()
                                quat.requires_grad_()
                                self.Rs[i] = quat

                        else:
                            # Update parameters in time step t-H with saved gradients
                            upd_R = self.perspective_taker.update_rotation_angles_(
                                self.Rs[0], grad_R,
                                self.at_learning_rate_rotation)
                            print(f'updated angles: {upd_R}')

                            # Save rotation angles
                            rotang = torch.stack(upd_R)
                            # angles:
                            ang_diff = rotang - self.ideal_angle
                            ang_loss = 2 - (
                                torch.cos(torch.deg2rad(ang_diff)) + 1)
                            print(
                                f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
                            )
                            self.rv_losses.append(torch.norm(ang_loss))
                            # Compute rotation matrix
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                upd_R[0], upd_R[1], upd_R[2])[0]

                            # Zero out gradients for all parameters in all time steps of tuning horizon
                            for i in range(self.tuning_length + 1):
                                for j in range(self.num_input_dimensions):
                                    self.Rs[i][j].requires_grad = False
                                    self.Rs[i][j].grad.data.zero_()

                            # Update all parameters for all time steps
                            for i in range(self.tuning_length + 1):
                                angles = []
                                for j in range(3):
                                    angle = upd_R[j].clone()
                                    angle.requires_grad_()
                                    angles.append(angle)
                                self.Rs[i] = angles

                        # Calculate and save rotation losses
                        # matrix:
                        # mat_loss = self.mse(
                        #     (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))),
                        #     self.identity_matrix
                        # )
                        dif_R = torch.mm(self.ideal_rotation,
                                         torch.transpose(rotmat, 0, 1))
                        mat_loss = torch.arccos(0.5 * (torch.trace(dif_R) - 1))
                        mat_loss = torch.rad2deg(mat_loss)
                        print(f'loss of rotation matrix: {mat_loss}')
                        self.rm_losses.append(mat_loss)

                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        ## Get gradients
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            self.C_grads[i] = self.Cs[i].grad

                        # print(self.C_grads[self.tuning_length])

                        # Calculate overall gradients
                        if grad_calc_translation == 'lastOfTunHor':
                            ### version 1
                            grad_C = self.C_grads[0]
                        elif grad_calc_translation == 'meanOfTunHor':
                            ### version 2 / 3
                            grad_C = torch.mean(torch.stack(self.C_grads),
                                                dim=0)
                        elif grad_calc_translation == 'weightedInTunHor':
                            ### version 4
                            weighted_grads_C = [None
                                                ] * (self.tuning_length + 1)
                            for i in range(self.tuning_length + 1):
                                weighted_grads_C[i] = np.power(
                                    self.grad_bias_translation,
                                    i) * self.C_grads[i]
                            grad_C = torch.mean(torch.stack(weighted_grads_C),
                                                dim=0)

                        # Update parameters in time step t-H with saved gradients
                        grad_C = grad_C.to(self.device)
                        upd_C = self.perspective_taker.update_translation_bias_(
                            self.Cs[0], grad_C,
                            self.at_learning_rate_translation, self.c_momentum)

                        # Compare translation bias to ideal bias
                        trans_loss = self.mse(self.ideal_translation, upd_C)
                        self.c_losses.append(trans_loss)
                        print(f'loss of translation bias (MSE): {trans_loss}')

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Cs[i].requires_grad = False
                            self.Cs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            translation = upd_C.clone()
                            translation.requires_grad_()
                            self.Cs[i] = translation

                    #######################################################################

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # print(f'updated init_state: {init_state}')

                # forward pass from t-H to t with new parameters
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    ###########################  BINDING  #################################
                    if do_binding:
                        bm = self.binder.scale_binding_matrix(
                            self.Bs[i], self.scale_mode, self.scale_combo,
                            self.nxm_enhance, self.nxm_last_line_scale)
                        if self.nxm:
                            bm = bm[:-1]
                        x_B = self.binder.bind(self.at_inputs[i], bm)
                    else:
                        x_B = self.at_inputs[i]

                    if self.gestalten:
                        if self.dir_mag_gest:
                            mag = x_B[:, -1].view(self.num_observations, 1)
                            x_B = x_B[:, :-1]
                        x_B = torch.cat([
                            x_B[:, :self.num_spatial_dimensions],
                            x_B[:, self.num_spatial_dimensions:]
                        ])
                    ###########################  ROTATION  ################################
                    if do_rotation:
                        if self.rotation_type == 'qrotate':
                            x_R = self.perspective_taker.qrotate(
                                x_B, self.Rs[i])
                        else:
                            rotmat = self.perspective_taker.compute_rotation_matrix_(
                                self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                            x_R = self.perspective_taker.rotate(x_B, rotmat)
                    else:
                        x_R = x_B

                    if self.gestalten:
                        dir = x_R[-self.num_observations:, :]
                        x_R = x_R[:-self.num_observations, :]
                    ###########################  TRANSLATION  #############################
                    if do_translation:
                        x_C = self.perspective_taker.translate(x_R, self.Cs[i])
                    else:
                        x_C = x_R

                    if self.gestalten:
                        if self.dir_mag_gest:
                            x_C = torch.cat([x_C, dir, mag], dim=1)
                        else:
                            x_C = torch.cat([x_C, dir], dim=1)
                    #######################################################################

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_().to(
                            self.device)
                        at_c = state[1].clone().detach().requires_grad_().to(
                            self.device)
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current input
                ###########################  BINDING  #################################
                if do_binding:
                    bm = self.binder.scale_binding_matrix(
                        self.Bs[-1], self.scale_mode, self.scale_combo,
                        self.nxm_enhance, self.nxm_last_line_scale)
                    if self.nxm:
                        bm = bm[:-1]
                    x_B = self.binder.bind(o, bm)
                else:
                    x_B = o

                if self.gestalten:
                    if self.dir_mag_gest:
                        mag = x_B[:, -1].view(self.num_observations, 1)
                        x_B = x_B[:, :-1]
                    x_B = torch.cat([
                        x_B[:, :self.num_spatial_dimensions],
                        x_B[:, self.num_spatial_dimensions:]
                    ])
                ###########################  ROTATION  ################################
                if do_rotation:
                    if self.rotation_type == 'qrotate':
                        x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                        x_R = self.perspective_taker.rotate(x_B, rotmat)
                else:
                    x_R = x_B

                if self.gestalten:
                    dir = x_R[-self.num_observations:, :]
                    x_R = x_R[:-self.num_observations, :]
                ###########################  TRANSLATION  #############################
                if do_translation:
                    x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
                else:
                    x_C = x_R

                if self.gestalten:
                    if self.dir_mag_gest:
                        x_C = torch.cat([x_C, dir, mag], dim=1)
                    else:
                        x_C = torch.cat([x_C, dir], dim=1)
                #######################################################################

                x = self.preprocessor.convert_data_AT_to_LSTM(x_C)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_observations,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)

            inp_i = self.at_inputs[i]
            if do_binding:
                x_B = self.binder.bind(inp_i, bm)
            else:
                x_B = inp_i

            if self.gestalten:
                if self.dir_mag_gest:
                    mag = x_B[:, -1].view(self.num_observations, 1)
                    x_B = x_B[:, :-1]
                x_B = torch.cat([
                    x_B[:, :self.num_spatial_dimensions],
                    x_B[:, self.num_spatial_dimensions:]
                ])
            ###########################  ROTATION  ################################
            if do_rotation:
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(x_B, self.Rs[-1])
                else:
                    x_R = self.perspective_taker.rotate(x_B, rotmat)
            else:
                x_R = x_B

            if self.gestalten:
                dir = x_R[-self.num_observations:, :]
                x_R = x_R[:-self.num_observations, :]
            ###########################  TRANSLATION  #############################
            if do_translation:
                x_C = self.perspective_taker.translate(x_R, self.Cs[-1])
            else:
                x_C = x_R

            if self.gestalten:
                if self.dir_mag_gest:
                    x_i = torch.cat([x_C, dir, mag], dim=1)
                else:
                    x_i = torch.cat([x_C, dir], dim=1)
            #######################################################################

            at_final_inputs = torch.cat(
                (at_final_inputs, x_i.reshape(1, self.input_per_frame)), 0)

        ###########################  BINDING  #################################
        # get final binding matrix
        if do_binding:
            final_binding_matrix = self.binder.scale_binding_matrix(
                self.Bs[-1].clone().detach(), self.scale_mode,
                self.scale_combo)
            print(f'final binding matrix: {final_binding_matrix}')
            final_binding_entries = self.Bs[-1].clone().detach()
            print(f'final binding entires: {final_binding_entries}')

        else:
            final_binding_entries, final_binding_matrix = None, None

        ###########################  ROTATION  ################################
        # get final rotation matrix
        if do_rotation:
            if self.rotation_type == 'qrotate':
                final_rotation_values = self.Rs[0].clone().detach()
                # get final quaternion
                print(f'final quaternion: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.quaternion2rotmat(
                    final_rotation_values)
            else:
                final_rotation_values = [
                    self.Rs[0][i].clone().detach()
                    for i in range(self.num_input_dimensions)
                ]
                print(f'final euler angles: {final_rotation_values}')
                final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                    final_rotation_values[0], final_rotation_values[1],
                    final_rotation_values[2])

            print(f'final rotation matrix: \n{final_rotation_matrix}')

        else:
            final_rotation_matrix, final_rotation_values = None, None

        ###########################  TRANSLATION  #############################
        # get final translation bias
        if do_translation:
            final_translation_values = self.Cs[0].clone().detach()
            print(f'final translation bias: {final_translation_values}')

        else:
            final_translation_values = None

        #######################################################################

        return [
            at_final_inputs, at_final_predictions, final_binding_matrix,
            final_binding_entries, final_rotation_values,
            final_rotation_matrix, final_translation_values
        ]
    def run_inference(self, observations, grad_calculation):

        at_final_predictions = torch.tensor([]).to(self.device)
        at_final_inputs = torch.tensor([]).to(self.device)

        if self.rotation_type == 'qrotate':
            ## Rotation quaternion
            rq = self.perspective_taker.init_quaternion()
            # print(rq)

            for i in range(self.tuning_length + 1):
                quat = rq.clone().to(self.device)
                quat.requires_grad_()
                self.Rs.append(quat)

        elif self.rotation_type == 'eulrotate':
            ## Rotation euler angles
            # ra = perspective_taker.init_angles_()
            # ra = torch.Tensor([[309.89], [82.234], [95.765]])
            ra = torch.Tensor([[75.0], [6.0], [128.0]])
            print(ra)

            for i in range(self.tuning_length + 1):
                angles = []
                for j in range(self.num_input_dimensions):
                    angle = ra[j].clone()
                    angle.requires_grad_()
                    angles.append(angle)
                self.Rs.append(angles)

        else:
            print('ERROR: Received unknown rotation type!')
            exit()

        ## Core state
        # define scaler
        state_scaler = 0.95

        # init state
        at_h = torch.zeros(1, self.core_model.hidden_size).to(self.device)
        at_c = torch.zeros(1, self.core_model.hidden_size).to(self.device)

        at_h.requires_grad = True
        at_c.requires_grad = True

        init_state = (at_h, at_c)
        state = (init_state[0], init_state[1])

        ############################################################################
        ##########  FORWARD PASS  ##################################################

        for i in range(self.tuning_length):
            o = observations[self.obs_count].to(self.device)
            self.at_inputs = torch.cat((self.at_inputs,
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[i])
            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)
            self.at_states.append(state)
            self.at_predictions = torch.cat(
                (self.at_predictions,
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        ############################################################################
        ##########  ACTIVE TUNING ##################################################

        while self.obs_count < self.num_frames:
            # TODO folgendes evtl in function auslagern
            o = observations[self.obs_count].to(self.device)
            self.obs_count += 1

            if self.rotation_type == 'qrotate':
                x_R = self.perspective_taker.qrotate(o, self.Rs[-1])

            else:
                rotmat = self.perspective_taker.compute_rotation_matrix_(
                    self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                x_R = self.perspective_taker.rotate(o, rotmat)

            x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            ## Generate current prediction
            with torch.no_grad():
                state = self.at_states[-1]
                state = (state[0] * state_scaler, state[1] * state_scaler)
                new_prediction, state = self.core_model(x, state)

            ## For #tuning_cycles
            for cycle in range(self.tuning_cycles):
                print('----------------------------------------------')

                # Get prediction
                p = self.at_predictions[-1]

                # Calculate error
                loss = self.at_loss(p, x[0])

                # Propagate error back through tuning horizon
                loss.backward(retain_graph=True)

                self.at_losses.append(loss.clone().detach().cpu().numpy())
                print(f'frame: {self.obs_count} cycle: {cycle} loss: {loss}')

                # Update parameters
                with torch.no_grad():

                    ## get gradients
                    if self.rotation_type == 'qrotate':
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            self.R_grads[i] = self.Rs[i].grad
                    else:
                        for i in range(self.tuning_length + 1):
                            # save grads for all parameters in all time steps of tuning horizon
                            grad = []
                            for j in range(self.num_input_dimensions):
                                grad.append(self.Rs[i][j].grad)
                            self.R_grads[i] = torch.stack(grad)

                    # print(self.R_grads[self.tuning_length])

                    # Calculate overall gradients
                    if grad_calculation == 'lastOfTunHor':
                        ### version 1
                        grad_R = self.R_grads[0]
                    elif grad_calculation == 'meanOfTunHor':
                        ### version 2 / 3
                        grad_R = torch.mean(torch.stack(self.R_grads), dim=0)
                    elif grad_calculation == 'weightedInTunHor':
                        ### version 4
                        weighted_grads_R = [None] * (self.tuning_length + 1)
                        for i in range(self.tuning_length + 1):
                            weighted_grads_R[i] = np.power(self.grad_bias,
                                                           i) * self.R_grads[i]
                        grad_R = torch.mean(torch.stack(weighted_grads_R),
                                            dim=0)

                    # print(f'grad_R: {grad_R}')

                    grad_R = grad_R.to(self.device)
                    if self.rotation_type == 'qrotate':
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_quaternion(
                            self.Rs[0], grad_R, self.at_learning_rate,
                            self.r_momentum)
                        print(f'updated quaternion: {upd_R}')

                        # Compare quaternion values
                        # quat_loss = torch.sum(self.perspective_taker.qmul(self.ideal_quat, upd_R))
                        quat_loss = 2 * torch.arccos(
                            torch.abs(
                                torch.sum(torch.mul(self.ideal_quat, upd_R))))
                        quat_loss = torch.rad2deg(quat_loss)
                        print(f'loss of quaternion: {quat_loss}')
                        self.rv_losses.append(quat_loss)

                        # Compare quaternion angles
                        ang = torch.rad2deg(
                            self.perspective_taker.qeuler(upd_R, 'zyx'))
                        ang_diff = ang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of quaternion angles: {ang_loss} \nwith norm: {torch.norm(ang_loss)}'
                        )
                        self.ra_losses.append(torch.norm(ang_loss))

                        # Compute rotation matrix
                        rotmat = self.perspective_taker.quaternion2rotmat(
                            upd_R)

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            self.Rs[i].requires_grad = False
                            self.Rs[i].grad.data.zero_()

                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            quat = upd_R.clone()
                            quat.requires_grad_()
                            self.Rs[i] = quat

                    else:
                        # Update parameters in time step t-H with saved gradients
                        upd_R = self.perspective_taker.update_rotation_angles_(
                            self.Rs[0], grad_R, self.at_learning_rate)
                        print(f'updated angles: {upd_R}')

                        # Save rotation angles
                        rotang = torch.stack(upd_R)
                        # angles:
                        ang_diff = rotang - self.ideal_angle
                        ang_loss = 2 - (torch.cos(torch.deg2rad(ang_diff)) + 1)
                        print(
                            f'loss of rotation angles: \n  {ang_loss}, \n  with norm {torch.norm(ang_loss)}'
                        )
                        self.rv_losses.append(torch.norm(ang_loss))
                        # Compute rotation matrix
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            upd_R[0], upd_R[1], upd_R[2])[0]

                        # Zero out gradients for all parameters in all time steps of tuning horizon
                        for i in range(self.tuning_length + 1):
                            for j in range(self.num_input_dimensions):
                                self.Rs[i][j].requires_grad = False
                                self.Rs[i][j].grad.data.zero_()

                        # print(Rs[0])
                        # Update all parameters for all time steps
                        for i in range(self.tuning_length + 1):
                            angles = []
                            for j in range(3):
                                angle = upd_R[j].clone()
                                angle.requires_grad_()
                                angles.append(angle)
                            self.Rs[i] = angles
                        # print(Rs[0])

                    # Calculate and save rotation losses
                    # matrix:
                    # mat_loss = self.mse(
                    #     (torch.mm(self.ideal_rotation, torch.transpose(rotmat, 0, 1))),
                    #     self.identity_matrix
                    # )
                    dif_R = torch.mm(self.ideal_rotation,
                                     torch.transpose(rotmat, 0, 1))
                    mat_loss = torch.arccos(0.5 * (torch.trace(dif_R) - 1))
                    mat_loss = torch.rad2deg(mat_loss)

                    print(f'loss of rotation matrix: {mat_loss}')
                    self.rm_losses.append(mat_loss)

                    # print(Rs[0])

                    # Initial state
                    g_h = at_h.grad.to(self.device)
                    g_c = at_c.grad.to(self.device)

                    upd_h = init_state[0] - self.at_learning_rate_state * g_h
                    upd_c = init_state[1] - self.at_learning_rate_state * g_c

                    at_h.data = upd_h.clone().detach().requires_grad_()
                    at_c.data = upd_c.clone().detach().requires_grad_()

                    at_h.grad.data.zero_()
                    at_c.grad.data.zero_()

                    # state_optimizer.step()
                    # print(f'updated init_state: {init_state}')

                ## REORGANIZE FOR MULTIPLE CYCLES!!!!!!!!!!!!!

                # forward pass from t-H to t with new parameters
                # Update init state???
                init_state = (at_h, at_c)
                state = (init_state[0], init_state[1])
                self.at_predictions = torch.tensor([]).to(self.device)
                for i in range(self.tuning_length):

                    if self.rotation_type == 'qrotate':
                        x_R = self.perspective_taker.qrotate(
                            self.at_inputs[i], self.Rs[i])
                    else:
                        rotmat = self.perspective_taker.compute_rotation_matrix_(
                            self.Rs[i][0], self.Rs[i][1], self.Rs[i][2])
                        x_R = self.perspective_taker.rotate(
                            self.at_inputs[i], rotmat)

                    x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

                    state = (state[0] * state_scaler, state[1] * state_scaler)
                    upd_prediction, state = self.core_model(x, state)
                    self.at_predictions = torch.cat(
                        (self.at_predictions,
                         upd_prediction.reshape(1, self.input_per_frame)), 0)

                    # for last tuning cycle update initial state to track gradients
                    if cycle == (self.tuning_cycles - 1) and i == 0:
                        with torch.no_grad():
                            final_prediction = self.at_predictions[0].clone(
                            ).detach().to(self.device)
                            final_input = x.clone().detach().to(self.device)

                        at_h = state[0].clone().detach().requires_grad_()
                        at_c = state[1].clone().detach().requires_grad_()
                        init_state = (at_h, at_c)
                        state = (init_state[0], init_state[1])

                    self.at_states[i] = state

                # Update current rotation
                if self.rotation_type == 'qrotate':
                    x_R = self.perspective_taker.qrotate(o, self.Rs[-1])
                else:
                    rotmat = self.perspective_taker.compute_rotation_matrix_(
                        self.Rs[-1][0], self.Rs[-1][1], self.Rs[-1][2])
                    x_R = self.perspective_taker.rotate(o, rotmat)

                x = self.preprocessor.convert_data_AT_to_LSTM(x_R)

            # END tuning cycle

            ## Generate updated prediction
            state = self.at_states[-1]
            state = (state[0] * state_scaler, state[1] * state_scaler)
            new_prediction, state = self.core_model(x, state)

            ## Reorganize storage variables
            # observations
            at_final_inputs = torch.cat(
                (at_final_inputs, final_input.reshape(
                    1, self.input_per_frame)), 0)
            self.at_inputs = torch.cat((self.at_inputs[1:],
                                        o.reshape(1, self.num_input_features,
                                                  self.num_input_dimensions)),
                                       0)

            # predictions
            at_final_predictions = torch.cat(
                (at_final_predictions,
                 final_prediction.reshape(1, self.input_per_frame)), 0)
            self.at_predictions = torch.cat(
                (self.at_predictions[1:],
                 new_prediction.reshape(1, self.input_per_frame)), 0)

        # END active tuning

        # store rest of predictions in at_final_predictions
        for i in range(self.tuning_length):
            at_final_predictions = torch.cat(
                (at_final_predictions, self.at_predictions[i].reshape(
                    1, self.input_per_frame)), 0)
            if self.rotation_type == 'qrotate':
                x_i = self.perspective_taker.qrotate(self.at_inputs[i],
                                                     self.Rs[-1])
            else:
                x_i = self.perspective_taker.rotate(self.at_inputs[i], rotmat)
            at_final_inputs = torch.cat(
                (at_final_inputs, x_i.reshape(1, self.input_per_frame)), 0)

        # get final rotation matrix
        if self.rotation_type == 'qrotate':
            final_rotation_values = self.Rs[0].clone().detach()
            # get final quaternion
            print(f'final quaternion: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.quaternion2rotmat(
                final_rotation_values)

        else:
            final_rotation_values = [
                self.Rs[0][i].clone().detach()
                for i in range(self.num_input_dimensions)
            ]
            print(f'final euler angles: {final_rotation_values}')
            final_rotation_matrix = self.perspective_taker.compute_rotation_matrix_(
                final_rotation_values[0], final_rotation_values[1],
                final_rotation_values[2])

        print(f'final rotation matrix: \n{final_rotation_matrix}')

        return at_final_inputs, at_final_predictions, final_rotation_values, final_rotation_matrix
Beispiel #25
0
def cubo2rod(cu):
    """Cubochoric vector to Rodrigues-Frank vector.
        Quaternion returned in form s, <x,y,z>
        where s is real component, and <x,y,z>
        is imaginary vector component
    """
    """
        Step 1: Cubochoric vector to homochoric vector.

        References
        ----------
        D. Roşca et al., Modelling and Simulation in Materials Science and Engineering 22:075013, 2014
        https://doi.org/10.1088/0965-0393/22/7/075013

        """

    # get pyramid and scale by grid parameter ratio
    XYZ = torch.gather(cu, -1, _get_tensor_pyramid_order(cu, 'forward')) * sc
    order = torch.le(torch.abs(XYZ[..., 1:2]), torch.abs(XYZ[..., 0:1]))
    q = math.pi / 12.0 * torch.where(order, XYZ[..., 1:2], XYZ[..., 0:1]) \
        / torch.where(order, XYZ[..., 0:1], XYZ[..., 1:2])
    c = torch.cos(q)
    s = torch.sin(q)
    q = R1 * 2.0 ** 0.25 / beta / torch.sqrt(math.sqrt(2.0) - c) \
        * torch.where(order, XYZ[..., 0:1], XYZ[..., 1:2])

    T = torch.cat(((math.sqrt(2.0) * c - 1.0), math.sqrt(2.0) * s), dim=-1) * q

    # transform to sphere grid (inverse Lambert)
    c = torch.sum(T**2, -1, keepdim=True)
    s = c * math.pi / 24.0 / XYZ[..., 2:3]**2
    c = c * math.sqrt(math.pi / 24.0) / XYZ[..., 2:3]
    q = torch.sqrt(1.0 - s)

    ho = torch.where(
        torch.isclose(torch.sum(torch.abs(XYZ[..., 0:2]), -1, keepdim=True),
                      _precision_check(0.0, XYZ.dtype),
                      rtol=0.0,
                      atol=1.0e-16),
        torch.cat((torch.zeros_like(
            XYZ[..., 0:2]), math.sqrt(6.0 / math.pi) * XYZ[..., 2:3]),
                  dim=-1),
        torch.cat((torch.where(order, T[..., 0:1], T[..., 1:2]) * q,
                   torch.where(order, T[..., 1:2], T[..., 0:1]) * q,
                   math.sqrt(6.0 / math.pi) * XYZ[..., 2:3] - c),
                  dim=-1))

    ho[torch.isclose(torch.sum(torch.abs(cu), -1),
                     _precision_check(0.0, cu.dtype),
                     rtol=0.0,
                     atol=1.0e-16)] = 0.0  # warning
    ho = torch.gather(ho, -1, _get_tensor_pyramid_order(cu, 'backward'))

    # return ho # here for homochoric
    """Step 2: Homochoric vector to axis angle pair."""
    tfit = [
        +1.0000000000018852, -0.5000000002194847, -0.024999992127593126,
        -0.003928701544781374, -0.0008152701535450438, -0.0002009500426119712,
        -0.00002397986776071756, -0.00008202868926605841,
        +0.00012448715042090092, -0.0001749114214822577,
        +0.0001703481934140054, -0.00012062065004116828,
        +0.000059719705868660826, -0.00001980756723965647,
        +0.000003953714684212874, -0.00000036555001439719544
    ]
    hmag_squared = torch.sum(ho**2., -1, keepdim=True)

    hm = torch.clone(
        hmag_squared)  # use detach() for decoupled autograd relationship

    s = tfit[0] + tfit[1] * hmag_squared
    for i in range(2, 16):
        hm *= hmag_squared
        s += tfit[i] * hm

    # with np.errstate(invalid='ignore'):
    ax = torch.where(
        torch.lt(torch.abs(hmag_squared),
                 torch.tensor(1.e-8)).expand(ho.shape[:-1] + (4, )),
        _precision_check([0.0, 0.0, 1.0, 0.0], ho.dtype, ho.device),
        torch.cat((ho / torch.sqrt(hmag_squared),
                   2.0 * torch.arccos(torch.clip(s, -1.0, 1.0))),
                  dim=-1))

    # return ax # here for axis angle pair
    """Step 3: Axis angle pair to Rodrigues-Frank vector."""
    ro = torch.cat((ax[..., :3],
                    torch.where(
                        torch.isclose(ax[..., 3:4],
                                      _precision_check(math.pi, ax.dtype),
                                      atol=1.e-15,
                                      rtol=.0),
                        _precision_check(float('inf'), ax.dtype, ax.device),
                        torch.tan(ax[..., 3:4] * 0.5))),
                   dim=-1)
    ro[torch.lt(torch.abs(ax[..., 3]),
                1.e-6)] = _precision_check([.0, .0, P, .0], ax.dtype,
                                           ax.device)

    return ro
Beispiel #26
0
def dihedral_angle(mesh, edge_points):
    normals_a = get_normals(mesh, edge_points, 0)
    normals_b = get_normals(mesh, edge_points, 3)
    dot = torch.sum(normals_a * normals_b, dim=-1).clip(-1, 1)
    angles = np.pi - torch.arccos(dot)
    return angles.unsqueeze(1)
Beispiel #27
0
 def pointwise_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
     t = torch.tensor([-1, -2, 3], dtype=torch.int8)
     s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
     f = torch.zeros(3)
     g = torch.tensor([-1, 0, 1])
     w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
     return (
         torch.abs(torch.tensor([-1, -2, 3])),
         torch.absolute(torch.tensor([-1, -2, 3])),
         torch.acos(a),
         torch.arccos(a),
         torch.acosh(a.uniform_(1.0, 2.0)),
         torch.add(a, 20),
         torch.add(a, torch.randn(4, 1), alpha=10),
         torch.addcdiv(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.addcmul(torch.randn(1, 3),
                       torch.randn(3, 1),
                       torch.randn(1, 3),
                       value=0.1),
         torch.angle(a),
         torch.asin(a),
         torch.arcsin(a),
         torch.asinh(a),
         torch.arcsinh(a),
         torch.atan(a),
         torch.arctan(a),
         torch.atanh(a.uniform_(-1.0, 1.0)),
         torch.arctanh(a.uniform_(-1.0, 1.0)),
         torch.atan2(a, a),
         torch.bitwise_not(t),
         torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
         torch.ceil(a),
         torch.clamp(a, min=-0.5, max=0.5),
         torch.clamp(a, min=0.5),
         torch.clamp(a, max=0.5),
         torch.clip(a, min=-0.5, max=0.5),
         torch.conj(a),
         torch.copysign(a, 1),
         torch.copysign(a, b),
         torch.cos(a),
         torch.cosh(a),
         torch.deg2rad(
             torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0,
                                                              -90.0]])),
         torch.div(a, b),
         torch.divide(a, b, rounding_mode="trunc"),
         torch.divide(a, b, rounding_mode="floor"),
         torch.digamma(torch.tensor([1.0, 0.5])),
         torch.erf(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
         torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
         torch.exp(torch.tensor([0.0, math.log(2.0)])),
         torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
         torch.expm1(torch.tensor([0.0, math.log(2.0)])),
         torch.fake_quantize_per_channel_affine(
             torch.randn(2, 2, 2),
             (torch.randn(2) + 1) * 0.05,
             torch.zeros(2),
             1,
             0,
             255,
         ),
         torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
         torch.float_power(torch.randint(10, (4, )), 2),
         torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4,
                                                             -5])),
         torch.floor(a),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
         # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
         torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
         torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.frac(torch.tensor([1.0, 2.5, -3.2])),
         torch.randn(4, dtype=torch.cfloat).imag,
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
         torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
         torch.lerp(torch.arange(1.0, 5.0),
                    torch.empty(4).fill_(10), 0.5),
         torch.lerp(
             torch.arange(1.0, 5.0),
             torch.empty(4).fill_(10),
             torch.full_like(torch.arange(1.0, 5.0), 0.5),
         ),
         torch.lgamma(torch.arange(0.5, 2, 0.5)),
         torch.log(torch.arange(5) + 10),
         torch.log10(torch.rand(5)),
         torch.log1p(torch.randn(5)),
         torch.log2(torch.rand(5)),
         torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]),
                         torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]),
                          torch.tensor([-1, -2, -3])),
         torch.logical_and(r, s),
         torch.logical_and(r.double(), s.double()),
         torch.logical_and(r.double(), s),
         torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
         torch.logical_not(
             torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
         torch.logical_not(
             torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
             out=torch.empty(3, dtype=torch.int16),
         ),
         torch.logical_or(r, s),
         torch.logical_or(r.double(), s.double()),
         torch.logical_or(r.double(), s),
         torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logical_xor(r, s),
         torch.logical_xor(r.double(), s.double()),
         torch.logical_xor(r.double(), s),
         torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
         torch.logit(torch.rand(5), eps=1e-6),
         torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
         torch.i0(torch.arange(5, dtype=torch.float32)),
         torch.igamma(a, b),
         torch.igammac(a, b),
         torch.mul(torch.randn(3), 100),
         torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
         torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
         torch.tensor([float("nan"),
                       float("inf"), -float("inf"), 3.14]),
         torch.nan_to_num(w),
         torch.nan_to_num(w, nan=2.0),
         torch.nan_to_num(w, nan=2.0, posinf=1.0),
         torch.neg(torch.randn(5)),
         # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
         torch.polygamma(1, torch.tensor([1.0, 0.5])),
         torch.polygamma(2, torch.tensor([1.0, 0.5])),
         torch.polygamma(3, torch.tensor([1.0, 0.5])),
         torch.polygamma(4, torch.tensor([1.0, 0.5])),
         torch.pow(a, 2),
         torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
         torch.rad2deg(
             torch.tensor([[3.142, -3.142], [6.283, -6.283],
                           [1.570, -1.570]])),
         torch.randn(4, dtype=torch.cfloat).real,
         torch.reciprocal(a),
         torch.remainder(torch.tensor([-3.0, -2.0]), 2),
         torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
         torch.round(a),
         torch.rsqrt(a),
         torch.sigmoid(a),
         torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sgn(a),
         torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
         torch.sin(a),
         torch.sinc(a),
         torch.sinh(a),
         torch.sqrt(a),
         torch.square(a),
         torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
         torch.tan(a),
         torch.tanh(a),
         torch.trunc(a),
         torch.xlogy(f, g),
         torch.xlogy(f, g),
         torch.xlogy(f, 4),
         torch.xlogy(2, g),
     )
Beispiel #28
0
def arccos(a: Numeric):
    return torch.arccos(a)
Beispiel #29
0
def slerp(a, b, t):
    omega = torch.arccos(torch.dot(a / a.norm(), b / b.norm()))
    so = torch.sin(omega)
    return (((1.0 - t) * omega).sin() * a + (t * omega).sin() * b) / so
Beispiel #30
0
def q_to_axisangle(q):
    """Converts a Quaternion to axis-angle representation."""
    w, v = q[0], q[1:]
    theta = torch.arccos(w) * 2.0

    return v / torch.norm(v), theta