Example #1
0
def negative_norm(
    x: torch.FloatTensor,
    p: Union[str, int] = 2,
    power_norm: bool = False,
) -> torch.FloatTensor:
    """Evaluate negative norm of a vector.

    :param x: shape: (batch_size, num_heads, num_relations, num_tails, dim)
        The vectors.
    :param p:
        The p for the norm. cf. torch.norm.
    :param power_norm:
        Whether to return $|x-y|_p^p$, cf. https://github.com/pytorch/pytorch/issues/28119

    :return: shape: (batch_size, num_heads, num_relations, num_tails)
        The scores.
    """
    if power_norm:
        assert not isinstance(p, str)
        return -(x.abs()**p).sum(dim=-1)

    if torch.is_complex(x):
        assert not isinstance(p, str)
        # workaround for complex numbers: manually compute norm
        return -(x.abs()**p).sum(dim=-1)**(1 / p)

    return -x.norm(p=p, dim=-1)
Example #2
0
def powersum_norm(x: torch.FloatTensor, p: float, dim: Optional[int], normalize: bool) -> torch.FloatTensor:
    """Return the power sum norm."""
    value = x.abs().pow(p).sum(dim=dim)
    if not normalize:
        return value
    dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device)
    return value / dim
Example #3
0
def product_normalize(x: torch.FloatTensor, dim: int = -1) -> torch.FloatTensor:
    r"""Normalize a tensor along a given dimension so that the geometric mean is 1.0.

    :param x: shape: s
        An input tensor
    :param dim:
        the dimension along which to normalize the tensor

    :return: shape: s
        An output tensor where the given dimension is normalized to have a geometric mean of 1.0.
    """
    return x / at_least_eps(at_least_eps(x.abs()).log().mean(dim=dim, keepdim=True).exp())
Example #4
0
    def compute_mask(self, param: torch.FloatTensor, default_mask: torch.Tensor) -> torch.Tensor:
        """Assume the parameter is a matrix."""
        if self.prune_weights:
            norms = param.abs()
        else:
            norms = param.norm(dim=1, p=2) / sqrt(param.size(1))
            norms = norms.unsqueeze(dim=1)

        if self.prune_mode == LEAST:
            return default_mask * (norms > self.threshold)
        elif self.prune_mode == MOST:
            return default_mask * (norms <= self.threshold)
        else:
            return NotImplemented
Example #5
0
 def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:  # noqa: D102
     value = x.abs().pow(self.p).sum(dim=self.dim).mean()
     if not self.normalize:
         return value
     dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device)
     return value / dim
def add_order_cost(C: torch.FloatTensor, a: torch.FloatTensor,
                   b: torch.FloatTensor,
                   order_lambda: float) -> torch.FloatTensor:
    """
    Adds diagonal order preserving cost to a cost matrix.

    :param C: FloatTensor (n x m or num_batches x n x m) with costs.
    Note: The device and dtype of C are used for all other variables.
    :param a: Row marginals (num_batches x n).
    :param b: Column marginals (num_batches x m).
    :param order_lambda: Weight for diagonal order preserving cost (0 = no weight, 1 = all weight).
    :return: The cost matrix with the order preserving cost added in.
    """
    # Checks
    assert len(C.shape) in [2, 3]
    batched = len(C.shape) == 3
    assert len(a.shape) == len(b.shape) == (2 if batched else 1)
    assert 0.0 <= order_lambda <= 1.0

    # Return C if not adding order cost
    if order_lambda == 0.0:
        return C

    # Setup
    dtype = C.dtype
    device = C.device

    # Compute order preserving cost
    I = torch.arange(C.size(-2), dtype=dtype, device=device).unsqueeze(dim=1)
    J = torch.arange(C.size(-1), dtype=dtype, device=device).unsqueeze(dim=0)

    # Compute masks
    mask, mask_n, mask_m = compute_masks(C=C, a=a, b=b)

    # Compute N and M
    N = mask_n.sum(dim=-1, keepdim=True)
    M = mask_m.sum(dim=-1, keepdim=True)

    if batched:
        I = I.unsqueeze(dim=0).repeat(C.size(0), 1, C.size(2))
        J = J.unsqueeze(dim=0).repeat(C.size(0), C.size(1), 1)

        N = N.unsqueeze(dim=-1)
        M = M.unsqueeze(dim=-1)
    else:
        I = I.repeat(1, C.size(1))
        J = J.repeat(C.size(0), 1)

    D = torch.abs(I / N - J / M)

    # Match average magnitudes so costs are on the same scale
    mask_sum = mask.sum(dim=-1).sum(dim=-1)
    C_magnitude = (mask * C.abs()).sum(dim=-1).sum(dim=-1) / mask_sum
    D_magnitude = (mask * D.abs()).sum(dim=-1).sum(dim=-1) / mask_sum
    D_magnitude[D_magnitude == 0] = 1  # prevent divide by 0
    D *= (C_magnitude / D_magnitude).unsqueeze(dim=-1).unsqueeze(dim=-1)

    # Add order preserving cost
    C = (1 - order_lambda) * C + order_lambda * D
    C *= mask

    return C