def init_encoder(self):
        dummy_batch = torch.zeros((2, self.hparams.image_channels, self.hparams.image_height,
                                   self.hparams.image_height))
        encoder_name = self.hparams.encoder

        if encoder_name == 'amdim_encoder':
            encoder = AMDIMEncoder(
                dummy_batch,
                num_channels=self.hparams.image_channels,
                encoder_feature_dim=self.hparams.encoder_feature_dim,
                embedding_fx_dim=self.hparams.embedding_fx_dim,
                conv_block_depth=self.hparams.conv_block_depth,
                encoder_size=self.hparams.image_height,
                use_bn=self.hparams.use_bn
            )
            encoder.init_weights()
            return encoder
        else:
            return torchvision_ssl_encoder(encoder_name, return_all_feature_maps=True)
예제 #2
0
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams

        dummy_batch = torch.zeros(
            (2, 3, hparams.image_height, hparams.image_height))

        self.encoder = AMDIMEncoder(dummy_batch,
                                    num_channels=3,
                                    ndf=hparams.ndf,
                                    n_rkhs=hparams.n_rkhs,
                                    n_depth=hparams.n_depth,
                                    encoder_size=hparams.image_height,
                                    use_bn=hparams.use_bn)
        self.encoder.init_weights()

        # the loss has learnable parameters
        self.nce_loss = AMDIMLossNCE(tclip=self.hparams.tclip)

        self.tng_split = None
        self.val_split = None
예제 #3
0
class AMDIM(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams

        dummy_batch = torch.zeros(
            (2, 3, hparams.image_height, hparams.image_height))

        self.encoder = AMDIMEncoder(dummy_batch,
                                    num_channels=3,
                                    ndf=hparams.ndf,
                                    n_rkhs=hparams.n_rkhs,
                                    n_depth=hparams.n_depth,
                                    encoder_size=hparams.image_height,
                                    use_bn=hparams.use_bn)
        self.encoder.init_weights()

        # the loss has learnable parameters
        self.nce_loss = AMDIMLossNCE(tclip=self.hparams.tclip)

        self.tng_split = None
        self.val_split = None

    def forward(self, img_1, img_2):
        # feats for img 1
        # r1 = last layer out
        # r5 = last layer with (b, c, 5, 5) size
        # r7 = last layer with (b, c, 7, 7) size
        r1_x1, r5_x1, r7_x1 = self.encoder(img_1)

        # feats for img 2
        r1_x2, r5_x2, r7_x2 = self.encoder(img_2)

        # first number = resnet block. second = image 1 or 2
        return r1_x1, r5_x1, r7_x1, r1_x2, r5_x2, r7_x2

    def training_step(self, batch, batch_nb):
        [img_1, img_2], _ = batch

        # ------------------
        # FEATURE EXTRACTION
        # extract features from various blocks for each image
        # _x1 are from image 1
        # _x2 from image 2
        r1_x1, r5_x1, r7_x1, r1_x2, r5_x2, r7_x2 = self.forward(img_1, img_2)

        result = {
            'r1_x1': r1_x1,
            'r5_x1': r5_x1,
            'r7_x1': r7_x1,
            'r1_x2': r1_x2,
            'r5_x2': r5_x2,
            'r7_x2': r7_x2,
        }

        return result

    def training_end(self, outputs):
        r1_x1 = outputs['r1_x1']
        r5_x1 = outputs['r5_x1']
        r7_x1 = outputs['r7_x1']
        r1_x2 = outputs['r1_x2']
        r5_x2 = outputs['r5_x2']
        r7_x2 = outputs['r7_x2']

        # ------------------
        # NCE LOSS
        loss_1t5, loss_1t7, loss_5t5, lgt_reg = self.nce_loss(
            r1_x1, r5_x1, r7_x1, r1_x2, r5_x2, r7_x2)
        unsupervised_loss = loss_1t5 + loss_1t7 + loss_5t5 + lgt_reg

        # ------------------
        # FULL LOSS
        total_loss = unsupervised_loss

        tensorboard_logs = {'train_nce_loss': total_loss}
        result = {'loss': total_loss, 'log': tensorboard_logs}

        return result

    def validation_step(self, batch, batch_nb):
        [img_1, img_2], labels = batch

        # generate features
        r1_x1, r5_x1, r7_x1, r1_x2, r5_x2, r7_x2 = self.forward(img_1, img_2)

        # NCE LOSS
        loss_1t5, loss_1t7, loss_5t5, lgt_reg = self.nce_loss(
            r1_x1, r5_x1, r7_x1, r1_x2, r5_x2, r7_x2)
        unsupervised_loss = loss_1t5 + loss_1t7 + loss_5t5 + lgt_reg

        result = {'val_nce': unsupervised_loss}
        return result

    def validation_epoch_end(self, outputs):
        val_nce = 0
        for output in outputs:
            val_nce += output['val_nce']

        val_nce = val_nce / len(outputs)
        tensorboard_logs = {'val_nce': val_nce}
        return {'val_loss': val_nce, 'log': tensorboard_logs}

    def configure_optimizers(self):
        opt = optim.Adam(params=self.parameters(),
                         lr=self.hparams.learning_rate,
                         betas=(0.8, 0.999),
                         weight_decay=1e-5,
                         eps=1e-7)

        if self.hparams.dataset_name in ['CIFAR10', 'stl_10', 'CIFAR100']:
            lr_scheduler = MultiStepLR(opt, milestones=[250, 280], gamma=0.2)
        else:
            lr_scheduler = MultiStepLR(opt, milestones=[30, 45], gamma=0.2)

        return opt  # [opt], [lr_scheduler]

    def train_dataloader(self):
        if self.hparams.dataset_name == 'CIFAR10':
            dataset = AMDIMPretraining.cifar10_train(self.hparams.data_dir)

        if self.hparams.dataset_name == 'stl_10':
            self.tng_split, self.val_split = AMDIMPretraining.stl_train(
                self.hparams.data_dir)
            dataset = self.tng_split

        if self.hparams.dataset_name == 'imagenet_128':
            dataset = AMDIMPretraining.imagenet_train(self.hparams.data_dir,
                                                      self.hparams.nb_classes)

        # LOADER
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            pin_memory=True,
            drop_last=True,
            num_workers=16,
        )
        return loader

    def val_dataloader(self):
        if self.hparams.dataset_name == 'CIFAR10':
            dataset = AMDIMPretraining.cifar10_val(self.hparams.data_dir)

        if self.hparams.dataset_name == 'stl_10':
            dataset = self.val_split

        if self.hparams.dataset_name == 'imagenet_128':
            dataset = AMDIMPretraining.imagenet_val(self.hparams.data_dir,
                                                    self.hparams.nb_classes)

        # LOADER
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            pin_memory=True,
            drop_last=True,
            num_workers=16,
        )
        return loader

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # CIFAR 10
        cf_root_lr = 2e-4
        cifar_10 = {
            'dataset_name':
            'CIFAR10',
            'ndf':
            320,
            'n_rkhs':
            1280,
            'depth':
            10,
            'image_height':
            32,
            'batch_size':
            200,
            'nb_classes':
            10,
            'lr_options': [
                cf_root_lr * 32,
                cf_root_lr * 16,
                cf_root_lr * 8,
                cf_root_lr * 4,
                cf_root_lr * 2,
                cf_root_lr,
                cf_root_lr * 1 / 2,
                cf_root_lr * 1 / 4,
                cf_root_lr * 1 / 8,
                cf_root_lr * 1 / 16,
                cf_root_lr * 1 / 32,
            ]
        }

        # stl-10
        stl_root_lr = 2e-4
        stl_10 = {
            'dataset_name':
            'stl_10',
            'ndf':
            192,
            'n_rkhs':
            1536,
            'depth':
            8,
            'image_height':
            64,
            'batch_size':
            200,
            'nb_classes':
            10,
            'lr_options': [
                stl_root_lr * 32,
                stl_root_lr * 16,
                stl_root_lr * 8,
                stl_root_lr * 4,
                stl_root_lr * 2,
                stl_root_lr,
                stl_root_lr * 1 / 2,
                stl_root_lr * 1 / 4,
                stl_root_lr * 1 / 8,
                stl_root_lr * 1 / 16,
                stl_root_lr * 1 / 32,
            ]
        }

        imagenet_root_lr = 2e-4
        imagenet_128 = {
            'dataset_name':
            'imagenet_128',
            'ndf':
            320,
            'n_rkhs':
            2560,
            'depth':
            10,
            'image_height':
            128,
            'batch_size':
            200,
            'nb_classes':
            1000,
            'lr_options': [
                imagenet_root_lr * 32,
                imagenet_root_lr * 16,
                imagenet_root_lr * 8,
                imagenet_root_lr * 4,
                imagenet_root_lr * 2,
                imagenet_root_lr,
                imagenet_root_lr * 1 / 2,
                imagenet_root_lr * 1 / 4,
                imagenet_root_lr * 1 / 8,
                imagenet_root_lr * 1 / 16,
                imagenet_root_lr * 1 / 32,
            ]
        }

        imagenet_128_large = {
            'dataset_name':
            'imagenet_128',
            'ndf':
            320,
            'n_rkhs':
            2560,
            'depth':
            10,
            'image_height':
            128,
            'batch_size':
            200,
            'nb_classes':
            1000,
            'lr_options': [
                imagenet_root_lr * 32,
                imagenet_root_lr * 16,
                imagenet_root_lr * 8,
                imagenet_root_lr * 4,
                imagenet_root_lr * 2,
                imagenet_root_lr,
                imagenet_root_lr * 1 / 2,
                imagenet_root_lr * 1 / 4,
                imagenet_root_lr * 1 / 8,
                imagenet_root_lr * 1 / 16,
                imagenet_root_lr * 1 / 32,
            ]
        }

        # dataset = cifar_10
        # dataset = stl_10
        dataset = cifar_10

        # dataset options
        parser.add_argument('--nb_classes',
                            default=dataset['nb_classes'],
                            type=int)

        # network params
        parser.add_argument('--tclip',
                            type=float,
                            default=20.0,
                            help='soft clipping range for NCE scores')
        parser.add_argument('--use_bn', type=int, default=0)
        parser.add_argument('--ndf',
                            type=int,
                            default=dataset['ndf'],
                            help='feature width for encoder')
        parser.add_argument(
            '--n_rkhs',
            type=int,
            default=dataset['n_rkhs'],
            help='number of dimensions in fake RKHS embeddings')
        parser.add_argument('--n_depth', type=int, default=dataset['depth'])
        parser.add_argument('--image_height',
                            type=int,
                            default=dataset['image_height'])

        # trainin params
        resnets = [
            'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
            'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2',
            'wide_resnet101_2'
        ]
        parser.add_argument('--dataset_name',
                            type=str,
                            default=dataset['dataset_name'])
        parser.add_argument('--batch_size',
                            type=int,
                            default=dataset['batch_size'],
                            help='input batch size (default: 200)')
        parser.add_argument('--learning_rate', type=float, default=0.0002)

        # data
        parser.add_argument('--data_dir', default=os.getcwd(), type=str)
        return parser