Example #1
0
class LocalAggregationLoss(nn.Module):
    def __init__(self,
                 temperature,
                 knns,
                 clustering_repeats,
                 n_centroids,
                 memory_bank,
                 kmeans_n_init=1,
                 nn_metric=cosine_distance,
                 nn_metric_params={}):
        super(LocalAggregationLoss, self).__init__()

        self.temperature = temperature
        self.memory_bank = memory_bank

        ## 1. Distance: Efficiently compute nearest neighbors << set B in alg
        self.neighbour_finder = NearestNeighbors(
            n_neighbors=knns + 1,
            algorithm='ball_tree',
            metric=nn_metric,
            metric_params=nn_metric_params)
        ## 2. Clusters: efficiently compute clusters << set C ini alg
        self.clusterer = []
        for k_clusterer in range(clustering_repeats):
            self.clusterer.append(
                KMeans(n_clusters=n_centroids,
                       init='random',
                       n_init=kmeans_n_init))

    def forward(self, codes, indices):
        assert codes.shape[0] == len(indices)
        codes = codes.type(torch.DoubleTensor)
        code_data = normalize(codes.detach().numpy(), axis=1)

        ##constants in the loss function; no gradients@backpass
        self.memory_bank.update_memory(code_data, indices)

        bg_neighbours = self._nearest_neighbours(code_data, indices)
        close_neighbours = self._close_grouper(indices)
        neighbour_inersect = self._intersecter(bg_neighbours, close_neighbours)

        ## compute pdf
        v = F.normalize(codes, p=2, dim=1)
        d1 = self._prob_density(v, bg_neighbours)
        d2 = self._prob_density(v, neighbour_inersect)

        return torch.sum(torch.log(d1) - torch.log(d2)) / codes.shape[0]

    def _nearest_neighbours(self, codes_data, indices):
        self.neighbour_finder.fit(self.memory_bank.vectors)
        indices_nearest = self.neighbour_finder.kneighbours(
            codes_data, return_distance=False)
        return self.memory_bank.mask(indices_nearest)

    def _close_grouper(self, indices):
        ## ascertain
        memberships = [[]] * len(indices)

        for clusterer in self.clusterer:
            clusterer.fit(self.memory_bank.vectors)
            for k_idx, cluster_idx in enumerate(clusterer.labels_[indices]):
                other_members = np.where(clusterer.labels_ == cluster_idx)[0]
                other_members_union = np.union1d(memberships[k_idx],
                                                 other_members)
                memberships[k_idx] = other_members_union.astype(int)
        return self.memory_bank.mask(np.array(memberships, dtype=object))

    def _intersecter(self, n1, n2):
        return np.array([[v1 and v2 for v1, v2 in zip(n1_x, n2_x)]
                         for n1_x, n2_x in zip(n1, n2)])

    def _prob_density(self, codes, indices):
        ## unormalized differentiable probability densities
        ragged = len(set([np.count_nonzero(idx) for idx in indices])) != 1

        # In case the subsets of memory vectors are all of the same size, broadcasting can be used and the
        # batch dimension is handled concisely. This will always be true for the k-nearest neighbour density
        if not ragged:
            vals = torch.tensor([
                np.compress(ind, self.memory_bank.vectors, axis=0)
                for ind in indices
            ],
                                requires_grad=False)
            v_dots = torch.matmul(vals, codes.unsqueeze(-1))
            exp_values = torch.exp(torch.div(v_dots, self.temperature))
            pdensity = torch.sum(exp_values, dim=1).squeeze(-1)

        # Broadcasting not possible if the subsets of memory vectors are of different size, so then manually loop
        # over the batch dimension and stack results
        else:
            xx_container = []
            for k_item in range(codes.size(0)):
                vals = torch.tensor(np.compress(indices[k_item],
                                                self.memory_bank.vectors,
                                                axis=0),
                                    requires_grad=False)
                v_dots_prime = torch.mv(vals, codes[k_item])
                exp_values_prime = torch.exp(
                    torch.div(v_dots_prime, self.temperature))
                xx_prime = torch.sum(exp_values_prime, dim=0)
                xx_container.append(xx_prime)
            pdensity = torch.stack(xx_container, dim=0)

        return pdensity