def forward(self, indices, outputs, gpu_idx): self.indices = indices.detach() self.outputs = l2_normalize(outputs, dim=1) self._bank = self.memory_bank_broadcast[gpu_idx] batch_size = self.indices.size(0) data_prob = self.compute_data_prob() noise_prob = self.compute_noise_prob() assert data_prob.size(0) == batch_size assert noise_prob.size(0) == batch_size assert noise_prob.size(1) == self.k base_prob = 1.0 / self.data_len eps = 1e-7 ## Pmt data_div = data_prob + (self.k * base_prob + eps) ln_data = torch.log(data_prob) - torch.log(data_div) ## Pon noise_div = noise_prob + (self.k * base_prob + eps) ln_noise = math.log(self.k * base_prob) - torch.log(noise_div) curr_loss = -(torch.sum(ln_data) + torch.sum(ln_noise)) curr_loss = curr_loss / batch_size new_data_memory = self.updated_new_data_memory(self.indices, self.outputs) return curr_loss.unsqueeze(0), new_data_memory
def _create(self): # initialize random weights mb_init = torch.rand(self.size, self.dim, device=self.device) std_dev = 1. / np.sqrt(self.dim / 3) mb_init = mb_init * (2 * std_dev) - std_dev # L2 normalise so that the norm is 1 mb_init = l2_normalize(mb_init, dim=1) return mb_init.detach() # detach so its not trainable
def __init__(self, indices, outputs, memory_bank, k=4096, t=0.07, m=0.5): self.k, self.t, self.m = k, t, m self.indices = indices.detach() self.outputs = l2_normalize(outputs) self.memory_bank = memory_bank self.data_len = memory_bank.size self.device = indices.device
def forward(self, indices, outputs, gpu_idx): """ :param back_nei_idxs: shape (batch_size, 4096) :param all_close_nei: shape (batch_size, _size_of_dataset) in byte """ self.indices = indices.detach() self.outputs = l2_normalize(outputs, dim=1) self._bank = self.memory_bank_broadcast[ gpu_idx] # select a mem bank based on gpu device self._cluster_labels = self.cluster_label_broadcast[gpu_idx] k = self.k all_dps = self._get_all_dot_products(self.outputs) back_nei_dps, back_nei_idxs = torch.topk(all_dps, k=k, sorted=False, dim=1) back_nei_probs = self._softmax(back_nei_dps) all_close_nei_in_back = None no_kmeans = self._cluster_labels.size(0) with torch.no_grad(): for each_k_idx in range(no_kmeans): curr_close_nei = self.__get_close_nei_in_back( each_k_idx, self._cluster_labels, back_nei_idxs, k) if all_close_nei_in_back is None: all_close_nei_in_back = curr_close_nei else: # assuming all_close_nei and curr_close_nei are byte tensors all_close_nei_in_back = all_close_nei_in_back | curr_close_nei relative_probs = self.__get_relative_prob(all_close_nei_in_back, back_nei_probs) loss = -torch.mean(torch.log(relative_probs + 1e-7)).unsqueeze(0) # compute new data memory new_data_memory = self.updated_new_data_memory(self.indices, self.outputs) return loss, new_data_memory
def updated_new_data_memory(self): data_memory = self.memory_bank.at_idxs(self.indices) new_data_memory = data_memory * self.m + (1 - self.m) * self.outputs return l2_normalize(new_data_memory, dim=1)
def updated_new_data_memory(self, indices, outputs): outputs = l2_normalize(outputs) data_memory = torch.index_select(self._bank, 0, indices) new_data_memory = data_memory * self.m + (1 - self.m) * outputs return l2_normalize(new_data_memory, dim=1)