示例#1
0
    def forward(self, indices, outputs, gpu_idx):
        self.indices = indices.detach()
        self.outputs = l2_normalize(outputs, dim=1)
        self._bank = self.memory_bank_broadcast[gpu_idx]

        batch_size = self.indices.size(0)
        data_prob = self.compute_data_prob()
        noise_prob = self.compute_noise_prob()

        assert data_prob.size(0) == batch_size
        assert noise_prob.size(0) == batch_size
        assert noise_prob.size(1) == self.k

        base_prob = 1.0 / self.data_len
        eps = 1e-7

        ## Pmt
        data_div = data_prob + (self.k * base_prob + eps)

        ln_data = torch.log(data_prob) - torch.log(data_div)

        ## Pon
        noise_div = noise_prob + (self.k * base_prob + eps)
        ln_noise = math.log(self.k * base_prob) - torch.log(noise_div)

        curr_loss = -(torch.sum(ln_data) + torch.sum(ln_noise))
        curr_loss = curr_loss / batch_size

        new_data_memory = self.updated_new_data_memory(self.indices,
                                                       self.outputs)

        return curr_loss.unsqueeze(0), new_data_memory
示例#2
0
 def _create(self):
     # initialize random weights 
     mb_init = torch.rand(self.size, self.dim, device=self.device)
     std_dev = 1. / np.sqrt(self.dim / 3)
     mb_init = mb_init * (2 * std_dev) - std_dev
     # L2 normalise so that the norm is 1
     mb_init = l2_normalize(mb_init, dim=1)
     return mb_init.detach()  # detach so its not trainable
示例#3
0
    def __init__(self, indices, outputs, memory_bank, k=4096, t=0.07, m=0.5):
        self.k, self.t, self.m = k, t, m

        self.indices = indices.detach()
        self.outputs = l2_normalize(outputs)

        self.memory_bank = memory_bank
        self.data_len = memory_bank.size
        self.device = indices.device
示例#4
0
    def forward(self, indices, outputs, gpu_idx):
        """
        :param back_nei_idxs: shape (batch_size, 4096)
        :param all_close_nei: shape (batch_size, _size_of_dataset) in byte
        """
        self.indices = indices.detach()
        self.outputs = l2_normalize(outputs, dim=1)
        self._bank = self.memory_bank_broadcast[
            gpu_idx]  # select a mem bank based on gpu device
        self._cluster_labels = self.cluster_label_broadcast[gpu_idx]

        k = self.k

        all_dps = self._get_all_dot_products(self.outputs)
        back_nei_dps, back_nei_idxs = torch.topk(all_dps,
                                                 k=k,
                                                 sorted=False,
                                                 dim=1)
        back_nei_probs = self._softmax(back_nei_dps)

        all_close_nei_in_back = None
        no_kmeans = self._cluster_labels.size(0)
        with torch.no_grad():
            for each_k_idx in range(no_kmeans):
                curr_close_nei = self.__get_close_nei_in_back(
                    each_k_idx, self._cluster_labels, back_nei_idxs, k)

                if all_close_nei_in_back is None:
                    all_close_nei_in_back = curr_close_nei
                else:
                    # assuming all_close_nei and curr_close_nei are byte tensors
                    all_close_nei_in_back = all_close_nei_in_back | curr_close_nei

        relative_probs = self.__get_relative_prob(all_close_nei_in_back,
                                                  back_nei_probs)
        loss = -torch.mean(torch.log(relative_probs + 1e-7)).unsqueeze(0)

        # compute new data memory
        new_data_memory = self.updated_new_data_memory(self.indices,
                                                       self.outputs)

        return loss, new_data_memory
示例#5
0
 def updated_new_data_memory(self):
     data_memory = self.memory_bank.at_idxs(self.indices)
     new_data_memory = data_memory * self.m + (1 - self.m) * self.outputs
     return l2_normalize(new_data_memory, dim=1)
示例#6
0
 def updated_new_data_memory(self, indices, outputs):
     outputs = l2_normalize(outputs)
     data_memory = torch.index_select(self._bank, 0, indices)
     new_data_memory = data_memory * self.m + (1 - self.m) * outputs
     return l2_normalize(new_data_memory, dim=1)