def _normalize_logits(self, unnorm: torch.tensor, dim: int = -1): unnorm = unnorm.clamp(-100, 100) max_u = unnorm.max(dim, keepdim=True).values logsumexp = (unnorm - max_u).exp().sum(dim, keepdim=True).log() norm = unnorm - (max_u + logsumexp) if not ((norm.exp().sum(dim) - 1).abs() < 1e-5).all(): import pdb pdb.set_trace() assert ((norm.exp().sum(dim) - 1).abs() < 1e-5).all( ), f'{unnorm.min()}, {unnorm.max()}, {unnorm.mean()}, {unnorm.std()}' return norm
def forward(self, x: Tensor) -> (Tensor, Tensor, Tensor): scale_factor = 1 / (x.numel() * self.Qp)**0.5 scale = self.grad_scale(self.s, scale_factor) x = x / scale x = x.clamp(self.Qn, self.Qp) x_bar = self.round_pass(x) x_hat = x_bar * scale return x_hat, scale, torch.Tensor([self.bit_width])
def gem(x: torch.tensor, p: int = 3, eps: float = 1e-6) -> torch.tensor: """Generalized Mean Pooling. Args: x (torch.tensor): input features, expected shapes - BxCxHxW p (int, optional): normalization degree. Defaults is `3`. eps (float, optional): minimum value to use in x. Defaults is `1e-6`. Returns: tensor with shapes - BxCx1x1 """ return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)
def __call__(self, distance_matrix: torch.tensor): # continuous distances (personalized page rank) if distance_matrix.dtype in [torch.float32, torch.float16, torch.float64]: dist_values = distance_matrix.reshape(-1).cpu().numpy() if UNREACHABLE in dist_values: # We fix UNREACHABLE to be the bin with index 0 possible_distances = torch.tensor( hist_by_area(dist_values[dist_values < UNREACHABLE], self.n_bins - 2), dtype=torch.float32) possible_distances = torch.cat( [torch.tensor([UNREACHABLE], dtype=torch.float32), possible_distances]) dist_values = torch.tensor(dist_values, dtype=torch.float32) indices = (dist_values[:, None] > possible_distances[None, :]).sum(-1).reshape(distance_matrix.shape) # Shift all indices by 1, as UNREACHABLE will be bin 0 indices = (indices + 1) % self.n_bins else: # We assume there is already many 0 values in the distance matrix, that will be in bin 0 possible_distances = torch.tensor(hist_by_area(dist_values, self.n_bins - 1), dtype=torch.float32) dist_values = torch.tensor(dist_values, dtype=torch.float32) # It is very important that both dist_values and possible_distances are float32 # Otherwise, it can happen that a > b if a = b, just because a.dtype = float64, b.dtype=32 indices = (dist_values[:, None] > possible_distances[None, :]).sum(-1).reshape(distance_matrix.shape) # discrete distances (shortest paths) elif distance_matrix.dtype in [torch.long, torch.int32, torch.int16, torch.int8, torch.int, torch.bool]: max_distance = 1000 dist_values = distance_matrix.clamp(-max_distance, max_distance) unreachable_ixs = abs(dist_values) == max_distance dist_values[unreachable_ixs] = UNREACHABLE unique_values = dist_values.unique() num_unique_values = len(unique_values) num_unique_values += 1 if UNREACHABLE not in dist_values else 0 num_unique_values = min(num_unique_values, self.n_bins) dist_values = dist_values.reshape(-1).cpu().numpy() values = dist_values[abs(dist_values) < max_distance] if num_unique_values < self.n_bins: # if there are fewer unique values than bins, we do not need to use sophisticated binning # possible_distances = hist_by_area(values, num_unique_values - 1) possible_distances = torch.sort(unique_values)[0] if UNREACHABLE in possible_distances: possible_distances = possible_distances[:-1] else: if isinstance(self.trans_func, EqualBinning) and self.n_fixed == 0: # Also resort to regular binning possible_distances = hist_by_area(values, num_unique_values - 1) else: # Calculate bins where the area of each bin is governed by trans_func. # When using exponential binning, this means that bin area will grow exponentially in distance # to the bin containing the value 0 possible_distances = calculate_bins(values, num_unique_values - 1, self.n_fixed, hist_by_area, self.trans_func) if isinstance(possible_distances, torch.Tensor): possible_distances = possible_distances.type(torch.float32) else: possible_distances = torch.tensor(possible_distances, dtype=torch.float32) dist_values = torch.tensor(dist_values) neg_bins, pos_bins = possible_distances[possible_distances < 0], possible_distances[possible_distances >= 0] neg_vals, pos_vals = dist_values[dist_values < 0], dist_values[dist_values >= 0] neg_ixs = (neg_vals[:, None] >= neg_bins[None]).sum(-1) - 1 pos_ixs = len(pos_bins) - (pos_vals[:, None] <= pos_bins[None]).sum(-1) indices = torch.cat([neg_ixs, pos_ixs + len(neg_bins)]) indices = torch.zeros_like(dist_values) indices[dist_values < 0] = neg_ixs indices[dist_values >= 0] = pos_ixs + len(neg_bins) # Shift all indices by 1, as UNREACHABLE will be bin 0 indices += 1 indices = indices.reshape(distance_matrix.shape) indices[unreachable_ixs] = 0 possible_distances = torch.cat([torch.tensor([UNREACHABLE], dtype=torch.float32), possible_distances]) if num_unique_values < self.n_bins: bin_padding_tensor = torch.tensor([BIN_PADDING for i in range(self.n_bins - num_unique_values)], dtype=torch.float32) possible_distances = torch.cat([possible_distances, bin_padding_tensor]) else: raise NotImplementedError(f"Binning for tensors of type {distance_matrix.dtype} is not impolemented") assert 0 <= indices.min() and indices.max() <= self.n_bins, f"Indices have to be in [0, {self.n_bins}] but got [{indices.min()}, {indices.max()}]" assert len( possible_distances) == self.n_bins, f"Calculated amount of bins ({len(possible_distances)}) differs from requested amount ({self.n_bins})" return indices, possible_distances