class finetuneSIMCLR(pl.LightningModule): def __init__(self, encoder, DATA_PATH, withhold, batch_size, val_split, hidden_dims, train_transform, val_transform, num_workers, **kwargs): super().__init__() self.DATA_PATH = DATA_PATH self.val_split = val_split self.batch_size = batch_size self.hidden_dims = hidden_dims self.train_transform = train_transform self.val_transform = val_transform self.num_workers = num_workers self.withhold = withhold #data stuff shutil.rmtree('split_data', ignore_errors=True) if not (path.isdir(f"{self.DATA_PATH}/train") and path.isdir(f"{self.DATA_PATH}/val")): splitfolders.ratio(self.DATA_PATH, output=f"split_data", ratio=(1 - self.val_split - self.withhold, self.val_split, self.withhold), seed=10) self.DATA_PATH = 'split_data' print( f'automatically splitting data into train and validation data {self.val_split} and withhold {self.withhold}' ) self.num_classes = len(os.listdir(f'{self.DATA_PATH}/train')) #model stuff self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) print('KWARGS:', kwargs) self.encoder, self.embedding_size = load_encoder(encoder, kwargs) self.linear_layer = SSLEvaluator(n_input=self.embedding_size, n_classes=self.num_classes, p=0.1, n_hidden=self.hidden_dims) # def forward(self, x): # x = self.encoder(x)[0] # x = F.log_softmax(self.fc1(x), dim = 1) # return x def shared_step(self, batch): x, y = batch feats = self.encoder(x)[-1] feats = feats.view(feats.size(0), -1) logits = self.linear_layer(feats) loss = self.loss_fn(logits, y) return loss, logits, y def training_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) acc = self.train_acc(logits, y) self.log('tloss', loss, prog_bar=True) self.log('tastep', acc, prog_bar=True) self.log('ta_epoch', self.train_acc) return loss def validation_step(self, batch, batch_idx): with torch.no_grad(): loss, logits, y = self.shared_step(batch) acc = self.val_acc(logits, y) acc = self.val_acc(logits, y) self.log('vloss', loss, prog_bar=True, sync_dist=True) self.log('val_acc_epoch', self.val_acc, prog_bar=True) return loss def loss_fn(self, logits, labels): return F.cross_entropy(logits, labels) def configure_optimizers(self): opt = SGD([{ 'params': self.encoder.parameters() }, { 'params': self.linear_layer.parameters(), 'lr': 0.1 }], lr=1e-4, momentum=0.9) return [opt] def prepare_data(self): train_pipeline = self.train_transform( DATA_PATH=f"{self.DATA_PATH}/train", input_height=256, batch_size=self.batch_size, num_threads=self.num_workers, device_id=0) print(f"{self.DATA_PATH}/train") val_pipeline = self.val_transform(DATA_PATH=f"{self.DATA_PATH}/val", input_height=256, batch_size=self.batch_size, num_threads=self.num_workers, device_id=0) class LightningWrapper(DALIClassificationIterator): def __init__(self, *kargs, **kvargs): super().__init__(*kargs, **kvargs) def __next__(self): out = super().__next__() out = out[0] return [ out[k] if k != "label" else torch.squeeze(out[k]) for k in self.output_map ] self.train_loader = LightningWrapper(train_pipeline, fill_last_batch=False, auto_reset=True, reader_name="Reader") self.val_loader = LightningWrapper(val_pipeline, fill_last_batch=False, auto_reset=True, reader_name="Reader") def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader
class CLASSIFIER(pl.LightningModule): #SSLFineTuner def __init__(self, encoder, DATA_PATH, VAL_PATH, hidden_dim, image_size, seed, cpus, transform=SimCLRTransform, **classifier_hparams): super().__init__() self.DATA_PATH = DATA_PATH self.VAL_PATH = VAL_PATH self.transform = transform self.image_size = image_size self.cpus = cpus self.seed = seed self.batch_size = classifier_hparams['batch_size'] self.classifier_hparams = classifier_hparams self.linear_layer = SSLEvaluator( n_input=encoder.embedding_size, n_classes=self.classifier_hparams['num_classes'], p=self.classifier_hparams['dropout'], n_hidden=hidden_dim) self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) self.encoder = encoder self.weights = None print(classifier_hparams) if classifier_hparams['weights'] is not None: self.weights = torch.tensor([ float(item) for item in classifier_hparams['weights'].split(',') ]) self.weights = self.weights.cuda() self.save_hyperparameters() #override optimizer to allow modification of encoder learning rate def configure_optimizers(self): optimizer = SGD([{ 'params': self.encoder.parameters(), 'lr': 0 }, { 'params': self.linear_layer.parameters(), 'lr': self.classifier_hparams['linear_lr'] }], lr=self.classifier_hparams['learning_rate'], momentum=self.classifier_hparams['momentum']) if self.classifier_hparams['scheduler_type'] == "step": scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, self.classifier_hparams['decay_epochs'], gamma=self.classifier_hparams['gamma']) elif self.classifier_hparams['scheduler_type'] == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, self.classifier_hparams['epochs'], eta_min=self.classifier_hparams[ 'final_lr'] # total epochs to run ) return [optimizer], [scheduler] def forward(self, x): feats = self.encoder(x)[-1] feats = feats.view(feats.size(0), -1) logits = self.linear_layer(feats) return logits def shared_step(self, batch): x, y = batch logits = self.forward(x) loss = self.loss_fn(logits, y) return loss, logits, y def training_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) acc = self.train_acc(logits, y) self.log('tloss', loss, prog_bar=True) self.log('tastep', acc, prog_bar=True) self.log('ta_epoch', self.train_acc) return loss def validation_step(self, batch, batch_idx): with torch.no_grad(): loss, logits, y = self.shared_step(batch) acc = self.val_acc(logits, y) acc = self.val_acc(logits, y) self.log('val_loss', loss, prog_bar=True, sync_dist=True) self.log('val_acc_epoch', self.val_acc, prog_bar=True) self.log('val_acc_epoch', self.val_acc, prog_bar=True) return loss def loss_fn(self, logits, labels): return F.cross_entropy(logits, labels, weight=self.weights) def setup(self, stage='inference'): Options = Enum('Loader', 'fit test inference') if stage == Options.fit.name: train = self.transform(self.DATA_PATH, batch_size=self.batch_size, input_height=self.image_size, copies=1, stage='train', num_threads=self.cpus, device_id=self.local_rank, seed=self.seed) val = self.transform(self.VAL_PATH, batch_size=self.batch_size, input_height=self.image_size, copies=1, stage='validation', num_threads=self.cpus, device_id=self.local_rank, seed=self.seed) self.train_loader = ClassifierWrapper(transform=train) self.val_loader = ClassifierWrapper(transform=val) elif stage == Options.inference.name: self.test_dataloader = ClassifierWrapper( transform=self.transform(self.DATA_PATH, batch_size=self.batch_size, input_height=self.image_size, copies=1, stage='inference', num_threads=2 * self.cpus, device_id=self.local_rank, seed=self.seed)) self.inference_dataloader = self.test_dataloader def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader #give user permission to add extra arguments for SIMSIAM model particularly. This cannot share the name of any parameters from train.py def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) # training params parser.add_argument("--linear_lr", default=1e-1, type=float, help="learning rate for classification head.") parser.add_argument("--dropout", default=0.1, type=float, help="dropout of neurons during training [0-1].") parser.add_argument("--nesterov", default=False, type=bool, help="Use nesterov during training.") parser.add_argument( "--scheduler_type", default='cosine', type=str, help="learning rate scheduler: ['cosine' or 'step']") parser.add_argument("--gamma", default=0.1, type=float, help="gamma param for learning rate.") parser.add_argument("--decay_epochs", default=[60, 80], type=list, help="epochs to do optimizer decay") parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") parser.add_argument("--momentum", type=float, default=0.9, help="momentum for learning rate") parser.add_argument( '--weights', type=str, help='delimited list of weights for penalty during classification') return parser
class SSLOnlineEvaluator(Callback): def __init__(self, data_dir, z_dim, max_epochs=10, check_val_every_n_epoch=1, batch_size=1024, num_workers=32): self.z_dim = z_dim self.max_epochs = max_epochs self.check_val_every_n_epoch = check_val_every_n_epoch self.datamodule = BigearthnetDataModule( data_dir=data_dir, train_frac=0.01, val_frac=0.01, lmdb=True, batch_size=batch_size, num_workers=num_workers ) self.datamodule.setup() self.criterion = nn.MultiLabelSoftMarginLoss() self.metric = lambda output, target: average_precision_score(target, output, average='micro') * 100.0 def on_pretrain_routine_start(self, trainer, pl_module): self.classifier = SSLEvaluator( n_input=self.z_dim, n_classes=self.datamodule.num_classes, n_hidden=None ).to(pl_module.device) self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3) def on_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) % self.check_val_every_n_epoch != 0: return encoder = pl_module.encoder_q self.classifier.train() for _ in range(self.max_epochs): for inputs, targets in self.datamodule.train_dataloader(): inputs = inputs.to(pl_module.device) targets = targets.to(pl_module.device) with torch.no_grad(): representations = encoder(inputs) representations = representations.detach() logits = self.classifier(representations) loss = self.criterion(logits, targets) loss.backward() self.optimizer.step() self.optimizer.zero_grad() self.classifier.eval() accuracies = [] for inputs, targets in self.datamodule.val_dataloader(): inputs = inputs.to(pl_module.device) with torch.no_grad(): representations = encoder(inputs) representations = representations.detach() logits = self.classifier(representations) preds = torch.sigmoid(logits).detach().cpu() acc = self.metric(preds, targets) accuracies.append(acc) acc = torch.mean(torch.tensor(accuracies)) metrics = {'online_val_acc': acc} trainer.logger_connector.log_metrics(metrics, {}) trainer.logger_connector.add_progress_bar_metrics(metrics)