class finetuneSIMCLR(pl.LightningModule):
    def __init__(self, encoder, DATA_PATH, withhold, batch_size, val_split,
                 hidden_dims, train_transform, val_transform, num_workers,
                 **kwargs):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.val_split = val_split
        self.batch_size = batch_size
        self.hidden_dims = hidden_dims
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.num_workers = num_workers
        self.withhold = withhold

        #data stuff
        shutil.rmtree('split_data', ignore_errors=True)
        if not (path.isdir(f"{self.DATA_PATH}/train")
                and path.isdir(f"{self.DATA_PATH}/val")):
            splitfolders.ratio(self.DATA_PATH,
                               output=f"split_data",
                               ratio=(1 - self.val_split - self.withhold,
                                      self.val_split, self.withhold),
                               seed=10)
            self.DATA_PATH = 'split_data'
            print(
                f'automatically splitting data into train and validation data {self.val_split} and withhold {self.withhold}'
            )

        self.num_classes = len(os.listdir(f'{self.DATA_PATH}/train'))

        #model stuff
        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        print('KWARGS:', kwargs)
        self.encoder, self.embedding_size = load_encoder(encoder, kwargs)

        self.linear_layer = SSLEvaluator(n_input=self.embedding_size,
                                         n_classes=self.num_classes,
                                         p=0.1,
                                         n_hidden=self.hidden_dims)


#   def forward(self, x):
#       x = self.encoder(x)[0]
#       x = F.log_softmax(self.fc1(x), dim = 1)
#       return x

    def shared_step(self, batch):
        x, y = batch
        feats = self.encoder(x)[-1]
        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        loss = self.loss_fn(logits, y)
        return loss, logits, y

    def training_step(self, batch, batch_idx):
        loss, logits, y = self.shared_step(batch)
        acc = self.train_acc(logits, y)
        self.log('tloss', loss, prog_bar=True)
        self.log('tastep', acc, prog_bar=True)
        self.log('ta_epoch', self.train_acc)

        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            loss, logits, y = self.shared_step(batch)
            acc = self.val_acc(logits, y)

        acc = self.val_acc(logits, y)
        self.log('vloss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc_epoch', self.val_acc, prog_bar=True)

        return loss

    def loss_fn(self, logits, labels):
        return F.cross_entropy(logits, labels)

    def configure_optimizers(self):
        opt = SGD([{
            'params': self.encoder.parameters()
        }, {
            'params': self.linear_layer.parameters(),
            'lr': 0.1
        }],
                  lr=1e-4,
                  momentum=0.9)

        return [opt]

    def prepare_data(self):

        train_pipeline = self.train_transform(
            DATA_PATH=f"{self.DATA_PATH}/train",
            input_height=256,
            batch_size=self.batch_size,
            num_threads=self.num_workers,
            device_id=0)
        print(f"{self.DATA_PATH}/train")
        val_pipeline = self.val_transform(DATA_PATH=f"{self.DATA_PATH}/val",
                                          input_height=256,
                                          batch_size=self.batch_size,
                                          num_threads=self.num_workers,
                                          device_id=0)

        class LightningWrapper(DALIClassificationIterator):
            def __init__(self, *kargs, **kvargs):
                super().__init__(*kargs, **kvargs)

            def __next__(self):
                out = super().__next__()
                out = out[0]
                return [
                    out[k] if k != "label" else torch.squeeze(out[k])
                    for k in self.output_map
                ]

        self.train_loader = LightningWrapper(train_pipeline,
                                             fill_last_batch=False,
                                             auto_reset=True,
                                             reader_name="Reader")
        self.val_loader = LightningWrapper(val_pipeline,
                                           fill_last_batch=False,
                                           auto_reset=True,
                                           reader_name="Reader")

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader
Beispiel #2
0
class CLASSIFIER(pl.LightningModule):  #SSLFineTuner
    def __init__(self,
                 encoder,
                 DATA_PATH,
                 VAL_PATH,
                 hidden_dim,
                 image_size,
                 seed,
                 cpus,
                 transform=SimCLRTransform,
                 **classifier_hparams):
        super().__init__()

        self.DATA_PATH = DATA_PATH
        self.VAL_PATH = VAL_PATH
        self.transform = transform
        self.image_size = image_size
        self.cpus = cpus
        self.seed = seed

        self.batch_size = classifier_hparams['batch_size']
        self.classifier_hparams = classifier_hparams

        self.linear_layer = SSLEvaluator(
            n_input=encoder.embedding_size,
            n_classes=self.classifier_hparams['num_classes'],
            p=self.classifier_hparams['dropout'],
            n_hidden=hidden_dim)

        self.train_acc = Accuracy()
        self.val_acc = Accuracy(compute_on_step=False)
        self.encoder = encoder

        self.weights = None

        print(classifier_hparams)
        if classifier_hparams['weights'] is not None:
            self.weights = torch.tensor([
                float(item)
                for item in classifier_hparams['weights'].split(',')
            ])
            self.weights = self.weights.cuda()

        self.save_hyperparameters()

    #override optimizer to allow modification of encoder learning rate
    def configure_optimizers(self):
        optimizer = SGD([{
            'params': self.encoder.parameters(),
            'lr': 0
        }, {
            'params': self.linear_layer.parameters(),
            'lr': self.classifier_hparams['linear_lr']
        }],
                        lr=self.classifier_hparams['learning_rate'],
                        momentum=self.classifier_hparams['momentum'])

        if self.classifier_hparams['scheduler_type'] == "step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                self.classifier_hparams['decay_epochs'],
                gamma=self.classifier_hparams['gamma'])
        elif self.classifier_hparams['scheduler_type'] == "cosine":

            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                self.classifier_hparams['epochs'],
                eta_min=self.classifier_hparams[
                    'final_lr']  # total epochs to run
            )

        return [optimizer], [scheduler]

    def forward(self, x):
        feats = self.encoder(x)[-1]
        feats = feats.view(feats.size(0), -1)
        logits = self.linear_layer(feats)
        return logits

    def shared_step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss_fn(logits, y)
        return loss, logits, y

    def training_step(self, batch, batch_idx):

        loss, logits, y = self.shared_step(batch)
        acc = self.train_acc(logits, y)
        self.log('tloss', loss, prog_bar=True)
        self.log('tastep', acc, prog_bar=True)
        self.log('ta_epoch', self.train_acc)

        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            loss, logits, y = self.shared_step(batch)
            acc = self.val_acc(logits, y)

        acc = self.val_acc(logits, y)
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        self.log('val_acc_epoch', self.val_acc, prog_bar=True)
        self.log('val_acc_epoch', self.val_acc, prog_bar=True)
        return loss

    def loss_fn(self, logits, labels):
        return F.cross_entropy(logits, labels, weight=self.weights)

    def setup(self, stage='inference'):
        Options = Enum('Loader', 'fit test inference')
        if stage == Options.fit.name:
            train = self.transform(self.DATA_PATH,
                                   batch_size=self.batch_size,
                                   input_height=self.image_size,
                                   copies=1,
                                   stage='train',
                                   num_threads=self.cpus,
                                   device_id=self.local_rank,
                                   seed=self.seed)
            val = self.transform(self.VAL_PATH,
                                 batch_size=self.batch_size,
                                 input_height=self.image_size,
                                 copies=1,
                                 stage='validation',
                                 num_threads=self.cpus,
                                 device_id=self.local_rank,
                                 seed=self.seed)
            self.train_loader = ClassifierWrapper(transform=train)
            self.val_loader = ClassifierWrapper(transform=val)

        elif stage == Options.inference.name:
            self.test_dataloader = ClassifierWrapper(
                transform=self.transform(self.DATA_PATH,
                                         batch_size=self.batch_size,
                                         input_height=self.image_size,
                                         copies=1,
                                         stage='inference',
                                         num_threads=2 * self.cpus,
                                         device_id=self.local_rank,
                                         seed=self.seed))
            self.inference_dataloader = self.test_dataloader

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    #give user permission to add extra arguments for SIMSIAM model particularly. This cannot share the name of any parameters from train.py
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # training params
        parser.add_argument("--linear_lr",
                            default=1e-1,
                            type=float,
                            help="learning rate for classification head.")
        parser.add_argument("--dropout",
                            default=0.1,
                            type=float,
                            help="dropout of neurons during training [0-1].")
        parser.add_argument("--nesterov",
                            default=False,
                            type=bool,
                            help="Use nesterov during training.")
        parser.add_argument(
            "--scheduler_type",
            default='cosine',
            type=str,
            help="learning rate scheduler: ['cosine' or 'step']")
        parser.add_argument("--gamma",
                            default=0.1,
                            type=float,
                            help="gamma param for learning rate.")
        parser.add_argument("--decay_epochs",
                            default=[60, 80],
                            type=list,
                            help="epochs to do optimizer decay")
        parser.add_argument("--weight_decay",
                            default=1e-6,
                            type=float,
                            help="weight decay")
        parser.add_argument("--final_lr",
                            type=float,
                            default=1e-6,
                            help="final learning rate")
        parser.add_argument("--momentum",
                            type=float,
                            default=0.9,
                            help="momentum for learning rate")
        parser.add_argument(
            '--weights',
            type=str,
            help='delimited list of weights for penalty during classification')
        return parser
class SSLOnlineEvaluator(Callback):

    def __init__(self, data_dir, z_dim, max_epochs=10, check_val_every_n_epoch=1, batch_size=1024, num_workers=32):
        self.z_dim = z_dim
        self.max_epochs = max_epochs
        self.check_val_every_n_epoch = check_val_every_n_epoch

        self.datamodule = BigearthnetDataModule(
            data_dir=data_dir,
            train_frac=0.01,
            val_frac=0.01,
            lmdb=True,
            batch_size=batch_size,
            num_workers=num_workers
        )
        self.datamodule.setup()

        self.criterion = nn.MultiLabelSoftMarginLoss()
        self.metric = lambda output, target: average_precision_score(target, output, average='micro') * 100.0

    def on_pretrain_routine_start(self, trainer, pl_module):
        self.classifier = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.datamodule.num_classes,
            n_hidden=None
        ).to(pl_module.device)

        self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3)

    def on_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch + 1) % self.check_val_every_n_epoch != 0:
            return

        encoder = pl_module.encoder_q

        self.classifier.train()
        for _ in range(self.max_epochs):
            for inputs, targets in self.datamodule.train_dataloader():
                inputs = inputs.to(pl_module.device)
                targets = targets.to(pl_module.device)

                with torch.no_grad():
                    representations = encoder(inputs)
                representations = representations.detach()

                logits = self.classifier(representations)
                loss = self.criterion(logits, targets)

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

        self.classifier.eval()
        accuracies = []
        for inputs, targets in self.datamodule.val_dataloader():
            inputs = inputs.to(pl_module.device)

            with torch.no_grad():
                representations = encoder(inputs)
            representations = representations.detach()

            logits = self.classifier(representations)
            preds = torch.sigmoid(logits).detach().cpu()
            acc = self.metric(preds, targets)
            accuracies.append(acc)
        acc = torch.mean(torch.tensor(accuracies))

        metrics = {'online_val_acc': acc}
        trainer.logger_connector.log_metrics(metrics, {})
        trainer.logger_connector.add_progress_bar_metrics(metrics)