def get_losses_for_batch(self, batch, train=True):
        indices, img1, img2, _, = batch
        outputs1 = self.forward(img1)

        if self.loss_name == 'SimCLR':
            outputs2 = self.forward(img2)
            loss_fn = SimCLR(outputs1, outputs2, 
                             t=self.config.loss_params.t)
            loss = loss_fn.get_loss()
        elif self.loss_name == 'MoCo':
            with torch.no_grad():
                self._momentum_update_key_encoder()
                if self.use_ddp or self.use_ddp2:
                    img2, idx_unshuffle = self._batch_shuffle_ddp(img2)
                outputs2 = self.model_k(img2)
                if self.use_ddp or self.use_ddp2:
                    outputs_k = self._batch_unshuffle_ddp(outputs_k, idx_unshuffle)

            loss_fn = MoCo(outputs1, outputs2, 
                           self.moco_queue.clone().detach(),
                           t=self.config.loss_params.t)
            loss = loss_fn.get_loss()

            if train:
                outputs_k = utils.l2_normalize(outputs2, dim=1)
                self._dequeue_and_enqueue(outputs_k)
        else:
            raise Exception(f'Loss {self.loss_name} not supported.')

        if train:
            with torch.no_grad():
                new_data_memory = utils.l2_normalize(outputs1)
                self.memory_bank.update(indices, new_data_memory)

        return loss
Beispiel #2
0
    def get_losses_for_batch(self, emb_dict, train=True):
        if self.loss_name == 'nce':
            loss_fn = NoiseConstrastiveEstimation(emb_dict['indices'], emb_dict['img_embs_1'], self.memory_bank,
                                                  k=self.config.loss_params.k,
                                                  t=self.t,
                                                  m=self.config.loss_params.m)
            loss = loss_fn.get_loss()
        elif self.loss_name == 'simclr':
            if 'img_embs_2' not in emb_dict:
                raise ValueError(f'img_embs_2 is required for SimCLR loss')
            loss_fn = SimCLRObjective(emb_dict['img_embs_1'], emb_dict['img_embs_2'], t=self.t)
            loss = loss_fn.get_loss()
        else:
            raise Exception(f'Objective {self.loss_name} is not supported.')

        if train:
            with torch.no_grad():
                if self.loss_name == 'nce':
                    new_data_memory = loss_fn.updated_new_data_memory()
                    self.memory_bank.update(emb_dict['indices'], new_data_memory)
                elif 'simclr' in self.loss_name:
                    outputs_avg = (utils.l2_normalize(emb_dict['img_embs_1'], dim=1) + 
                                   utils.l2_normalize(emb_dict['img_embs_2'], dim=1)) / 2.
                    indices = emb_dict['indices']
                    self.memory_bank.update(indices, outputs_avg)
                else:
                    raise Exception(f'Objective {self.loss_name} is not supported.')

        return loss
Beispiel #3
0
 def __init__(self, outputs1, outputs2, queue, t=0.07):
     super().__init__()
     self.outputs1 = l2_normalize(outputs1, dim=1)
     self.outputs2 = l2_normalize(outputs2, dim=1)
     self.queue = queue.detach()
     self.t = t
     self.k = queue.size(0)
     self.device = self.outputs1.device
Beispiel #4
0
    def get_losses_for_batch(self, batch):
        indices, inputs1, inputs2, _ = batch
        outputs1 = self.forward(inputs1)
        outputs2 = self.forward(inputs2)
        loss_fn = SimCLRObjective(outputs1,
                                  outputs2,
                                  t=self.config.loss_params.t)
        loss = loss_fn.get_loss()

        with torch.no_grad():  # for nearest neighbor
            new_data_memory = (l2_normalize(outputs1, dim=1) +
                               l2_normalize(outputs2, dim=1)) / 2.
            self.memory_bank.update(indices, new_data_memory)

        return loss
Beispiel #5
0
 def _create(self):
     # initialize random weights
     mb_init = torch.rand(self.size, self.dim, requires_grad=False)
     std_dev = 1. / np.sqrt(self.dim / 3)
     mb_init = mb_init * (2 * std_dev) - std_dev
     # L2 normalise so that the norm is 1
     mb_init = l2_normalize(mb_init, dim=1)
     return mb_init
Beispiel #6
0
 def _create(self):
     # initialize random weights
     mb_init = torch.rand(self.size, self.dim, device=self.device)
     std_dev = 1. / np.sqrt(self.dim / 3)
     mb_init = mb_init * (2 * std_dev) - std_dev
     # L2 normalise so that the norm is 1
     mb_init = l2_normalize(mb_init, dim=1)
     return mb_init.detach()  # detach so its not trainable
Beispiel #7
0
    def __init__(self, indices, outputs, memory_bank, k=4096, t=0.07, m=0.5, **kwargs):
        self.k, self.t, self.m = k, t, m

        self.indices = indices.detach()
        self.outputs = l2_normalize(outputs, dim=1)

        self.memory_bank = memory_bank
        self.device = outputs.device
        self.data_len = memory_bank.size
Beispiel #8
0
    def get_losses_for_batch(self, emb_dict):
        loss_function = AdversarialSimCLRLoss(
            embs1=emb_dict['view1_embs'],
            embs2=emb_dict['view2_embs'],
            t=self.config.loss_params.t,
            view_maker_loss_weight=self.config.loss_params.
            view_maker_loss_weight)
        encoder_loss, view_maker_loss = loss_function.get_loss()

        with torch.no_grad():
            new_data_memory = l2_normalize(emb_dict['view1_embs'].detach(),
                                           dim=1)
            self.memory_bank.update(emb_dict['indices'], new_data_memory)

        return encoder_loss, view_maker_loss
    def __init__(self, config):
        super().__init__(config)

        self.loss_name = self.config.loss_params.name
        if self.loss_name == 'MoCo':
            self.model_k = self.create_encoder()

            for param_q, param_k in zip(self.model.parameters(), self.model_k.parameters()):
                param_k.data.copy_(param_q.data)  # initialize
                param_k.requires_grad = False     # do not update

            # create queue (k x out_dim)
            moco_queue = torch.randn(
                self.config.loss_params.k,
                self.config.model_params.out_dim, 
            )
            self.register_buffer("moco_queue", moco_queue)
            self.moco_queue = utils.l2_normalize(moco_queue, dim=1)
            self.register_buffer("moco_queue_ptr", torch.zeros(1, dtype=torch.long))
Beispiel #10
0
    def get_losses_for_batch(self, emb_dict, train=True):
        if self.loss_name == 'AdversarialSimCLRLoss':
            view_maker_loss_weight = self.config.loss_params.view_maker_loss_weight
            loss_function = AdversarialSimCLRLoss(
                embs1=emb_dict['view1_embs'],
                embs2=emb_dict['view2_embs'],
                t=self.t,
                view_maker_loss_weight=view_maker_loss_weight
            )
            encoder_loss, view_maker_loss = loss_function.get_loss()
            img_embs = emb_dict['view1_embs'] 
        elif self.loss_name == 'AdversarialNCELoss':
            view_maker_loss_weight = self.config.loss_params.view_maker_loss_weight
            loss_function = AdversarialNCELoss(
                emb_dict['indices'],
                emb_dict['view1_embs'],
                self.memory_bank,
                k=self.config.loss_params.k,
                t=self.t,
                m=self.config.loss_params.m,
                view_maker_loss_weight=view_maker_loss_weight
            )
            encoder_loss, view_maker_loss = loss_function.get_loss()
            img_embs = emb_dict['view1_embs'] 
        else:
            raise Exception(f'Objective {self.loss_name} is not supported.') 
        
        # Update memory bank.
        if train:
            with torch.no_grad():
                if self.loss_name == 'AdversarialNCELoss':
                    new_data_memory = loss_function.updated_new_data_memory()
                    self.memory_bank.update(emb_dict['indices'], new_data_memory)
                else:
                    new_data_memory = utils.l2_normalize(img_embs, dim=1)
                    self.memory_bank.update(emb_dict['indices'], new_data_memory)

        return encoder_loss, view_maker_loss
Beispiel #11
0
 def updated_new_data_memory(self):
     data_memory = self.memory_bank.at_idxs(self.indices)
     new_data_memory = data_memory * self.m + (1 - self.m) * self.outputs
     return l2_normalize(new_data_memory, dim=1)
Beispiel #12
0
 def __init__(self, outputs1, outputs2, t, push_only=False):
     super().__init__()
     self.outputs1 = l2_normalize(outputs1, dim=1)
     self.outputs2 = l2_normalize(outputs2, dim=1)
     self.t = t
     self.push_only = push_only
 def __init__(self, outputs1, outputs2, t=0.07):
     super().__init__()
     self.outputs1 = l2_normalize(outputs1, dim=1)
     self.outputs2 = l2_normalize(outputs2, dim=1)
     self.t = t
Beispiel #14
0
 def noise(self, batch_size):
     shape = (batch_size, self.pretrain_config.model_params.noise_dim)
     # Center noise at 0 then project to unit sphere.
     noise = utils.l2_normalize(torch.rand(shape) - 0.5)
     return noise
Beispiel #15
0
 def normalize_embeddings(self):
     self.embs1 = l2_normalize(self.embs1)
     self.embs2 = l2_normalize(self.embs2)