def sync_state(self): # Communications self._curr_correct_predictions_k = all_reduce_sum( self._curr_correct_predictions_k) self._curr_sample_count = all_reduce_sum(self._curr_sample_count) # Store results self._total_correct_predictions_k += self._curr_correct_predictions_k self._total_sample_count += self._curr_sample_count # Reset values until next sync self._curr_correct_predictions_k.zero_() self._curr_sample_count.zero_()
def sync_state(self): # Communications self._curr_correct_predictions_k = all_reduce_sum( self._curr_correct_predictions_k) self._curr_correct_targets = all_reduce_sum(self._curr_correct_targets) # Store results self._total_correct_predictions_k += self._curr_correct_predictions_k self._total_correct_targets += self._curr_correct_targets # Reset values until next sync self._curr_correct_predictions_k.zero_() self._curr_correct_targets.zero_()
def distributed_sinkhornknopp(self, Q: torch.Tensor): """ Apply the distributed sinknorn optimization on the scores matrix to find the assignments """ with torch.no_grad(): sum_Q = torch.sum(Q, dtype=Q.dtype) all_reduce_sum(sum_Q) Q /= sum_Q k = Q.shape[0] n = Q.shape[1] N = get_world_size() * Q.shape[1] # we follow the u, r, c and Q notations from # https://arxiv.org/abs/1911.05371 r = torch.ones(k) / k c = torch.ones(n) / N if self.use_gpu: r = r.cuda(non_blocking=True) c = c.cuda(non_blocking=True) curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(curr_sum) for _ in range(self.loss_config.num_iters): u = curr_sum Q *= (r / u).unsqueeze(1) Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0) curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(curr_sum) return ( Q / torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float()
def sync_state(self): """ Globally syncing the state of each meter across all the trainers. We gather scores, targets, total sampled """ # Communications self._curr_sample_count = all_reduce_sum(self._curr_sample_count) self._scores = self.gather_scores(self._scores) self._targets = self.gather_targets(self._targets) # Store results self._total_sample_count += self._curr_sample_count # Reset values until next sync self._curr_sample_count.zero_()
def cluster_memory(self): self.start_idx = 0 j = 0 with torch.no_grad(): for i_K, K in enumerate(self.num_clusters): # run distributed k-means # init centroids with elements from memory bank of rank 0 centroids = torch.empty( K, self.embedding_dim).cuda(non_blocking=True) if get_rank() == 0: random_idx = torch.randperm( len(self.local_memory_embeddings[j]))[:K] assert len(random_idx ) >= K, "please reduce the number of centroids" centroids = self.local_memory_embeddings[j][random_idx] dist.broadcast(centroids, 0) for n_iter in range(self.nmb_kmeans_iters + 1): # E step dot_products = torch.mm(self.local_memory_embeddings[j], centroids.t()) _, assignments = dot_products.max(dim=1) # finish if n_iter == self.nmb_kmeans_iters: break # M step where_helper = get_indices_sparse( assignments.cpu().numpy()) counts = torch.zeros(K).cuda(non_blocking=True).int() emb_sums = torch.zeros( K, self.embedding_dim).cuda(non_blocking=True) for k in range(len(where_helper)): if len(where_helper[k][0]) > 0: emb_sums[k] = torch.sum( self.local_memory_embeddings[j][where_helper[k] [0]], dim=0, ) counts[k] = len(where_helper[k][0]) all_reduce_sum(counts) mask = counts > 0 all_reduce_sum(emb_sums) centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze( 1) # normalize centroids centroids = nn.functional.normalize(centroids, dim=1, p=2) getattr(self, "centroids" + str(i_K)).copy_(centroids) # gather the assignments assignments_all = gather_from_all(assignments) indexes_all = gather_from_all(self.local_memory_index) self.assignments[i_K] = -100 self.assignments[i_K][indexes_all] = assignments_all j = (j + 1) % self.nmb_mbs logging.info(f"Rank: {get_rank()}, clustering of the memory bank done")
def distributed_sinkhornknopp(self, Q: torch.Tensor): """ Apply the distributed sinknorn optimization on the scores matrix to find the assignments """ eps_num_stab = 1e-12 with torch.no_grad(): # remove potential infs in Q # replace the inf entries with the max of the finite entries in Q mask = torch.isinf(Q) ind = torch.nonzero(mask) if len(ind) > 0: for i in ind: Q[i[0], i[1]] = 0 m = torch.max(Q) for i in ind: Q[i[0], i[1]] = m sum_Q = torch.sum(Q, dtype=Q.dtype) all_reduce_sum(sum_Q) Q /= sum_Q k = Q.shape[0] n = Q.shape[1] N = self.world_size * Q.shape[1] # we follow the u, r, c and Q notations from # https://arxiv.org/abs/1911.05371 r = torch.ones(k) / k c = torch.ones(n) / N if self.use_double_prec: r, c = r.double(), c.double() if self.use_gpu: r = r.cuda(non_blocking=True) c = c.cuda(non_blocking=True) for _ in range(self.nmb_sinkhornknopp_iters): u = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(u) # for numerical stability, add a small epsilon value # for non-zero Q values. if len(torch.nonzero(u == 0)) > 0: Q += eps_num_stab u = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(u) u = r / u # remove potential infs in "u" # replace the inf entries with the max of the finite entries in "u" mask = torch.isinf(u) ind = torch.nonzero(mask) if len(ind) > 0: for i in ind: u[i[0]] = 0 m = torch.max(u) for i in ind: u[i[0]] = m Q *= u.unsqueeze(1) Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0) Q = (Q / torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float() # hard assignment if self.num_iteration < self.temp_hard_assignment_iters: index_max = torch.max(Q, dim=1)[1] Q.zero_() Q.scatter_(1, index_max.unsqueeze(1), 1) return Q