def negative_norm( x: torch.FloatTensor, p: Union[str, int] = 2, power_norm: bool = False, ) -> torch.FloatTensor: """Evaluate negative norm of a vector. :param x: shape: (batch_size, num_heads, num_relations, num_tails, dim) The vectors. :param p: The p for the norm. cf. torch.norm. :param power_norm: Whether to return $|x-y|_p^p$, cf. https://github.com/pytorch/pytorch/issues/28119 :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ if power_norm: assert not isinstance(p, str) return -(x.abs()**p).sum(dim=-1) if torch.is_complex(x): assert not isinstance(p, str) # workaround for complex numbers: manually compute norm return -(x.abs()**p).sum(dim=-1)**(1 / p) return -x.norm(p=p, dim=-1)
def powersum_norm(x: torch.FloatTensor, p: float, dim: Optional[int], normalize: bool) -> torch.FloatTensor: """Return the power sum norm.""" value = x.abs().pow(p).sum(dim=dim) if not normalize: return value dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device) return value / dim
def product_normalize(x: torch.FloatTensor, dim: int = -1) -> torch.FloatTensor: r"""Normalize a tensor along a given dimension so that the geometric mean is 1.0. :param x: shape: s An input tensor :param dim: the dimension along which to normalize the tensor :return: shape: s An output tensor where the given dimension is normalized to have a geometric mean of 1.0. """ return x / at_least_eps(at_least_eps(x.abs()).log().mean(dim=dim, keepdim=True).exp())
def compute_mask(self, param: torch.FloatTensor, default_mask: torch.Tensor) -> torch.Tensor: """Assume the parameter is a matrix.""" if self.prune_weights: norms = param.abs() else: norms = param.norm(dim=1, p=2) / sqrt(param.size(1)) norms = norms.unsqueeze(dim=1) if self.prune_mode == LEAST: return default_mask * (norms > self.threshold) elif self.prune_mode == MOST: return default_mask * (norms <= self.threshold) else: return NotImplemented
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 value = x.abs().pow(self.p).sum(dim=self.dim).mean() if not self.normalize: return value dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device) return value / dim
def add_order_cost(C: torch.FloatTensor, a: torch.FloatTensor, b: torch.FloatTensor, order_lambda: float) -> torch.FloatTensor: """ Adds diagonal order preserving cost to a cost matrix. :param C: FloatTensor (n x m or num_batches x n x m) with costs. Note: The device and dtype of C are used for all other variables. :param a: Row marginals (num_batches x n). :param b: Column marginals (num_batches x m). :param order_lambda: Weight for diagonal order preserving cost (0 = no weight, 1 = all weight). :return: The cost matrix with the order preserving cost added in. """ # Checks assert len(C.shape) in [2, 3] batched = len(C.shape) == 3 assert len(a.shape) == len(b.shape) == (2 if batched else 1) assert 0.0 <= order_lambda <= 1.0 # Return C if not adding order cost if order_lambda == 0.0: return C # Setup dtype = C.dtype device = C.device # Compute order preserving cost I = torch.arange(C.size(-2), dtype=dtype, device=device).unsqueeze(dim=1) J = torch.arange(C.size(-1), dtype=dtype, device=device).unsqueeze(dim=0) # Compute masks mask, mask_n, mask_m = compute_masks(C=C, a=a, b=b) # Compute N and M N = mask_n.sum(dim=-1, keepdim=True) M = mask_m.sum(dim=-1, keepdim=True) if batched: I = I.unsqueeze(dim=0).repeat(C.size(0), 1, C.size(2)) J = J.unsqueeze(dim=0).repeat(C.size(0), C.size(1), 1) N = N.unsqueeze(dim=-1) M = M.unsqueeze(dim=-1) else: I = I.repeat(1, C.size(1)) J = J.repeat(C.size(0), 1) D = torch.abs(I / N - J / M) # Match average magnitudes so costs are on the same scale mask_sum = mask.sum(dim=-1).sum(dim=-1) C_magnitude = (mask * C.abs()).sum(dim=-1).sum(dim=-1) / mask_sum D_magnitude = (mask * D.abs()).sum(dim=-1).sum(dim=-1) / mask_sum D_magnitude[D_magnitude == 0] = 1 # prevent divide by 0 D *= (C_magnitude / D_magnitude).unsqueeze(dim=-1).unsqueeze(dim=-1) # Add order preserving cost C = (1 - order_lambda) * C + order_lambda * D C *= mask return C