Пример #1
0
 def __init__(self, config):
     super().__init__()
     self.config = config
     self.batch_size = config.optim_params.batch_size
     # self.device = f'cuda:{config.gpu_device}' if config.cuda else 'cpu'
     self.train_dataset, self.val_dataset = self.create_datasets()
     self.model = self.create_encoder()
     self.memory_bank = MemoryBank(len(self.train_dataset),
                                   self.config.model_params.out_dim)
     self.train_ordered_labels = self.train_dataset.all_speaker_ids
Пример #2
0
 def __init__(self, config):
     super().__init__()
     self.config = config
     self.batch_size = config.optim_params.batch_size
     self.train_dataset, self.val_dataset = self.create_datasets()
     self.model = self.create_encoder()
     self.memory_bank = MemoryBank(len(self.train_dataset),
                                   self.config.model_params.out_dim)
     self.memory_bank_labels = MemoryBank(len(self.train_dataset),
                                          1,
                                          dtype=int)
Пример #3
0
    def __init__(self, config):
        super(PretrainViewMakerSystem, self).__init__()
        self.config = config
        self.batch_size = config.optim_params.batch_size
        self.loss_name = self.config.loss_params.name
        self.t = self.config.loss_params.t

        default_augmentations = self.config.data_params.default_augmentations
        # DotMap is the default argument when a config argument is missing
        if default_augmentations == DotMap():
           default_augmentations = 'all'
        self.train_dataset, self.val_dataset = datasets.get_image_datasets(
            config.data_params.dataset,
            default_augmentations=default_augmentations,
        )
        train_labels = self.train_dataset.dataset.targets
        self.train_ordered_labels = np.array(train_labels)
        self.model = self.create_encoder()
        self.memory_bank = MemoryBank(
            len(self.train_dataset), 
            self.config.model_params.out_dim, 
        )
Пример #4
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.optim_params.batch_size
        self.loss_name = self.config.loss_params.objective
        self.t = self.config.loss_params.t

        self.train_dataset, self.val_dataset = datasets.get_image_datasets(
            config.data_params.dataset,
            config.data_params.default_augmentations or 'none',
        )
        # Used for computing knn validation accuracy
        train_labels = self.train_dataset.dataset.targets
        self.train_ordered_labels = np.array(train_labels)

        self.model = self.create_encoder()
        self.viewmaker = self.create_viewmaker()
        
        # Used for computing knn validation accuracy.
        self.memory_bank = MemoryBank(
            len(self.train_dataset),
            self.config.model_params.out_dim,
        )
Пример #5
0
class PretrainExpertInstDiscSystem(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.optim_params.batch_size
        # self.device = f'cuda:{config.gpu_device}' if config.cuda else 'cpu'
        self.train_dataset, self.val_dataset = self.create_datasets()
        self.model = self.create_encoder()
        self.memory_bank = MemoryBank(len(self.train_dataset),
                                      self.config.model_params.out_dim)
        self.train_ordered_labels = self.train_dataset.all_speaker_ids

    def create_datasets(self):
        print('Initializing train dataset.')
        train_dataset = LibriSpeech(
            train=True,
            spectral_transforms=self.config.data_params.spectral_transforms,
            wavform_transforms=not self.config.data_params.spectral_transforms,
            small=self.config.data_params.small,
            input_size=self.config.data_params.input_size,
        )
        print('Initializing validation dataset.')
        val_dataset = LibriSpeech(
            train=False,
            spectral_transforms=False,
            wavform_transforms=False,
            small=self.config.data_params.small,
            test_url=self.config.data_params.test_url,
            input_size=self.config.data_params.input_size,
        )
        return train_dataset, val_dataset

    def create_encoder(self):
        if self.config.model_params.resnet_small:
            encoder_model = resnet_small.ResNet18(
                self.config.model_params.out_dim,
                num_channels=1,
                input_size=64,
            )
        else:
            resnet_class = getattr(
                torchvision.models,
                self.config.model_params.resnet_version,
            )
            encoder_model = resnet_class(
                pretrained=False,
                num_classes=self.config.model_params.out_dim,
            )
            encoder_model.conv1 = nn.Conv2d(1,
                                            64,
                                            kernel_size=7,
                                            stride=2,
                                            padding=3,
                                            bias=False)

        if self.config.model_params.projection_head:
            mlp_dim = encoder_model.fc.weight.size(1)
            encoder_model.fc = nn.Sequential(
                nn.Linear(mlp_dim, mlp_dim),
                nn.ReLU(),
                encoder_model.fc,
            )
        return encoder_model

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.model.parameters(),
            lr=self.config.optim_params.learning_rate,
            momentum=self.config.optim_params.momentum,
            weight_decay=self.config.optim_params.weight_decay)
        return [optim], []

    def forward(self, inputs):
        return self.model(inputs)

    def get_losses_for_batch(self, batch):
        indices, inputs, _ = batch
        outputs = self.forward(inputs)
        loss_fn = NoiseConstrastiveEstimation(indices,
                                              outputs,
                                              self.memory_bank,
                                              k=self.config.loss_params.k,
                                              t=self.config.loss_params.t,
                                              m=self.config.loss_params.m)
        loss = loss_fn.get_loss()

        with torch.no_grad():
            new_data_memory = loss_fn.updated_new_data_memory()
            self.memory_bank.update(indices, new_data_memory)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.get_losses_for_batch(batch)
        metrics = {'loss': loss}
        return {'loss': loss, 'log': metrics}

    def get_nearest_neighbor_label(self, embs, labels):
        """
        NOTE: ONLY TO BE USED FOR VALIDATION.
        
        For each example in validation, find the nearest example in the 
        training dataset using the memory bank. Assume its label as
        the predicted label.
        """
        all_dps = self.memory_bank.get_all_dot_products(embs)
        _, neighbor_idxs = torch.topk(all_dps, k=1, sorted=False, dim=1)
        neighbor_idxs = neighbor_idxs.squeeze(1)
        neighbor_idxs = neighbor_idxs.cpu().numpy()

        neighbor_labels = self.train_ordered_labels[neighbor_idxs]
        neighbor_labels = torch.from_numpy(neighbor_labels).long()

        num_correct = torch.sum(neighbor_labels.cpu() == labels.cpu()).item()

        return num_correct, embs.size(0)

    def validation_step(self, batch, batch_idx):
        _, inputs, speaker_ids = batch
        outputs = self.model(inputs)
        num_correct, batch_size = self.get_nearest_neighbor_label(
            outputs, speaker_ids)
        num_correct = torch.tensor(num_correct,
                                   dtype=float,
                                   device=self.device)
        batch_size = torch.tensor(batch_size, dtype=float, device=self.device)
        return OrderedDict({
            'val_num_correct': num_correct,
            'val_num_total': batch_size
        })

    def validation_epoch_end(self, outputs):
        metrics = {}
        for key in outputs[0].keys():
            metrics[key] = torch.stack([elem[key] for elem in outputs]).mean()
        num_correct = torch.stack([out['val_num_correct']
                                   for out in outputs]).sum()
        num_total = torch.stack([out['val_num_total']
                                 for out in outputs]).sum()
        val_acc = num_correct / float(num_total)
        metrics['val_acc'] = val_acc
        return {'log': metrics, 'val_acc': val_acc}

    def train_dataloader(self):
        return create_dataloader(self.train_dataset, self.config,
                                 self.batch_size)

    def val_dataloader(self):
        return create_dataloader(self.val_dataset,
                                 self.config,
                                 self.batch_size,
                                 shuffle=False)
Пример #6
0
class PretrainViewMakerSystem(pl.LightningModule):
    '''Pytorch Lightning System for self-supervised pretraining 
    with adversarially generated views.
    '''

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.optim_params.batch_size
        self.loss_name = self.config.loss_params.objective
        self.t = self.config.loss_params.t

        self.train_dataset, self.val_dataset = datasets.get_image_datasets(
            config.data_params.dataset,
            config.data_params.default_augmentations or 'none',
        )
        # Used for computing knn validation accuracy
        train_labels = self.train_dataset.dataset.targets
        self.train_ordered_labels = np.array(train_labels)

        self.model = self.create_encoder()
        self.viewmaker = self.create_viewmaker()
        
        # Used for computing knn validation accuracy.
        self.memory_bank = MemoryBank(
            len(self.train_dataset),
            self.config.model_params.out_dim,
        )

    def view(self, imgs):
        if 'Expert' in self.config.system:
            raise RuntimeError('Cannot call self.view() with Expert system')
        views = self.viewmaker(imgs)
        views = self.normalize(views)
        return views

    def create_encoder(self):
        '''Create the encoder model.'''
        if self.config.model_params.resnet_small:
            # ResNet variant for smaller inputs (e.g. CIFAR-10).
            encoder_model = resnet_small.ResNet18(self.config.model_params.out_dim)
        else:
            resnet_class = getattr(
                torchvision.models, 
                self.config.model_params.resnet_version,
            )
            encoder_model = resnet_class(
                pretrained=False,
                num_classes=self.config.model_params.out_dim,
            )
        if self.config.model_params.projection_head:
            mlp_dim = encoder_model.fc.weight.size(1)
            encoder_model.fc = nn.Sequential(
                nn.Linear(mlp_dim, mlp_dim),
                nn.ReLU(),
                encoder_model.fc,
            )
        return encoder_model

    def create_viewmaker(self):
        view_model = viewmaker.Viewmaker(
            num_channels=self.train_dataset.NUM_CHANNELS,
            distortion_budget=self.config.model_params.view_bound_magnitude,
            activation=self.config.model_params.generator_activation or 'relu',
            clamp=self.config.model_params.clamp_views,
            frequency_domain=self.config.model_params.spectral or False,
            downsample_to=self.config.model_params.viewmaker_downsample or False,
            num_res_blocks=self.config.model_params.num_res_blocks or 5,
        )
        return view_model

    def noise(self, batch_size, device):
        shape = (batch_size, self.config.model_params.noise_dim)
        # Center noise at 0 then project to unit sphere.
        noise = utils.l2_normalize(torch.rand(shape, device=device) - 0.5)
        return noise
    
    def get_repr(self, img):
        '''Get the representation for a given image.'''
        if 'Expert' not in self.config.system:
            # The Expert system datasets are normalized already.
            img = self.normalize(img)
        return self.model(img)
    
    def normalize(self, imgs):
        # These numbers were computed using compute_image_dset_stats.py
        if 'cifar' in self.config.data_params.dataset:
            mean = torch.tensor([0.491, 0.482, 0.446], device=imgs.device)
            std = torch.tensor([0.247, 0.243, 0.261], device=imgs.device)
        else:
            raise ValueError(f'Dataset normalizer for {self.config.data_params.dataset} not implemented')
        imgs = (imgs - mean[None, :, None, None]) / std[None, :, None, None]
        return imgs

    def forward(self, batch, train=True):
        indices, img, img2, neg_img, _, = batch
        if self.loss_name == 'AdversarialNCELoss':
            view1 = self.view(img)
            view1_embs = self.model(view1)
            emb_dict = {
                'indices': indices,
                'view1_embs': view1_embs,
            }
        elif self.loss_name == 'AdversarialSimCLRLoss':
            if self.config.model_params.double_viewmaker:
                view1, view2 = self.view(img)
            else:
                view1 = self.view(img)
                view2 = self.view(img2)
            emb_dict = {
                'indices': indices,
                'view1_embs': self.model(view1),
                'view2_embs': self.model(view2),
            }
        else:
            raise ValueError(f'Unimplemented loss_name {self.loss_name}.')
        
        if self.global_step % 200 == 0:
            # Log some example views. 
            views_to_log = view1.permute(0,2,3,1).detach().cpu().numpy()[:10]
            wandb.log({"examples": [wandb.Image(view, caption=f"Epoch: {self.current_epoch}, Step {self.global_step}, Train {train}") for view in views_to_log]})

        return emb_dict

    def get_losses_for_batch(self, emb_dict, train=True):
        if self.loss_name == 'AdversarialSimCLRLoss':
            view_maker_loss_weight = self.config.loss_params.view_maker_loss_weight
            loss_function = AdversarialSimCLRLoss(
                embs1=emb_dict['view1_embs'],
                embs2=emb_dict['view2_embs'],
                t=self.t,
                view_maker_loss_weight=view_maker_loss_weight
            )
            encoder_loss, view_maker_loss = loss_function.get_loss()
            img_embs = emb_dict['view1_embs'] 
        elif self.loss_name == 'AdversarialNCELoss':
            view_maker_loss_weight = self.config.loss_params.view_maker_loss_weight
            loss_function = AdversarialNCELoss(
                emb_dict['indices'],
                emb_dict['view1_embs'],
                self.memory_bank,
                k=self.config.loss_params.k,
                t=self.t,
                m=self.config.loss_params.m,
                view_maker_loss_weight=view_maker_loss_weight
            )
            encoder_loss, view_maker_loss = loss_function.get_loss()
            img_embs = emb_dict['view1_embs'] 
        else:
            raise Exception(f'Objective {self.loss_name} is not supported.') 
        
        # Update memory bank.
        if train:
            with torch.no_grad():
                if self.loss_name == 'AdversarialNCELoss':
                    new_data_memory = loss_function.updated_new_data_memory()
                    self.memory_bank.update(emb_dict['indices'], new_data_memory)
                else:
                    new_data_memory = utils.l2_normalize(img_embs, dim=1)
                    self.memory_bank.update(emb_dict['indices'], new_data_memory)

        return encoder_loss, view_maker_loss

    def get_nearest_neighbor_label(self, img_embs, labels):
        '''
        Used for online kNN classifier.
        For each image in validation, find the nearest image in the 
        training dataset using the memory bank. Assume its label as
        the predicted label.
        '''
        batch_size = img_embs.size(0)
        all_dps = self.memory_bank.get_all_dot_products(img_embs)
        _, neighbor_idxs = torch.topk(all_dps, k=1, sorted=False, dim=1)
        neighbor_idxs = neighbor_idxs.squeeze(1)
        neighbor_idxs = neighbor_idxs.cpu().numpy()

        neighbor_labels = self.train_ordered_labels[neighbor_idxs]
        neighbor_labels = torch.from_numpy(neighbor_labels).long()

        num_correct = torch.sum(neighbor_labels.cpu() == labels.cpu()).item()
        return num_correct, batch_size

    def training_step(self, batch, batch_idx, optimizer_idx):
        emb_dict = self.forward(batch)
        emb_dict['optimizer_idx'] = torch.tensor(optimizer_idx, device=self.device)
        return emb_dict
    
    def training_step_end(self, emb_dict):
        encoder_loss, view_maker_loss = self.get_losses_for_batch(emb_dict, train=True)

        # Handle Tensor (dp) and int (ddp) cases
        if emb_dict['optimizer_idx'].__class__ == int or emb_dict['optimizer_idx'].dim() == 0:
            optimizer_idx = emb_dict['optimizer_idx'] 
        else:
            optimizer_idx = emb_dict['optimizer_idx'][0]
        if optimizer_idx == 0:
            metrics = {
                'encoder_loss': encoder_loss, 'temperature': self.t
            }
            return {'loss': encoder_loss, 'log': metrics}
        else:
            metrics = {
                'view_maker_loss': view_maker_loss,
            }
            return {'loss': view_maker_loss, 'log': metrics}

    def validation_step(self, batch, batch_idx):
        emb_dict = self.forward(batch, train=False)
        if 'img_embs' in emb_dict:
            img_embs = emb_dict['img_embs']
        else:
            _, img, _, _, _ = batch
            img_embs = self.get_repr(img)  # Need encoding of image without augmentations (only normalization).
        labels = batch[-1]
        encoder_loss, view_maker_loss = self.get_losses_for_batch(emb_dict, train=False)

        num_correct, batch_size = self.get_nearest_neighbor_label(img_embs, labels)
        output = OrderedDict({
            'val_loss': encoder_loss + view_maker_loss,
            'val_encoder_loss': encoder_loss,
            'val_view_maker_loss': view_maker_loss,
            'val_num_correct': torch.tensor(num_correct, dtype=float, device=self.device),
            'val_num_total': torch.tensor(batch_size, dtype=float, device=self.device),
        })

        return output

    def validation_epoch_end(self, outputs):
        metrics = {}
        for key in outputs[0].keys():
            try:
                metrics[key] = torch.stack([elem[key] for elem in outputs]).mean()
            except:
                pass

        num_correct = torch.stack([out['val_num_correct'] for out in outputs]).sum()
        num_total = torch.stack([out['val_num_total'] for out in outputs]).sum()
        val_acc = num_correct / float(num_total)
        metrics['val_acc'] = val_acc
        progress_bar = {'acc': val_acc}
        return {'val_loss': metrics['val_loss'], 
                'log': metrics, 
                'val_acc': val_acc, 
                'progress_bar': progress_bar}

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, 
                       second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
        if not self.config.optim_params.viewmaker_freeze_epoch:
            super().optimizer_step(current_epoch, batch_nb, optimizer, optimizer_idx)
            return

        if optimizer_idx == 0:
            optimizer.step()
            optimizer.zero_grad()
        elif current_epoch < self.config.optim_params.viewmaker_freeze_epoch:
            # Optionally freeze the viewmaker at a certain pretraining epoch.
            optimizer.step()
            optimizer.zero_grad()

    def configure_optimizers(self):
        # Optimize temperature with encoder.
        if type(self.t) == float or type(self.t) == int:
            encoder_params = self.model.parameters()
        else:
            encoder_params = list(self.model.parameters()) + [self.t]

        encoder_optim = torch.optim.SGD(
            encoder_params,
            lr=self.config.optim_params.learning_rate,
            momentum=self.config.optim_params.momentum,
            weight_decay=self.config.optim_params.weight_decay,
        )
        view_optim_name = self.config.optim_params.viewmaker_optim
        view_parameters = self.viewmaker.parameters()
        if view_optim_name == 'adam':
            view_optim = torch.optim.Adam(
                view_parameters, lr=self.config.optim_params.viewmaker_learning_rate or 0.001)
        elif not view_optim_name or view_optim_name == 'sgd':
            view_optim = torch.optim.SGD(
                view_parameters,
                lr=self.config.optim_params.viewmaker_learning_rate or self.config.optim_params.learning_rate,
                momentum=self.config.optim_params.momentum,
                weight_decay=self.config.optim_params.weight_decay,
            )
        else:
            raise ValueError(f'Optimizer {view_optim_name} not implemented')
        
        return [encoder_optim, view_optim], []

    def train_dataloader(self):
        return create_dataloader(self.train_dataset, self.config, self.batch_size)

    def val_dataloader(self):
        return create_dataloader(self.val_dataset, self.config, self.batch_size, 
                                 shuffle=False, drop_last=False)
Пример #7
0
class PretrainExpertSystem(PretrainViewMakerSystem):
    '''Pytorch Lightning System for self-supervised pretraining 
    with expert image views as described in Instance Discrimination 
    or SimCLR.
    '''

    def __init__(self, config):
        super(PretrainViewMakerSystem, self).__init__()
        self.config = config
        self.batch_size = config.optim_params.batch_size
        self.loss_name = self.config.loss_params.name
        self.t = self.config.loss_params.t

        default_augmentations = self.config.data_params.default_augmentations
        # DotMap is the default argument when a config argument is missing
        if default_augmentations == DotMap():
           default_augmentations = 'all'
        self.train_dataset, self.val_dataset = datasets.get_image_datasets(
            config.data_params.dataset,
            default_augmentations=default_augmentations,
        )
        train_labels = self.train_dataset.dataset.targets
        self.train_ordered_labels = np.array(train_labels)
        self.model = self.create_encoder()
        self.memory_bank = MemoryBank(
            len(self.train_dataset), 
            self.config.model_params.out_dim, 
        )

    def forward(self, img):
        return self.model(img)

    def get_losses_for_batch(self, emb_dict, train=True):
        if self.loss_name == 'nce':
            loss_fn = NoiseConstrastiveEstimation(emb_dict['indices'], emb_dict['img_embs_1'], self.memory_bank,
                                                  k=self.config.loss_params.k,
                                                  t=self.t,
                                                  m=self.config.loss_params.m)
            loss = loss_fn.get_loss()
        elif self.loss_name == 'simclr':
            if 'img_embs_2' not in emb_dict:
                raise ValueError(f'img_embs_2 is required for SimCLR loss')
            loss_fn = SimCLRObjective(emb_dict['img_embs_1'], emb_dict['img_embs_2'], t=self.t)
            loss = loss_fn.get_loss()
        else:
            raise Exception(f'Objective {self.loss_name} is not supported.')

        if train:
            with torch.no_grad():
                if self.loss_name == 'nce':
                    new_data_memory = loss_fn.updated_new_data_memory()
                    self.memory_bank.update(emb_dict['indices'], new_data_memory)
                elif 'simclr' in self.loss_name:
                    outputs_avg = (utils.l2_normalize(emb_dict['img_embs_1'], dim=1) + 
                                   utils.l2_normalize(emb_dict['img_embs_2'], dim=1)) / 2.
                    indices = emb_dict['indices']
                    self.memory_bank.update(indices, outputs_avg)
                else:
                    raise Exception(f'Objective {self.loss_name} is not supported.')

        return loss

    def configure_optimizers(self):
        encoder_params = self.model.parameters()

        if self.config.optim_params.adam:
            optim = torch.optim.AdamW(encoder_params)
        else:
            optim = torch.optim.SGD(
                encoder_params,
                lr=self.config.optim_params.learning_rate,
                momentum=self.config.optim_params.momentum,
                weight_decay=self.config.optim_params.weight_decay,
            )
        return [optim], []

    def training_step(self, batch, batch_idx):
        emb_dict = {}
        indices, img, img2, neg_img, labels, = batch
        if self.loss_name == 'nce':
            emb_dict['img_embs_1'] = self.forward(img)
        elif 'simclr' in self.loss_name:
            emb_dict['img_embs_1'] = self.forward(img)
            emb_dict['img_embs_2'] = self.forward(img2)

        emb_dict['indices'] = indices
        emb_dict['labels'] = labels
        return emb_dict

    def training_step_end(self, emb_dict):
        loss = self.get_losses_for_batch(emb_dict, train=True)
        metrics = {'loss': loss, 'temperature': self.t}
        return {'loss': loss, 'log': metrics}
    
    def validation_step(self, batch, batch_idx):
        emb_dict = {}
        indices, img, img2, neg_img, labels, = batch
        if self.loss_name == 'nce':
            emb_dict['img_embs_1'] = self.forward(img)
        elif 'simclr' in self.loss_name:
            emb_dict['img_embs_1'] = self.forward(img)
            emb_dict['img_embs_2'] = self.forward(img2)

        emb_dict['indices'] = indices
        emb_dict['labels'] = labels
        img_embs = emb_dict['img_embs_1']
        
        loss = self.get_losses_for_batch(emb_dict, train=False)

        num_correct, batch_size = self.get_nearest_neighbor_label(img_embs, labels)
        output = OrderedDict({
            'val_loss': loss,
            'val_num_correct': torch.tensor(num_correct, dtype=float, device=self.device),
            'val_num_total': torch.tensor(batch_size, dtype=float, device=self.device),
        })
        return output
Пример #8
0
class PretrainExpertInstDiscSystem(pl.LightningModule):
    '''Pretraining with Instance Discrimination
    
    NOTE: only the SimCLR model was used for PAMAP2 in the paper.
    '''
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.optim_params.batch_size
        self.train_dataset, self.val_dataset = self.create_datasets()
        self.model = self.create_encoder()
        self.memory_bank = MemoryBank(len(self.train_dataset),
                                      self.config.model_params.out_dim)
        self.memory_bank_labels = MemoryBank(len(self.train_dataset),
                                             1,
                                             dtype=int)

    def create_datasets(self):
        print('Initializing validation dataset.')
        # We use a larger default value of 50k examples for validation to reduce variance.
        val_dataset = PAMAP2(
            mode='val',
            examples_per_epoch=self.config.data_params.val_examples_per_epoch
            or 50000)

        if not self.config.quick:
            print('Initializing train dataset.')
            train_dataset = PAMAP2(mode='train',
                                   examples_per_epoch=self.config.data_params.
                                   train_examples_per_epoch or 10000)
            if not self.config.data_params.train_examples_per_epoch:
                print(
                    'WARNING: self.config.data_params.train_examples_per_epoch not specified. Using default value of 10k'
                )
        else:
            train_dataset = val_dataset
        return train_dataset, val_dataset

    def create_encoder(self):
        if self.config.model_params.resnet_small:
            encoder_model = resnet_small.ResNet18(
                self.config.model_params.out_dim,
                num_channels=52,  # 52 feature spectrograms
                input_size=32,
            )
        else:
            resnet_class = getattr(
                torchvision.models,
                self.config.model_params.resnet_version,
            )
            encoder_model = resnet_class(
                pretrained=False,
                num_classes=self.config.model_params.out_dim,
            )
            encoder_model.conv1 = nn.Conv2d(1,
                                            64,
                                            kernel_size=7,
                                            stride=2,
                                            padding=3,
                                            bias=False)

        if self.config.model_params.projection_head:
            mlp_dim = encoder_model.linear.weight.size(1)
            encoder_model.linear = nn.Sequential(
                nn.Linear(mlp_dim, mlp_dim),
                nn.ReLU(),
                encoder_model.linear,
            )
        return encoder_model

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.model.parameters(),
            lr=self.config.optim_params.learning_rate,
            momentum=self.config.optim_params.momentum,
            weight_decay=self.config.optim_params.weight_decay)
        return [optim], []

    def forward(self, inputs):
        return self.model(inputs)

    def get_losses_for_batch(self, batch):
        indices, inputs1, inputs2, _ = batch
        outputs = self.forward(inputs1)
        loss_fn = NoiseConstrastiveEstimation(indices,
                                              outputs,
                                              self.memory_bank,
                                              k=self.config.loss_params.k,
                                              t=self.config.loss_params.t,
                                              m=self.config.loss_params.m)
        loss = loss_fn.get_loss()

        with torch.no_grad():
            new_data_memory = loss_fn.updated_new_data_memory()
            self.memory_bank.update(indices, new_data_memory)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.get_losses_for_batch(batch)
        metrics = {'loss': loss}
        return {'loss': loss, 'log': metrics}

    def get_nearest_neighbor_label(self, embs, labels):
        """
        NOTE: ONLY TO BE USED FOR VALIDATION.
        
        For each example in validation, find the nearest example in the 
        training dataset using the memory bank. Assume its label as
        the predicted label.
        """
        all_dps = self.memory_bank.get_all_dot_products(embs)
        _, neighbor_idxs = torch.topk(all_dps, k=1, sorted=False, dim=1)
        neighbor_idxs = neighbor_idxs.squeeze(1)

        neighbor_labels = self.memory_bank_labels.at_idxs(
            neighbor_idxs).squeeze(-1)
        num_correct = torch.sum(neighbor_labels.cpu() == labels.cpu()).item()

        return num_correct, embs.size(0)

    def validation_step(self, batch, batch_idx):
        _, inputs1, inputs2, labels = batch
        outputs = self.model(inputs1)
        num_correct, batch_size = self.get_nearest_neighbor_label(
            outputs, labels)
        num_correct = torch.tensor(num_correct,
                                   dtype=float,
                                   device=self.device)
        batch_size = torch.tensor(batch_size, dtype=float, device=self.device)
        return OrderedDict({
            'val_num_correct': num_correct,
            'val_num_total': batch_size
        })

    def validation_epoch_end(self, outputs):
        metrics = {}
        for key in outputs[0].keys():
            metrics[key] = torch.stack([elem[key] for elem in outputs]).mean()
        num_correct = torch.stack([out['val_num_correct']
                                   for out in outputs]).sum()
        num_total = torch.stack([out['val_num_total']
                                 for out in outputs]).sum()
        val_acc = num_correct / float(num_total)
        metrics['val_acc'] = val_acc
        progress_bar = {'acc': val_acc}
        return {
            'log': metrics,
            'val_acc': val_acc,
            'progress_bar': progress_bar
        }

    def train_dataloader(self):
        return create_dataloader(self.train_dataset, self.config,
                                 self.batch_size)

    def val_dataloader(self):
        return create_dataloader(self.val_dataset,
                                 self.config,
                                 self.batch_size,
                                 shuffle=False)