def __init__(self, encoder, DATA_PATH, VAL_PATH, hidden_dim, image_size, seed, cpus, transform=SimCLRTransform, **classifier_hparams): super().__init__() self.DATA_PATH = DATA_PATH self.VAL_PATH = VAL_PATH self.transform = transform self.image_size = image_size self.cpus = cpus self.seed = seed self.batch_size = classifier_hparams['batch_size'] self.classifier_hparams = classifier_hparams self.linear_layer = SSLEvaluator( n_input=encoder.embedding_size, n_classes=self.classifier_hparams['num_classes'], p=self.classifier_hparams['dropout'], n_hidden=hidden_dim) self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) self.encoder = encoder self.save_hyperparameters()
def on_pretrain_routine_start(self, trainer, pl_module): self.classifier = SSLEvaluator( n_input=self.z_dim, n_classes=self.datamodule.num_classes, n_hidden=None ).to(pl_module.device) self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3)
def __init__(self, hparams): super().__init__() self.hparams = hparams self.online_evaluator = self.hparams.online_ft self.dataset = self.get_dataset(hparams.dataset) # encoder network (Z vectors) dummy_batch = torch.zeros( (2, 3, hparams.patch_size, hparams.patch_size)) self.encoder = CPCResNet101(dummy_batch) # info nce loss c, h = self.__compute_final_nb_c(hparams.patch_size) self.info_nce = InfoNCE(num_input_channels=c, target_dim=64, embed_scale=0.1) if self.online_evaluator: z_dim = c * h * h num_classes = self.dataset.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024)
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ Moves metrics and the online evaluator to the correct GPU. If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to a DDP module. """ for prefix, metrics in [("train", self.train_metrics), ("val", self.val_metrics)]: add_submodules_to_same_device(pl_module, metrics, prefix=prefix) self.evaluator = SSLEvaluator(n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim) self.evaluator.to(pl_module.device) if hasattr(trainer, "accelerator_connector"): # This works with Lightning 1.3.8 accelerator = trainer.accelerator_connector elif hasattr(trainer, "_accelerator_connector"): # This works with Lightning 1.5.5 accelerator = trainer._accelerator_connector else: raise ValueError("Unable to retrieve the accelerator information") if accelerator.is_distributed: if accelerator.use_ddp: self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator) self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore else: rank_zero_warn("This type of distributed accelerator is not supported. " "The online evaluator will not synchronize across GPUs.") self.optimizer = torch.optim.Adam(self.evaluator.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.evaluator_state is not None: self._wrapped_evaluator().load_state_dict(self.evaluator_state) if self.optimizer_state is not None: self.optimizer.load_state_dict(self.optimizer_state)
def __init__(self, dataset, data_dir, lr, wd, input_height, batch_size, online_ft=False, num_workers=0, optimizer='adam', step=30, gamma=0.5, temperature=0.5, **kwargs): super().__init__() self.hparams = Namespace( **{ 'lr': lr, 'step': step, 'gamma': gamma, 'temperature': temperature, 'dataset': dataset, 'data_dir': data_dir, 'wd': wd, 'input_height': input_height, 'batch_size': batch_size, 'online_ft': online_ft, 'num_workers': num_workers, 'optimizer': optimizer }) self.online_evaluator = online_ft self.batch_size = batch_size self.input_height = input_height self.gamma = gamma self.step = step self.optimizer = optimizer self.wd = wd self.lr = lr self.temp = temperature self.data_dir = data_dir self.num_workers = num_workers self.dataset_name = dataset self.dataset = self.get_dataset(dataset) self.loss_func = self.init_loss() self.encoder = self.init_encoder() self.projection = self.init_projection() if self.online_evaluator: z_dim = self.projection.output_dim num_classes = self.dataset.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024)
def __init__(self, encoder, DATA_PATH, VAL_PATH, hidden_dim, image_size, seed, cpus, transform=SimCLRTransform, **classifier_hparams): super().__init__() self.DATA_PATH = DATA_PATH self.VAL_PATH = VAL_PATH self.transform = transform self.image_size = image_size self.cpus = cpus self.seed = seed self.batch_size = classifier_hparams['batch_size'] self.classifier_hparams = classifier_hparams self.linear_layer = SSLEvaluator( n_input=encoder.embedding_size, n_classes=self.classifier_hparams['num_classes'], p=self.classifier_hparams['dropout'], n_hidden=hidden_dim) self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) self.encoder = encoder self.weights = None print(classifier_hparams) if classifier_hparams['weights'] is not None: self.weights = torch.tensor([ float(item) for item in classifier_hparams['weights'].split(',') ]) self.weights = self.weights.cuda() self.save_hyperparameters()
def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None: from pl_bolts.models.self_supervised.evaluator import SSLEvaluator pl_module.non_linear_evaluator = SSLEvaluator( n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim, ).to(pl_module.device) self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(), lr=1e-4)
def on_pretrain_routine_start(self, trainer, pl_module): from pl_bolts.models.self_supervised.evaluator import SSLEvaluator # takes an input --> flatten it --> add dropouts --> put a linear layer # simple linear mapping pl_module.non_linear_evaluator = SSLEvaluator( n_input=self.z_dim, n_classes = self.num_classes, p=self.drop_p,).to(pl_module.device) # pass in the parameters of the non_linear_evaluator and use SGD # need to learn the weights for the mlp_loss # need an optimizer self.optimizer = torch.optim.SGD(pl_module.non_linear_evaluator.parameters(), lr=1e-3)
def __init__(self, encoder, DATA_PATH, withhold, batch_size, val_split, hidden_dims, train_transform, val_transform, num_workers, **kwargs): super().__init__() self.DATA_PATH = DATA_PATH self.val_split = val_split self.batch_size = batch_size self.hidden_dims = hidden_dims self.train_transform = train_transform self.val_transform = val_transform self.num_workers = num_workers self.withhold = withhold #data stuff shutil.rmtree('split_data', ignore_errors=True) if not (path.isdir(f"{self.DATA_PATH}/train") and path.isdir(f"{self.DATA_PATH}/val")): splitfolders.ratio(self.DATA_PATH, output=f"split_data", ratio=(1 - self.val_split - self.withhold, self.val_split, self.withhold), seed=10) self.DATA_PATH = 'split_data' print( f'automatically splitting data into train and validation data {self.val_split} and withhold {self.withhold}' ) self.num_classes = len(os.listdir(f'{self.DATA_PATH}/train')) #model stuff self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) print('KWARGS:', kwargs) self.encoder, self.embedding_size = load_encoder(encoder, kwargs) self.linear_layer = SSLEvaluator(n_input=self.embedding_size, n_classes=self.num_classes, p=0.1, n_hidden=self.hidden_dims)
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ Initializes modules and moves metrics and class weights to module device """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore pl_module.non_linear_evaluator = SSLEvaluator( n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim).to(pl_module.device) assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module) self.optimizer = torch.optim.Adam( pl_module.non_linear_evaluator.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
def on_pretrain_routine_start(self, trainer, pl_module): from pl_bolts.models.self_supervised.evaluator import SSLEvaluator # attach the evaluator to the module if hasattr(pl_module, 'z_dim'): self.z_dim = pl_module.z_dim if hasattr(pl_module, 'num_classes'): self.num_classes = pl_module.num_classes pl_module.non_linear_evaluator = SSLEvaluator( n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim ).to(pl_module.device) self.optimizer = torch.optim.SGD(pl_module.non_linear_evaluator.parameters(), lr=1e-3)
def __init__(self, hparams): super().__init__() self.hparams = hparams self.online_evaluator = self.hparams.online_ft self.dataset = self.get_dataset(hparams.dataset) self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(hparams.patch_size) self.info_nce = InfoNCE(num_input_channels=c, target_dim=64, embed_scale=0.1) if self.online_evaluator: z_dim = c * h * h num_classes = self.dataset.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024)
def on_pretrain_routine_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # must move to device after setup, as during setup, pl_module is still on cpu self.online_evaluator = SSLEvaluator( n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim, ).to(pl_module.device) # switch fo PL compatibility reasons accel = (trainer.accelerator_connector if hasattr( trainer, "accelerator_connector") else trainer._accelerator_connector) if accel.is_distributed: if accel.use_ddp: from torch.nn.parallel import DistributedDataParallel as DDP self.online_evaluator = DDP(self.online_evaluator, device_ids=[pl_module.device]) elif accel.use_dp: from torch.nn.parallel import DataParallel as DP self.online_evaluator = DP(self.online_evaluator, device_ids=[pl_module.device]) else: rank_zero_warn( "Does not support this type of distributed accelerator. The online evaluator will not sync." ) self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4) if self._recovered_callback_state is not None: self.online_evaluator.load_state_dict( self._recovered_callback_state["state_dict"]) self.optimizer.load_state_dict( self._recovered_callback_state["optimizer_state"])
class SSLOnlineEvaluator(Callback): def __init__(self, data_dir, z_dim, max_epochs=10, check_val_every_n_epoch=1, batch_size=1024, num_workers=32): self.z_dim = z_dim self.max_epochs = max_epochs self.check_val_every_n_epoch = check_val_every_n_epoch self.datamodule = BigearthnetDataModule( data_dir=data_dir, train_frac=0.01, val_frac=0.01, lmdb=True, batch_size=batch_size, num_workers=num_workers ) self.datamodule.setup() self.criterion = nn.MultiLabelSoftMarginLoss() self.metric = lambda output, target: average_precision_score(target, output, average='micro') * 100.0 def on_pretrain_routine_start(self, trainer, pl_module): self.classifier = SSLEvaluator( n_input=self.z_dim, n_classes=self.datamodule.num_classes, n_hidden=None ).to(pl_module.device) self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3) def on_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) % self.check_val_every_n_epoch != 0: return encoder = pl_module.encoder_q self.classifier.train() for _ in range(self.max_epochs): for inputs, targets in self.datamodule.train_dataloader(): inputs = inputs.to(pl_module.device) targets = targets.to(pl_module.device) with torch.no_grad(): representations = encoder(inputs) representations = representations.detach() logits = self.classifier(representations) loss = self.criterion(logits, targets) loss.backward() self.optimizer.step() self.optimizer.zero_grad() self.classifier.eval() accuracies = [] for inputs, targets in self.datamodule.val_dataloader(): inputs = inputs.to(pl_module.device) with torch.no_grad(): representations = encoder(inputs) representations = representations.detach() logits = self.classifier(representations) preds = torch.sigmoid(logits).detach().cpu() acc = self.metric(preds, targets) accuracies.append(acc) acc = torch.mean(torch.tensor(accuracies)) metrics = {'online_val_acc': acc} trainer.logger_connector.log_metrics(metrics, {}) trainer.logger_connector.add_progress_bar_metrics(metrics)
class CLASSIFIER(pl.LightningModule): #SSLFineTuner def __init__(self, encoder, DATA_PATH, VAL_PATH, hidden_dim, image_size, seed, cpus, transform=SimCLRTransform, **classifier_hparams): super().__init__() self.DATA_PATH = DATA_PATH self.VAL_PATH = VAL_PATH self.transform = transform self.image_size = image_size self.cpus = cpus self.seed = seed self.batch_size = classifier_hparams['batch_size'] self.classifier_hparams = classifier_hparams self.linear_layer = SSLEvaluator( n_input=encoder.embedding_size, n_classes=self.classifier_hparams['num_classes'], p=self.classifier_hparams['dropout'], n_hidden=hidden_dim) self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) self.encoder = encoder self.weights = None print(classifier_hparams) if classifier_hparams['weights'] is not None: self.weights = torch.tensor([ float(item) for item in classifier_hparams['weights'].split(',') ]) self.weights = self.weights.cuda() self.save_hyperparameters() #override optimizer to allow modification of encoder learning rate def configure_optimizers(self): optimizer = SGD([{ 'params': self.encoder.parameters(), 'lr': 0 }, { 'params': self.linear_layer.parameters(), 'lr': self.classifier_hparams['linear_lr'] }], lr=self.classifier_hparams['learning_rate'], momentum=self.classifier_hparams['momentum']) if self.classifier_hparams['scheduler_type'] == "step": scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, self.classifier_hparams['decay_epochs'], gamma=self.classifier_hparams['gamma']) elif self.classifier_hparams['scheduler_type'] == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, self.classifier_hparams['epochs'], eta_min=self.classifier_hparams[ 'final_lr'] # total epochs to run ) return [optimizer], [scheduler] def forward(self, x): feats = self.encoder(x)[-1] feats = feats.view(feats.size(0), -1) logits = self.linear_layer(feats) return logits def shared_step(self, batch): x, y = batch logits = self.forward(x) loss = self.loss_fn(logits, y) return loss, logits, y def training_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) acc = self.train_acc(logits, y) self.log('tloss', loss, prog_bar=True) self.log('tastep', acc, prog_bar=True) self.log('ta_epoch', self.train_acc) return loss def validation_step(self, batch, batch_idx): with torch.no_grad(): loss, logits, y = self.shared_step(batch) acc = self.val_acc(logits, y) acc = self.val_acc(logits, y) self.log('val_loss', loss, prog_bar=True, sync_dist=True) self.log('val_acc_epoch', self.val_acc, prog_bar=True) self.log('val_acc_epoch', self.val_acc, prog_bar=True) return loss def loss_fn(self, logits, labels): return F.cross_entropy(logits, labels, weight=self.weights) def setup(self, stage='inference'): Options = Enum('Loader', 'fit test inference') if stage == Options.fit.name: train = self.transform(self.DATA_PATH, batch_size=self.batch_size, input_height=self.image_size, copies=1, stage='train', num_threads=self.cpus, device_id=self.local_rank, seed=self.seed) val = self.transform(self.VAL_PATH, batch_size=self.batch_size, input_height=self.image_size, copies=1, stage='validation', num_threads=self.cpus, device_id=self.local_rank, seed=self.seed) self.train_loader = ClassifierWrapper(transform=train) self.val_loader = ClassifierWrapper(transform=val) elif stage == Options.inference.name: self.test_dataloader = ClassifierWrapper( transform=self.transform(self.DATA_PATH, batch_size=self.batch_size, input_height=self.image_size, copies=1, stage='inference', num_threads=2 * self.cpus, device_id=self.local_rank, seed=self.seed)) self.inference_dataloader = self.test_dataloader def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader #give user permission to add extra arguments for SIMSIAM model particularly. This cannot share the name of any parameters from train.py def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) # training params parser.add_argument("--linear_lr", default=1e-1, type=float, help="learning rate for classification head.") parser.add_argument("--dropout", default=0.1, type=float, help="dropout of neurons during training [0-1].") parser.add_argument("--nesterov", default=False, type=bool, help="Use nesterov during training.") parser.add_argument( "--scheduler_type", default='cosine', type=str, help="learning rate scheduler: ['cosine' or 'step']") parser.add_argument("--gamma", default=0.1, type=float, help="gamma param for learning rate.") parser.add_argument("--decay_epochs", default=[60, 80], type=list, help="epochs to do optimizer decay") parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") parser.add_argument("--momentum", type=float, default=0.9, help="momentum for learning rate") parser.add_argument( '--weights', type=str, help='delimited list of weights for penalty during classification') return parser
def __init__( self, datamodule: pl.LightningDataModule = None, encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, task: str = 'cpc', num_workers: int = 4, learning_rate: int = 1e-4, data_dir: str = '', batch_size: int = 32, pretrained: str = None, **kwargs, ): """ PyTorch Lightning implementation of `Data-Efficient Image Recognition with Contrastive Predictive Coding <https://arxiv.org/abs/1905.09272>`_ Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord). Model implemented by: - `William Falcon <https://github.com/williamFalcon>`_ - `Tullie Murrell <https://github.com/tullie>`_ Example: >>> from pl_bolts.models.self_supervised import CPCV2 ... >>> model = CPCV2() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python cpc_module.py --gpus 1 # imagenet python cpc_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Some uses:: # load resnet18 pretrained using CPC on imagenet model = CPCV2(encoder='resnet18', pretrained=True) resnet18 = model.encoder renset18.freeze() # it supportes any torchvision resnet model = CPCV2(encoder='resnet50', pretrained=True) # use it as a feature extractor x = torch.rand(2, 3, 224, 224) out = model(x) Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly encoder: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. online_ft: Enable a 1024-unit MLP to fine-tune online task: Which self-supervised task to use ('cpc', 'amdim', etc...) num_workers: num dataloader worksers learning_rate: what learning rate to use data_dir: where to store data batch_size: batch size pretrained: If true, will use the weights pretrained (using CPC) on Imagenet """ super().__init__() self.save_hyperparameters() self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True # link data if datamodule is None: datamodule = CIFAR10DataModule( self.hparams.data_dir, num_workers=self.hparams.num_workers, batch_size=batch_size) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() self.datamodule = datamodule # init encoder self.encoder = encoder if isinstance(encoder, str): self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1) if self.online_evaluator: z_dim = c * h * h num_classes = self.datamodule.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024) if pretrained: self.load_pretrained(encoder)
class finetuneSIMCLR(pl.LightningModule): def __init__(self, encoder, DATA_PATH, withhold, batch_size, val_split, hidden_dims, train_transform, val_transform, num_workers, **kwargs): super().__init__() self.DATA_PATH = DATA_PATH self.val_split = val_split self.batch_size = batch_size self.hidden_dims = hidden_dims self.train_transform = train_transform self.val_transform = val_transform self.num_workers = num_workers self.withhold = withhold #data stuff shutil.rmtree('split_data', ignore_errors=True) if not (path.isdir(f"{self.DATA_PATH}/train") and path.isdir(f"{self.DATA_PATH}/val")): splitfolders.ratio(self.DATA_PATH, output=f"split_data", ratio=(1 - self.val_split - self.withhold, self.val_split, self.withhold), seed=10) self.DATA_PATH = 'split_data' print( f'automatically splitting data into train and validation data {self.val_split} and withhold {self.withhold}' ) self.num_classes = len(os.listdir(f'{self.DATA_PATH}/train')) #model stuff self.train_acc = Accuracy() self.val_acc = Accuracy(compute_on_step=False) print('KWARGS:', kwargs) self.encoder, self.embedding_size = load_encoder(encoder, kwargs) self.linear_layer = SSLEvaluator(n_input=self.embedding_size, n_classes=self.num_classes, p=0.1, n_hidden=self.hidden_dims) # def forward(self, x): # x = self.encoder(x)[0] # x = F.log_softmax(self.fc1(x), dim = 1) # return x def shared_step(self, batch): x, y = batch feats = self.encoder(x)[-1] feats = feats.view(feats.size(0), -1) logits = self.linear_layer(feats) loss = self.loss_fn(logits, y) return loss, logits, y def training_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) acc = self.train_acc(logits, y) self.log('tloss', loss, prog_bar=True) self.log('tastep', acc, prog_bar=True) self.log('ta_epoch', self.train_acc) return loss def validation_step(self, batch, batch_idx): with torch.no_grad(): loss, logits, y = self.shared_step(batch) acc = self.val_acc(logits, y) acc = self.val_acc(logits, y) self.log('vloss', loss, prog_bar=True, sync_dist=True) self.log('val_acc_epoch', self.val_acc, prog_bar=True) return loss def loss_fn(self, logits, labels): return F.cross_entropy(logits, labels) def configure_optimizers(self): opt = SGD([{ 'params': self.encoder.parameters() }, { 'params': self.linear_layer.parameters(), 'lr': 0.1 }], lr=1e-4, momentum=0.9) return [opt] def prepare_data(self): train_pipeline = self.train_transform( DATA_PATH=f"{self.DATA_PATH}/train", input_height=256, batch_size=self.batch_size, num_threads=self.num_workers, device_id=0) print(f"{self.DATA_PATH}/train") val_pipeline = self.val_transform(DATA_PATH=f"{self.DATA_PATH}/val", input_height=256, batch_size=self.batch_size, num_threads=self.num_workers, device_id=0) class LightningWrapper(DALIClassificationIterator): def __init__(self, *kargs, **kvargs): super().__init__(*kargs, **kvargs) def __next__(self): out = super().__next__() out = out[0] return [ out[k] if k != "label" else torch.squeeze(out[k]) for k in self.output_map ] self.train_loader = LightningWrapper(train_pipeline, fill_last_batch=False, auto_reset=True, reader_name="Reader") self.val_loader = LightningWrapper(val_pipeline, fill_last_batch=False, auto_reset=True, reader_name="Reader") def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader
def __init__(self, datamodule: pl_bolts.datamodules.LightningDataModule = None, data_dir: str = '', learning_rate: float = 0.00006, weight_decay: float = 0.0005, input_height: int = 32, batch_size: int = 128, online_ft: bool = False, num_workers: int = 4, optimizer: str = 'lars', lr_sched_step: float = 30.0, lr_sched_gamma: float = 0.5, lars_momentum: float = 0.9, lars_eta: float = 0.001, loss_temperature: float = 0.5, **kwargs): """ PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_ Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton. Model implemented by: - `William Falcon <https://github.com/williamFalcon>`_ - `Tullie Murrell <https://github.com/tullie>`_ Example: >>> from pl_bolts.models.self_supervised import SimCLR ... >>> model = SimCLR() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python simclr_module.py --gpus 1 # imagenet python simclr_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: datamodule: The datamodule data_dir: directory to store data learning_rate: the learning rate weight_decay: optimizer weight decay input_height: image input height batch_size: the batch size online_ft: whether to tune online or not num_workers: number of workers optimizer: optimizer name lr_sched_step: step for learning rate scheduler lr_sched_gamma: gamma for learning rate scheduler lars_momentum: the mom param for lars optimizer lars_eta: for lars optimizer loss_temperature: float = 0. """ super().__init__() self.save_hyperparameters() self.online_evaluator = online_ft # init default datamodule if datamodule is None: datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers) datamodule.train_transforms = SimCLRTrainDataTransform( input_height) datamodule.val_transforms = SimCLREvalDataTransform(input_height) self.datamodule = datamodule self.loss_func = self.init_loss() self.encoder = self.init_encoder() self.projection = self.init_projection() if self.online_evaluator: z_dim = self.projection.output_dim num_classes = self.datamodule.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024)