def get_losses_for_batch(self, batch, train=True): indices, img1, img2, _, = batch outputs1 = self.forward(img1) if self.loss_name == 'SimCLR': outputs2 = self.forward(img2) loss_fn = SimCLR(outputs1, outputs2, t=self.config.loss_params.t) loss = loss_fn.get_loss() elif self.loss_name == 'MoCo': with torch.no_grad(): self._momentum_update_key_encoder() if self.use_ddp or self.use_ddp2: img2, idx_unshuffle = self._batch_shuffle_ddp(img2) outputs2 = self.model_k(img2) if self.use_ddp or self.use_ddp2: outputs_k = self._batch_unshuffle_ddp(outputs_k, idx_unshuffle) loss_fn = MoCo(outputs1, outputs2, self.moco_queue.clone().detach(), t=self.config.loss_params.t) loss = loss_fn.get_loss() if train: outputs_k = utils.l2_normalize(outputs2, dim=1) self._dequeue_and_enqueue(outputs_k) else: raise Exception(f'Loss {self.loss_name} not supported.') if train: with torch.no_grad(): new_data_memory = utils.l2_normalize(outputs1) self.memory_bank.update(indices, new_data_memory) return loss
def get_losses_for_batch(self, emb_dict, train=True): if self.loss_name == 'nce': loss_fn = NoiseConstrastiveEstimation(emb_dict['indices'], emb_dict['img_embs_1'], self.memory_bank, k=self.config.loss_params.k, t=self.t, m=self.config.loss_params.m) loss = loss_fn.get_loss() elif self.loss_name == 'simclr': if 'img_embs_2' not in emb_dict: raise ValueError(f'img_embs_2 is required for SimCLR loss') loss_fn = SimCLRObjective(emb_dict['img_embs_1'], emb_dict['img_embs_2'], t=self.t) loss = loss_fn.get_loss() else: raise Exception(f'Objective {self.loss_name} is not supported.') if train: with torch.no_grad(): if self.loss_name == 'nce': new_data_memory = loss_fn.updated_new_data_memory() self.memory_bank.update(emb_dict['indices'], new_data_memory) elif 'simclr' in self.loss_name: outputs_avg = (utils.l2_normalize(emb_dict['img_embs_1'], dim=1) + utils.l2_normalize(emb_dict['img_embs_2'], dim=1)) / 2. indices = emb_dict['indices'] self.memory_bank.update(indices, outputs_avg) else: raise Exception(f'Objective {self.loss_name} is not supported.') return loss
def __init__(self, outputs1, outputs2, queue, t=0.07): super().__init__() self.outputs1 = l2_normalize(outputs1, dim=1) self.outputs2 = l2_normalize(outputs2, dim=1) self.queue = queue.detach() self.t = t self.k = queue.size(0) self.device = self.outputs1.device
def get_losses_for_batch(self, batch): indices, inputs1, inputs2, _ = batch outputs1 = self.forward(inputs1) outputs2 = self.forward(inputs2) loss_fn = SimCLRObjective(outputs1, outputs2, t=self.config.loss_params.t) loss = loss_fn.get_loss() with torch.no_grad(): # for nearest neighbor new_data_memory = (l2_normalize(outputs1, dim=1) + l2_normalize(outputs2, dim=1)) / 2. self.memory_bank.update(indices, new_data_memory) return loss
def _create(self): # initialize random weights mb_init = torch.rand(self.size, self.dim, requires_grad=False) std_dev = 1. / np.sqrt(self.dim / 3) mb_init = mb_init * (2 * std_dev) - std_dev # L2 normalise so that the norm is 1 mb_init = l2_normalize(mb_init, dim=1) return mb_init
def _create(self): # initialize random weights mb_init = torch.rand(self.size, self.dim, device=self.device) std_dev = 1. / np.sqrt(self.dim / 3) mb_init = mb_init * (2 * std_dev) - std_dev # L2 normalise so that the norm is 1 mb_init = l2_normalize(mb_init, dim=1) return mb_init.detach() # detach so its not trainable
def __init__(self, indices, outputs, memory_bank, k=4096, t=0.07, m=0.5, **kwargs): self.k, self.t, self.m = k, t, m self.indices = indices.detach() self.outputs = l2_normalize(outputs, dim=1) self.memory_bank = memory_bank self.device = outputs.device self.data_len = memory_bank.size
def get_losses_for_batch(self, emb_dict): loss_function = AdversarialSimCLRLoss( embs1=emb_dict['view1_embs'], embs2=emb_dict['view2_embs'], t=self.config.loss_params.t, view_maker_loss_weight=self.config.loss_params. view_maker_loss_weight) encoder_loss, view_maker_loss = loss_function.get_loss() with torch.no_grad(): new_data_memory = l2_normalize(emb_dict['view1_embs'].detach(), dim=1) self.memory_bank.update(emb_dict['indices'], new_data_memory) return encoder_loss, view_maker_loss
def __init__(self, config): super().__init__(config) self.loss_name = self.config.loss_params.name if self.loss_name == 'MoCo': self.model_k = self.create_encoder() for param_q, param_k in zip(self.model.parameters(), self.model_k.parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # do not update # create queue (k x out_dim) moco_queue = torch.randn( self.config.loss_params.k, self.config.model_params.out_dim, ) self.register_buffer("moco_queue", moco_queue) self.moco_queue = utils.l2_normalize(moco_queue, dim=1) self.register_buffer("moco_queue_ptr", torch.zeros(1, dtype=torch.long))
def get_losses_for_batch(self, emb_dict, train=True): if self.loss_name == 'AdversarialSimCLRLoss': view_maker_loss_weight = self.config.loss_params.view_maker_loss_weight loss_function = AdversarialSimCLRLoss( embs1=emb_dict['view1_embs'], embs2=emb_dict['view2_embs'], t=self.t, view_maker_loss_weight=view_maker_loss_weight ) encoder_loss, view_maker_loss = loss_function.get_loss() img_embs = emb_dict['view1_embs'] elif self.loss_name == 'AdversarialNCELoss': view_maker_loss_weight = self.config.loss_params.view_maker_loss_weight loss_function = AdversarialNCELoss( emb_dict['indices'], emb_dict['view1_embs'], self.memory_bank, k=self.config.loss_params.k, t=self.t, m=self.config.loss_params.m, view_maker_loss_weight=view_maker_loss_weight ) encoder_loss, view_maker_loss = loss_function.get_loss() img_embs = emb_dict['view1_embs'] else: raise Exception(f'Objective {self.loss_name} is not supported.') # Update memory bank. if train: with torch.no_grad(): if self.loss_name == 'AdversarialNCELoss': new_data_memory = loss_function.updated_new_data_memory() self.memory_bank.update(emb_dict['indices'], new_data_memory) else: new_data_memory = utils.l2_normalize(img_embs, dim=1) self.memory_bank.update(emb_dict['indices'], new_data_memory) return encoder_loss, view_maker_loss
def updated_new_data_memory(self): data_memory = self.memory_bank.at_idxs(self.indices) new_data_memory = data_memory * self.m + (1 - self.m) * self.outputs return l2_normalize(new_data_memory, dim=1)
def __init__(self, outputs1, outputs2, t, push_only=False): super().__init__() self.outputs1 = l2_normalize(outputs1, dim=1) self.outputs2 = l2_normalize(outputs2, dim=1) self.t = t self.push_only = push_only
def __init__(self, outputs1, outputs2, t=0.07): super().__init__() self.outputs1 = l2_normalize(outputs1, dim=1) self.outputs2 = l2_normalize(outputs2, dim=1) self.t = t
def noise(self, batch_size): shape = (batch_size, self.pretrain_config.model_params.noise_dim) # Center noise at 0 then project to unit sphere. noise = utils.l2_normalize(torch.rand(shape) - 0.5) return noise
def normalize_embeddings(self): self.embs1 = l2_normalize(self.embs1) self.embs2 = l2_normalize(self.embs2)