def gather_targets(self, targets: torch.Tensor): """ Do a gather over all embeddings, so we can compute the loss. Final shape is like: (batch_size * num_gpus) x embedding_dim """ if torch.distributed.is_available() and torch.distributed.is_initialized(): # gather all embeddings. targets_gathered = gather_from_all(targets) else: targets_gathered = targets return targets_gathered
def update_memory(self, embedding, y): momentum = self.params[3].item() if torch.distributed.is_available( ) and torch.distributed.is_initialized(): # gather all embeddings. embedding_gathered = gather_from_all(embedding) y_gathered = gather_from_all(y) else: embedding_gathered = embedding y_gathered = y # update memory with torch.no_grad(): # Assumption: memory_size >= y.max() assert y_gathered.max() < self.memory.shape[0], ( f"Memory bank {self.memory.shape} is not sufficient " f"to hold index: {y_gathered.max()}") l_pos = torch.index_select(self.memory, 0, y_gathered.view(-1)) l_pos.mul_(momentum) l_pos.add_(torch.mul(embedding_gathered, 1 - momentum)) updated_l = nn.functional.normalize(l_pos, p=2, dim=1) self.memory.index_copy_(0, y_gathered, updated_l)
def compute_partition_function(self, out): num_items = self.memory.size(0) with torch.no_grad(): batch_mean = out.mean() # NOTE: this relies of "mean" computation being stable and deterministic # across all nodes. Could be replaced with smarter ways. if torch.distributed.is_available( ) and torch.distributed.is_initialized(): batch_mean_gathered = gather_from_all(batch_mean) all_batch_mean = batch_mean_gathered.mean().squeeze().item() else: all_batch_mean = batch_mean.item() self.params[2] = all_batch_mean * num_items Z = self.params[2].clone().detach().item() rank = get_rank() logging.info(f"Rank: {rank}; Normalization constant Z is set to {Z}")
def cluster_memory(self): self.start_idx = 0 j = 0 with torch.no_grad(): for i_K, K in enumerate(self.num_clusters): # run distributed k-means # init centroids with elements from memory bank of rank 0 centroids = torch.empty( K, self.embedding_dim).cuda(non_blocking=True) if get_rank() == 0: random_idx = torch.randperm( len(self.local_memory_embeddings[j]))[:K] assert len(random_idx ) >= K, "please reduce the number of centroids" centroids = self.local_memory_embeddings[j][random_idx] dist.broadcast(centroids, 0) for n_iter in range(self.nmb_kmeans_iters + 1): # E step dot_products = torch.mm(self.local_memory_embeddings[j], centroids.t()) _, assignments = dot_products.max(dim=1) # finish if n_iter == self.nmb_kmeans_iters: break # M step where_helper = get_indices_sparse( assignments.cpu().numpy()) counts = torch.zeros(K).cuda(non_blocking=True).int() emb_sums = torch.zeros( K, self.embedding_dim).cuda(non_blocking=True) for k in range(len(where_helper)): if len(where_helper[k][0]) > 0: emb_sums[k] = torch.sum( self.local_memory_embeddings[j][where_helper[k] [0]], dim=0, ) counts[k] = len(where_helper[k][0]) all_reduce_sum(counts) mask = counts > 0 all_reduce_sum(emb_sums) centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze( 1) # normalize centroids centroids = nn.functional.normalize(centroids, dim=1, p=2) getattr(self, "centroids" + str(i_K)).copy_(centroids) # gather the assignments assignments_all = gather_from_all(assignments) indexes_all = gather_from_all(self.local_memory_index) self.assignments[i_K] = -100 self.assignments[i_K][indexes_all] = assignments_all j = (j + 1) % self.nmb_mbs logging.info(f"Rank: {get_rank()}, clustering of the memory bank done")