def sync_state(self):
        # Communications
        self._curr_correct_predictions_k = all_reduce_sum(
            self._curr_correct_predictions_k)
        self._curr_sample_count = all_reduce_sum(self._curr_sample_count)

        # Store results
        self._total_correct_predictions_k += self._curr_correct_predictions_k
        self._total_sample_count += self._curr_sample_count

        # Reset values until next sync
        self._curr_correct_predictions_k.zero_()
        self._curr_sample_count.zero_()
Beispiel #2
0
    def sync_state(self):
        # Communications
        self._curr_correct_predictions_k = all_reduce_sum(
            self._curr_correct_predictions_k)
        self._curr_correct_targets = all_reduce_sum(self._curr_correct_targets)

        # Store results
        self._total_correct_predictions_k += self._curr_correct_predictions_k
        self._total_correct_targets += self._curr_correct_targets

        # Reset values until next sync
        self._curr_correct_predictions_k.zero_()
        self._curr_correct_targets.zero_()
Beispiel #3
0
    def distributed_sinkhornknopp(self, Q: torch.Tensor):
        """
        Apply the distributed sinknorn optimization on the scores matrix to
        find the assignments
        """
        with torch.no_grad():
            sum_Q = torch.sum(Q, dtype=Q.dtype)
            all_reduce_sum(sum_Q)
            Q /= sum_Q

            k = Q.shape[0]
            n = Q.shape[1]
            N = get_world_size() * Q.shape[1]

            # we follow the u, r, c and Q notations from
            # https://arxiv.org/abs/1911.05371
            r = torch.ones(k) / k
            c = torch.ones(n) / N

            if self.use_gpu:
                r = r.cuda(non_blocking=True)
                c = c.cuda(non_blocking=True)

            curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype)
            all_reduce_sum(curr_sum)

            for _ in range(self.loss_config.num_iters):
                u = curr_sum
                Q *= (r / u).unsqueeze(1)
                Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0)
                curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype)
                all_reduce_sum(curr_sum)
            return (
                Q /
                torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float()
Beispiel #4
0
    def sync_state(self):
        """
        Globally syncing the state of each meter across all the trainers.
        We gather scores, targets, total sampled
        """
        # Communications
        self._curr_sample_count = all_reduce_sum(self._curr_sample_count)
        self._scores = self.gather_scores(self._scores)
        self._targets = self.gather_targets(self._targets)

        # Store results
        self._total_sample_count += self._curr_sample_count

        # Reset values until next sync
        self._curr_sample_count.zero_()
Beispiel #5
0
    def cluster_memory(self):
        self.start_idx = 0
        j = 0
        with torch.no_grad():
            for i_K, K in enumerate(self.num_clusters):
                # run distributed k-means

                # init centroids with elements from memory bank of rank 0
                centroids = torch.empty(
                    K, self.embedding_dim).cuda(non_blocking=True)
                if get_rank() == 0:
                    random_idx = torch.randperm(
                        len(self.local_memory_embeddings[j]))[:K]
                    assert len(random_idx
                               ) >= K, "please reduce the number of centroids"
                    centroids = self.local_memory_embeddings[j][random_idx]
                dist.broadcast(centroids, 0)

                for n_iter in range(self.nmb_kmeans_iters + 1):

                    # E step
                    dot_products = torch.mm(self.local_memory_embeddings[j],
                                            centroids.t())
                    _, assignments = dot_products.max(dim=1)

                    # finish
                    if n_iter == self.nmb_kmeans_iters:
                        break

                    # M step
                    where_helper = get_indices_sparse(
                        assignments.cpu().numpy())
                    counts = torch.zeros(K).cuda(non_blocking=True).int()
                    emb_sums = torch.zeros(
                        K, self.embedding_dim).cuda(non_blocking=True)
                    for k in range(len(where_helper)):
                        if len(where_helper[k][0]) > 0:
                            emb_sums[k] = torch.sum(
                                self.local_memory_embeddings[j][where_helper[k]
                                                                [0]],
                                dim=0,
                            )
                            counts[k] = len(where_helper[k][0])
                    all_reduce_sum(counts)
                    mask = counts > 0
                    all_reduce_sum(emb_sums)
                    centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(
                        1)

                    # normalize centroids
                    centroids = nn.functional.normalize(centroids, dim=1, p=2)

                getattr(self, "centroids" + str(i_K)).copy_(centroids)
                # gather the assignments
                assignments_all = gather_from_all(assignments)
                indexes_all = gather_from_all(self.local_memory_index)
                self.assignments[i_K] = -100
                self.assignments[i_K][indexes_all] = assignments_all

                j = (j + 1) % self.nmb_mbs

        logging.info(f"Rank: {get_rank()}, clustering of the memory bank done")
Beispiel #6
0
    def distributed_sinkhornknopp(self, Q: torch.Tensor):
        """
        Apply the distributed sinknorn optimization on the scores matrix to
        find the assignments
        """
        eps_num_stab = 1e-12
        with torch.no_grad():
            # remove potential infs in Q
            # replace the inf entries with the max of the finite entries in Q
            mask = torch.isinf(Q)
            ind = torch.nonzero(mask)
            if len(ind) > 0:
                for i in ind:
                    Q[i[0], i[1]] = 0
                m = torch.max(Q)
                for i in ind:
                    Q[i[0], i[1]] = m
            sum_Q = torch.sum(Q, dtype=Q.dtype)
            all_reduce_sum(sum_Q)
            Q /= sum_Q

            k = Q.shape[0]
            n = Q.shape[1]
            N = self.world_size * Q.shape[1]

            # we follow the u, r, c and Q notations from
            # https://arxiv.org/abs/1911.05371
            r = torch.ones(k) / k
            c = torch.ones(n) / N
            if self.use_double_prec:
                r, c = r.double(), c.double()

            if self.use_gpu:
                r = r.cuda(non_blocking=True)
                c = c.cuda(non_blocking=True)

            for _ in range(self.nmb_sinkhornknopp_iters):
                u = torch.sum(Q, dim=1, dtype=Q.dtype)
                all_reduce_sum(u)

                # for numerical stability, add a small epsilon value
                # for non-zero Q values.
                if len(torch.nonzero(u == 0)) > 0:
                    Q += eps_num_stab
                    u = torch.sum(Q, dim=1, dtype=Q.dtype)
                    all_reduce_sum(u)
                u = r / u
                # remove potential infs in "u"
                # replace the inf entries with the max of the finite entries in "u"
                mask = torch.isinf(u)
                ind = torch.nonzero(mask)
                if len(ind) > 0:
                    for i in ind:
                        u[i[0]] = 0
                    m = torch.max(u)
                    for i in ind:
                        u[i[0]] = m

                Q *= u.unsqueeze(1)
                Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0)
            Q = (Q /
                 torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float()

            # hard assignment
            if self.num_iteration < self.temp_hard_assignment_iters:
                index_max = torch.max(Q, dim=1)[1]
                Q.zero_()
                Q.scatter_(1, index_max.unsqueeze(1), 1)
            return Q