def __init__( self, *args, **kwargs, ): self._threshold = kwargs.get("threshold", None) self.__instantiate_transform(kwargs) BaseDatasetSamplerMixin.__init__(self, *args, **kwargs) BaseTasksMixin.__init__(self, *args, **kwargs) self.clean_kwargs(kwargs) LightningDataModule.__init__(self, *args, **kwargs) self.dataset_train = None self.dataset_val = None self.dataset_test = None self._seed = 42 self._num_workers = 2 self._shuffle = True self._drop_last = False self._pin_memory = True self._follow_batch = [] self._hyper_parameters = {}
def test_v1_7_0_datamodule_transform_properties(tmpdir): dm = MNISTDataModule() with pytest.deprecated_call( match= r"DataModule property `train_transforms` was deprecated in v1.5"): dm.train_transforms = "a" with pytest.deprecated_call( match=r"DataModule property `val_transforms` was deprecated in v1.5" ): dm.val_transforms = "b" with pytest.deprecated_call( match= r"DataModule property `test_transforms` was deprecated in v1.5"): dm.test_transforms = "c" with pytest.deprecated_call( match= r"DataModule property `train_transforms` was deprecated in v1.5"): _ = LightningDataModule(train_transforms="a") with pytest.deprecated_call( match=r"DataModule property `val_transforms` was deprecated in v1.5" ): _ = LightningDataModule(val_transforms="b") with pytest.deprecated_call( match= r"DataModule property `test_transforms` was deprecated in v1.5"): _ = LightningDataModule(test_transforms="c") with pytest.deprecated_call( match= r"DataModule property `test_transforms` was deprecated in v1.5"): _ = LightningDataModule(test_transforms="c", dims=(1, 1, 1))
def test_dm_init_from_datasets(tmpdir): train_ds = DummyDS() valid_ds = DummyDS() test_ds = DummyDS() valid_dss = [DummyDS(), DummyDS()] test_dss = [DummyDS(), DummyDS()] dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4)) assert dm.val_dataloader() is None assert dm.test_dataloader() is None dm = LightningDataModule.from_datasets(train_ds, valid_ds, test_ds, batch_size=4, num_workers=0) assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4)) assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4)) dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0) assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4)) assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4)) assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4)) assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4))
def test_dm_init_from_datasets_dataloaders(iterable): ds = DummyIDS if iterable else DummyDS train_ds = ds() dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True) with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"): _ = dm.val_dataloader() with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"): _ = dm.test_dataloader() train_ds_sequence = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_has_calls( [ call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), ] ) with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"): _ = dm.val_dataloader() with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"): _ = dm.test_dataloader() valid_ds = ds() test_ds = ds() dm = LightningDataModule.from_datasets(val_dataset=valid_ds, test_dataset=test_ds, batch_size=2, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.val_dataloader() dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) dm.test_dataloader() dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) with pytest.raises(MisconfigurationException, match="`train_dataloader` must be implemented"): _ = dm.train_dataloader() valid_dss = [ds(), ds()] test_dss = [ds(), ds()] predict_dss = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, predict_dss, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.val_dataloader() dm.test_dataloader() dm.predict_dataloader() dl_mock.assert_has_calls( [ call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(predict_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(predict_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), ] )
def model_cases(): class TestHparamsNamespace(LightningModule): learning_rate = 1 def __contains__(self, item): return item == "learning_rate" TestHparamsDict = {"learning_rate": 2} class TestModel1(LightningModule): # test for namespace learning_rate = 0 model1 = TestModel1() class TestModel2(LightningModule): # test for hparams namespace hparams = TestHparamsNamespace() model2 = TestModel2() class TestModel3(LightningModule): # test for hparams dict hparams = TestHparamsDict model3 = TestModel3() class TestModel4(LightningModule): # fail case batch_size = 1 model4 = TestModel4() trainer = Trainer() datamodule = LightningDataModule() datamodule.batch_size = 8 trainer.datamodule = datamodule model5 = LightningModule() model5.trainer = trainer class TestModel6(LightningModule): # test for datamodule w/ hparams w/o attribute (should use datamodule) hparams = TestHparamsDict model6 = TestModel6() model6.trainer = trainer TestHparamsDict2 = {"batch_size": 2} class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribute (should use datamodule) hparams = TestHparamsDict2 model7 = TestModel7() model7.trainer = trainer return model1, model2, model3, model4, model5, model6, model7
def predict_test(trainer: Trainer, model: LightningModule, dm: LightningDataModule): """Checks if the trained model has high accuracy on the test set.""" trainer.fit(model, datamodule=dm) dm.setup(stage="test") test_loader = dm.test_dataloader() acc = pl.metrics.Accuracy() for batch in test_loader: x, y = batch with torch.no_grad(): y_hat = model(x) y_hat = y_hat.cpu() acc.update(y_hat, y) average_acc = acc.compute() assert average_acc >= 0.5, f"This model is expected to get > {0.5} in " \ f"test set (it got {average_acc})"
def gaussian_noise_pig_dl( dm: pl.LightningDataModule, batch_size: int, N_hat_multiplier: float = 1, sigma_multiplier: float = 1, ) -> DataLoader: # * 3 # 3 is Arbitrary # return x + noise_std * torch.randn_like(x) x, _ = unpack_dataloader(dm.train_dataloader()) N = x.shape[0] noise_std = torch.std(x) # Does not exactly respect N_hat but easier to do that way if N_hat_multiplier >= 1: n_repeats = int(N_hat_multiplier) x = x.repeat(1, n_repeats).view(-1, x.shape[1]) else: N_hat = int(N * N_hat_multiplier) idx = torch.randperm(x.shape[0])[:N_hat] x = x[idx] x_hat = x + noise_std * torch.randn_like(x) dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True) return dl
def run_model_test_without_loggers(trainer_options: dict, model: LightningModule, data: LightningDataModule = None, min_acc: float = 0.50): reset_seed() # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=data) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" model2 = load_model_from_checkpoint( trainer.logger, trainer.checkpoint_callback.best_model_path, type(model)) # test new model accuracy test_loaders = model2.test_dataloader( ) if not data else data.test_dataloader() if not isinstance(test_loaders, list): test_loaders = [test_loaders] if not isinstance(model2, BoringModel): for dataloader in test_loaders: run_prediction_eval_model_template(model2, dataloader, min_acc=min_acc)
def lightning_train(trainer: pl.Trainer, model: pl.LightningModule, data_module: pl.LightningDataModule, checkpoint_path: str, resume: bool): if resume: trainer = pl.Trainer(resume_from_checkpoint=checkpoint_path) exp_key = trainer.logger.experiment.get_key() print("Running {}-model...".format(model.name)) # main part here data_module.setup('fit') trainer.fit(model=model, datamodule=data_module) model.save(trainer, checkpoint_path) return trainer, model, exp_key
def test_v1_7_0_datamodule_dims_property(tmpdir): dm = MNISTDataModule() with pytest.deprecated_call( match=r"DataModule property `dims` was deprecated in v1.5"): _ = dm.dims with pytest.deprecated_call( match=r"DataModule property `dims` was deprecated in v1.5"): _ = LightningDataModule(dims=(1, 1, 1))
def run_model_test( trainer_options, model: LightningModule, data: LightningDataModule = None, on_gpu: bool = True, version=None, with_hpc: bool = True, min_acc: float = 0.25, ): reset_seed() save_dir = trainer_options["default_root_dir"] # logger file to get meta logger = get_default_logger(save_dir, version=version) trainer_options.update(logger=logger) trainer = Trainer(**trainer_options) initial_values = torch.tensor( [torch.sum(torch.abs(x)) for x in model.parameters()]) trainer.fit(model, datamodule=data) post_train_values = torch.tensor( [torch.sum(torch.abs(x)) for x in model.parameters()]) assert trainer.state.finished, f"Training failed with {trainer.state}" # Check that the model is actually changed post-training change_ratio = torch.norm(initial_values - post_train_values) assert change_ratio > 0.1, f"the model is changed of {change_ratio}" # test model loading pretrained_model = load_model_from_checkpoint( logger, trainer.checkpoint_callback.best_model_path, type(model)) # test new model accuracy test_loaders = model.test_dataloader( ) if not data else data.test_dataloader() if not isinstance(test_loaders, list): test_loaders = [test_loaders] if not isinstance(model, BoringModel): for dataloader in test_loaders: run_prediction_eval_model_template(model, dataloader, min_acc=min_acc) if with_hpc: if trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = trainer.init_optimizers( pretrained_model) # test HPC saving trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder( save_dir) trainer.checkpoint_connector.restore(checkpoint_path)
def test_dm_init_from_datasets_dataloaders(iterable): ds = DummyIDS if iterable else DummyDS train_ds = ds() dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True) assert dm.val_dataloader() is None assert dm.test_dataloader() is None train_ds_sequence = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() dl_mock.assert_has_calls([ call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True) ]) assert dm.val_dataloader() is None assert dm.test_dataloader() is None valid_ds = ds() test_ds = ds() dm = LightningDataModule.from_datasets(val_dataset=valid_ds, test_dataset=test_ds, batch_size=2, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.val_dataloader() dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) dm.test_dataloader() dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True) assert dm.train_dataloader() is None valid_dss = [ds(), ds()] test_dss = [ds(), ds()] dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0) with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock: dm.val_dataloader() dm.test_dataloader() dl_mock.assert_has_calls([ call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ])
def __init__( self, train_sequences: List[str], validation_sequences: List[str], alphabet: AlphabetDataLoader, masking_ratio: float, masking_prob: float, random_token_prob: float, num_workers: int, toks_per_batch: int, crop_sizes: Tuple[int, int] = (512, 1024), ): LightningDataModule.__init__(self) self._train_sequences = train_sequences self._validation_sequences = validation_sequences self._alphabet = alphabet self._masking_ratio = masking_ratio self._masking_prob = masking_prob self._random_token_prob = random_token_prob self._num_workers = num_workers self._toks_per_batch = toks_per_batch self._crop_sizes = crop_sizes
def __post_init__(self, observation_space: gym.Space = None, action_space: gym.Space = None, reward_space: gym.Space = None): """ Initializes the fields of the setting that weren't set from the command-line. """ logger.debug(f"__post_init__ of Setting") if len(self.train_transforms) == 1 and isinstance(self.train_transforms[0], list): self.train_transforms = self.train_transforms[0] if len(self.val_transforms) == 1 and isinstance(self.val_transforms[0], list): self.val_transforms = self.val_transforms[0] if len(self.test_transforms) == 1 and isinstance(self.test_transforms[0], list): self.test_transforms = self.test_transforms[0] # Actually compose the list of Transforms or callables into a single transform. self.train_transforms: Compose = Compose(self.train_transforms) self.val_transforms: Compose = Compose(self.val_transforms) self.test_transforms: Compose = Compose(self.test_transforms) LightningDataModule.__init__(self, train_transforms=self.train_transforms, val_transforms=self.val_transforms, test_transforms=self.test_transforms, ) self._observation_space = observation_space self._action_space = action_space self._reward_space = reward_space # TODO: It's a bit confusing to also have a `config` attribute on the # Setting. Might want to change this a bit. self.config: Config = None self.train_env: Environment = None # type: ignore self.val_env: Environment = None # type: ignore self.test_env: Environment = None # type: ignore
def model_from_config(parameters: dict, datamodule: LightningDataModule) -> LightningModule: steps_per_epoch = len(datamodule.train_dataloader()) total_steps = parameters["trainer"]["epochs"] * steps_per_epoch class_weight = compute_class_weight(datamodule.dataset_train) regression = parameters["data"]["targets"] == "regression" if regression: model = ImageRegression( n_channel=datamodule.size(0), lr_scheduler_total_steps=total_steps, **parameters["model"], ) else: model = ImageClassification( n_channel=datamodule.size(0), n_class=len(parameters["data"]["targets"]["classes"]), class_weight=class_weight, lr_scheduler_total_steps=total_steps, **parameters["model"], ) return model
def run_model_test( trainer_options, model: LightningModule, data: LightningDataModule = None, on_gpu: bool = True, version=None, with_hpc: bool = True, min_acc: float = 0.25, ): reset_seed() save_dir = trainer_options["default_root_dir"] # logger file to get meta logger = get_default_logger(save_dir, version=version) trainer_options.update(logger=logger) trainer = Trainer(**trainer_options) initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()]) trainer.fit(model, datamodule=data) post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()]) assert trainer.state.finished, f"Training failed with {trainer.state}" # Check that the model is actually changed post-training change_ratio = torch.norm(initial_values - post_train_values) assert change_ratio > 0.03, f"the model is changed of {change_ratio}" # test model loading _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) # test new model accuracy test_loaders = model.test_dataloader() if not data else data.test_dataloader() if not isinstance(test_loaders, list): test_loaders = [test_loaders] if not isinstance(model, BoringModel): for dataloader in test_loaders: run_model_prediction(model, dataloader, min_acc=min_acc) if with_hpc: # test HPC saving # save logger to make sure we get all the metrics if logger: logger.finalize("finished") hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir) trainer.save_checkpoint(hpc_save_path) # test HPC loading checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir) trainer.checkpoint_connector.restore(checkpoint_path)
def __init__( self, datamodule: pl.LightningDataModule = None, encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, task: str = 'cpc', num_workers: int = 4, learning_rate: int = 1e-4, data_dir: str = '', batch_size: int = 32, pretrained: str = None, **kwargs, ): """ PyTorch Lightning implementation of `Data-Efficient Image Recognition with Contrastive Predictive Coding <https://arxiv.org/abs/1905.09272>`_ Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord). Model implemented by: - `William Falcon <https://github.com/williamFalcon>`_ - `Tullie Murrell <https://github.com/tullie>`_ Example: >>> from pl_bolts.models.self_supervised import CPCV2 ... >>> model = CPCV2() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python cpc_module.py --gpus 1 # imagenet python cpc_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 To Finetune:: python cpc_finetuner.py --ckpt_path path/to/checkpoint.ckpt --dataset cifar10 --gpus x Some uses:: # load resnet18 pretrained using CPC on imagenet model = CPCV2(encoder='resnet18', pretrained=True) resnet18 = model.encoder renset18.freeze() # it supportes any torchvision resnet model = CPCV2(encoder='resnet50', pretrained=True) # use it as a feature extractor x = torch.rand(2, 3, 224, 224) out = model(x) Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly encoder: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. online_ft: Enable a 1024-unit MLP to fine-tune online task: Which self-supervised task to use ('cpc', 'amdim', etc...) num_workers: num dataloader worksers learning_rate: what learning rate to use data_dir: where to store data batch_size: batch size pretrained: If true, will use the weights pretrained (using CPC) on Imagenet """ super().__init__() self.save_hyperparameters() self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True # link data if datamodule is None: datamodule = CIFAR10DataModule( self.hparams.data_dir, num_workers=self.hparams.num_workers, batch_size=batch_size ) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() self.datamodule = datamodule # init encoder self.encoder = encoder if isinstance(encoder, str): self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1) self.z_dim = c * h * h self.num_classes = self.datamodule.num_classes if pretrained: self.load_pretrained(encoder)
def kde_pig_dl( dm: pl.LightningDataModule, batch_size: int, N_hat_multiplier: float = 1, ) -> DataLoader: # % gd_n_steps, gd_lr, gd_threshold = 5, 4e-1, 0.005 # Spherical = each component has single variance. bgm = BayesianGaussianMixture( n_components=batch_size, covariance_type="spherical", warm_start=True, ) x_hat = torch.Tensor() for idx, batch in enumerate(iter(dm.train_dataloader())): x, _ = batch device = x.device x = x.detach().cpu().numpy() # Last batch might have less elements than origin n_components if x.shape[0] < bgm.n_components: bgm = BayesianGaussianMixture( n_components=x.shape[0], covariance_type="spherical", ) # Estimate KDE bgm.fit(x) # [N_components, 1], [N_components, N_features], [N_components, 1] weights, means, variances = ( torch.Tensor(bgm.weights_).to(device), torch.Tensor(bgm.means_).to(device), torch.Tensor(bgm.covariances_).to(device), ) filter_weights_idx = weights >= 1e-5 weights, means, variances = ( weights[filter_weights_idx], means[filter_weights_idx], variances[filter_weights_idx][:, None], ) n_selected_components = weights.shape[0] p_x = D.Independent(D.Normal(means, torch.sqrt(variances)), 1) mix = D.Categorical(weights) p_x = D.MixtureSameFamily(mix, p_x) # Sample according to multiplier x_start = p_x.sample( ( n_selected_components * ((batch_size // n_selected_components) + 1) * N_hat_multiplier, ) ).reshape(-1, x.shape[1]) # Use GD _x_hat = density_gradient_descent( p_x, x_start, {"N_steps": gd_n_steps, "lr": gd_lr, "threshold": gd_threshold}, ) # Ensure same device if x_hat.device != device: x_hat = x_hat.to(device) x_hat = torch.cat((x_hat, _x_hat.detach())) dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True) return dl
def mean_shift_pig_dl( dm: pl.LightningDataModule, batch_size: int, N_hat_multiplier: float = 1, max_iters: int = 20, h: float = None, h_factor: float = 1, sigma: float = None, sigma_factor: float = 1, kernel: str = "tophat", τ: float = 1e-5, ) -> DataLoader: x, _ = unpack_dataloader(dm.train_dataloader()) if h is None: h = silverman_bandwidth(x) * h_factor if sigma is None: # Note multiplier is arbitrary sigma = torch.std(x) * sigma_factor N = x.shape[0] N_hat = int(N_hat_multiplier * N) if N_hat > N: # With replacement to allow for N_hat > N idx = torch.randint(x.shape[0], (N_hat,)) else: # Without replacement idx = torch.randperm(x.shape[0])[:N_hat] # [n_pig, D] x_start = x[idx] # Sample locally around each point, [n_pig, D] x_hat = x_start + sigma * torch.randn_like(x_start) for _ in range(max_iters): # [N, n_pig] dst = sqr_dist(x, x_hat) if kernel == "tophat": kde = torch.where(dst <= h, torch.ones_like(dst), torch.zeros_like(dst)) else: raise NotImplementedError(f"Kernel {kernel} invalid") # Replace nan by 0 kde[kde != kde] = 0 # Threshold for stopping updates out_kde = (torch.max(kde, dim=0)[0] < τ)[:, None] # [n_pig, 1] sum_kde = torch.sum(kde, dim=0).reshape((-1, 1)) sum_kde = torch.where(out_kde, torch.ones_like(sum_kde), sum_kde) # Centroid for all estimates mu = (torch.transpose(kde, 0, 1) @ x) / sum_kde # Step size delta = mu - x_hat x_hat = torch.where(out_kde, x_hat, x_hat - delta) # import matplotlib.pyplot as plt # from sklearn.neighbors import KernelDensity # kde = KernelDensity(kernel="gaussian", bandwidth=0.2).fit(x_hat) # x_plot = np.linspace(-10, 20, 1000).reshape(-1, 1) # density = np.exp(kde.score_samples(x_plot)) # fig, ax = plt.subplots() # ax.plot(x, torch.zeros_like(x), "o") # ax.plot(x_hat, torch.zeros_like(x_hat), "o") # ax.plot(x_plot, density, color="black") # plt.show() # exit() dl = DataLoader(TensorDataset(x_hat), batch_size=batch_size, shuffle=True) return dl
def make_cli_parser( parser: argparse.ArgumentParser, datamodule_cls: pl.LightningDataModule) -> argparse.ArgumentParser: """make_cli_parser Augment an argument parser for slp with the default arguments Default arguments for training, logging, optimization etc. are added to the input {parser}. If you use make_cli_parser, the following command line arguments will be included usage: my_script.py [-h] [--hidden MODEL.INTERMEDIATE_HIDDEN] [--optimizer {Adam,AdamW,SGD,Adadelta,Adagrad,Adamax,ASGD,RMSprop}] [--lr OPTIM.LR] [--weight-decay OPTIM.WEIGHT_DECAY] [--lr-scheduler] [--lr-factor LR_SCHEDULE.FACTOR] [--lr-patience LR_SCHEDULE.PATIENCE] [--lr-cooldown LR_SCHEDULE.COOLDOWN] [--min-lr LR_SCHEDULE.MIN_LR] [--seed SEED] [--config CONFIG] [--experiment-name TRAINER.EXPERIMENT_NAME] [--run-id TRAINER.RUN_ID] [--experiment-group TRAINER.EXPERIMENT_GROUP] [--experiments-folder TRAINER.EXPERIMENTS_FOLDER] [--save-top-k TRAINER.SAVE_TOP_K] [--patience TRAINER.PATIENCE] [--wandb-project TRAINER.WANDB_PROJECT] [--tags [TRAINER.TAGS [TRAINER.TAGS ...]]] [--stochastic_weight_avg] [--gpus TRAINER.GPUS] [--val-interval TRAINER.CHECK_VAL_EVERY_N_EPOCH] [--clip-grad-norm TRAINER.GRADIENT_CLIP_VAL] [--epochs TRAINER.MAX_EPOCHS] [--steps TRAINER.MAX_STEPS] [--tbtt_steps TRAINER.TRUNCATED_BPTT_STEPS] [--debug] [--offline] [--early-stop-on TRAINER.EARLY_STOP_ON] [--early-stop-mode {min,max}] [--num-trials TUNE.NUM_TRIALS] [--gpus-per-trial TUNE.GPUS_PER_TRIAL] [--cpus-per-trial TUNE.CPUS_PER_TRIAL] [--tune-metric TUNE.METRIC] [--tune-mode {max,min}] [--val-percent DATA.VAL_PERCENT] [--test-percent DATA.TEST_PERCENT] [--bsz DATA.BATCH_SIZE] [--bsz-eval DATA.BATCH_SIZE_EVAL] [--num-workers DATA.NUM_WORKERS] [--no-pin-memory] [--drop-last] [--no-shuffle-eval] optional arguments: -h, --help show this help message and exit --hidden MODEL.INTERMEDIATE_HIDDEN Intermediate hidden layers for linear module --optimizer {Adam,AdamW,SGD,Adadelta,Adagrad,Adamax,ASGD,RMSprop} Which optimizer to use --lr OPTIM.LR Learning rate --weight-decay OPTIM.WEIGHT_DECAY Learning rate --lr-scheduler Use learning rate scheduling. Currently only ReduceLROnPlateau is supported out of the box --lr-factor LR_SCHEDULE.FACTOR Multiplicative factor by which LR is reduced. Used if --lr-scheduler is provided. --lr-patience LR_SCHEDULE.PATIENCE Number of epochs with no improvement after which learning rate will be reduced. Used if --lr-scheduler is provided. --lr-cooldown LR_SCHEDULE.COOLDOWN Number of epochs to wait before resuming normal operation after lr has been reduced. Used if --lr- scheduler is provided. --min-lr LR_SCHEDULE.MIN_LR Minimum lr for LR scheduling. Used if --lr-scheduler is provided. --seed SEED Seed for reproducibility --config CONFIG Path to YAML configuration file --experiment-name TRAINER.EXPERIMENT_NAME Name of the running experiment --run-id TRAINER.RUN_ID Unique identifier for the current run. If not provided it is inferred from datetime.now() --experiment-group TRAINER.EXPERIMENT_GROUP Group of current experiment. Useful when evaluating for different seeds / cross-validation etc. --experiments-folder TRAINER.EXPERIMENTS_FOLDER Top-level folder where experiment results & checkpoints are saved --save-top-k TRAINER.SAVE_TOP_K Save checkpoints for top k models --patience TRAINER.PATIENCE Number of epochs to wait before early stopping --wandb-project TRAINER.WANDB_PROJECT Wandb project under which results are saved --tags [TRAINER.TAGS [TRAINER.TAGS ...]] Tags for current run to make results searchable. --stochastic_weight_avg Use Stochastic weight averaging. --gpus TRAINER.GPUS Number of GPUs to use --val-interval TRAINER.CHECK_VAL_EVERY_N_EPOCH Run validation every n epochs --clip-grad-norm TRAINER.GRADIENT_CLIP_VAL Clip gradients with ||grad(w)|| >= args.clip_grad_norm --epochs TRAINER.MAX_EPOCHS Maximum number of training epochs --steps TRAINER.MAX_STEPS Maximum number of training steps --tbtt_steps TRAINER.TRUNCATED_BPTT_STEPS Truncated Back-propagation-through-time steps. --debug If true, we run a full run on a small subset of the input data and overfit 10 training batches --offline If true, forces offline execution of wandb logger --early-stop-on TRAINER.EARLY_STOP_ON Metric for early stopping --early-stop-mode {min,max} Minimize or maximize early stopping metric --num-trials TUNE.NUM_TRIALS Number of trials to run for hyperparameter tuning --gpus-per-trial TUNE.GPUS_PER_TRIAL How many gpus to use for each trial. If gpus_per_trial < 1 multiple trials are packed in the same gpu --cpus-per-trial TUNE.CPUS_PER_TRIAL How many cpus to use for each trial. --tune-metric TUNE.METRIC Tune this metric. Need to be one of the keys of metrics_map passed into make_trainer_for_ray_tune. --tune-mode {max,min} Maximize or minimize metric --val-percent DATA.VAL_PERCENT Percent of validation data to be randomly split from the training set, if no validation set is provided --test-percent DATA.TEST_PERCENT Percent of test data to be randomly split from the training set, if no test set is provided --bsz DATA.BATCH_SIZE Training batch size --bsz-eval DATA.BATCH_SIZE_EVAL Evaluation batch size --num-workers DATA.NUM_WORKERS Number of workers to be used in the DataLoader --no-pin-memory Don't pin data to GPU memory when transferring --drop-last Drop last incomplete batch --no-shuffle-eval Don't shuffle val & test sets Args: parser (argparse.ArgumentParser): A parent argument to be augmented datamodule_cls (pytorch_lightning.LightningDataModule): A data module class that injects arguments through the add_argparse_args method Returns: argparse.ArgumentParser: The augmented command line parser Examples: >>> import argparse >>> from slp.plbind.dm import PLDataModuleFromDatasets >>> parser = argparse.ArgumentParser("My cool model") >>> parser.add_argument("--hidden", dest="model.hidden", type=int) # Create parser with model arguments and anything else you need >>> parser = make_cli_parser(parser, PLDataModuleFromDatasets) >>> args = parser.parse_args(args=["--bsz", "64", "--lr", "0.01"]) >>> args.data.batch_size 64 >>> args.optim.lr 0.01 """ parser = add_optimizer_args(parser) parser = add_trainer_args(parser) parser = add_tune_args(parser) parser = datamodule_cls.add_argparse_args(parser) return parser
def data_info(datamodule: pl.LightningDataModule, parameters: dict, save_path: Path = None) -> dict: """ Summary of the data used for training/testing. Gives: class-balance in each split, shape of the data, range of the data, split sizes, plot of examples. Datasets in datamodule must have y() method to access targets (like BaseDataset). """ info = {} # Prepare datasets datasets = {} try: datasets["train"] = datamodule.train_dataloader().dataset except Exception: pass try: datasets["val"] = datamodule.val_dataloader().dataset except Exception: pass try: datasets["test"] = datamodule.test_dataloader().dataset except Exception: pass # Analyze class-balance def class_balance(dataset: BaseDataset): if isinstance(dataset, BaseDataset): y = dataset.y() elif isinstance(dataset, Subset): ds = dataset.dataset if not isinstance(ds, BaseDataset): raise AttributeError("dataset must be a BaseDataset or a Subset of a BaseDataset.") y = ds.y(dataset.indices) else: raise AttributeError("dataset must be a BaseDataset or a Subset of a BaseDataset.") counter = Counter(y) class_names = [list(c.keys())[0] for c in parameters["data"]["targets"]["classes"]] return {class_names[i]: counter[i] for i in range(len(class_names))} if "data" in parameters and "targets" in parameters["data"] and "classes" in parameters["data"]["targets"]: info["class-balance"] = {} for ds_name, ds in datasets.items(): info["class-balance"][ds_name] = class_balance(ds) # Shape of data info["shape"] = str(datamodule.size()) # Range if "train" in datasets: if isinstance(datasets["train"][0][0], tuple): t = torch.cat([datasets["train"][0][0][0], datasets["train"][1][0][0], datasets["train"][2][0][0]]) else: t = torch.cat([datasets["train"][0][0], datasets["train"][1][0], datasets["train"][2][0]]) info["tensor-data"] = { "min": t.min().item(), "max": t.max().item(), "mean": t.mean().item(), "std": t.std().item(), } # Sizes of splits info["set-sizes"] = {} for ds_name, ds in datasets.items(): info["set-sizes"][ds_name] = len(ds) # Examples of each batch def plot_batch(ds_name: str, ds: Dataset, n_images: int = 32): if isinstance(ds[0][0], tuple): images = [ds[i][0][0][0] for i in random.sample(range(len(ds)), n_images)] else: images = [ds[i][0][0] for i in random.sample(range(len(ds)), n_images)] plot_image_grid(images, max_images=n_images, columns=8, save_path=save_path / f"data-examples-{ds_name}.png") if save_path is not None: for ds_name, ds in datasets.items(): plot_batch(ds_name, ds) return info
def __init__(self, base_encoder: Union[str, torch.nn.Module] = 'resnet18', emb_dim: int = 128, num_negatives: int = 65536, encoder_momentum: float = 0.999, softmax_temperature: float = 0.07, learning_rate: float = 0.03, momentum: float = 0.9, weight_decay: float = 1e-4, datamodule: pl.LightningDataModule = None, data_dir: str = './', batch_size: int = 256, use_mlp: bool = False, num_workers: int = 8, *args, **kwargs): """ PyTorch Lightning implementation of `Moco <https://arxiv.org/abs/2003.04297>`_ Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. Code adapted from `facebookresearch/moco <https://github.com/facebookresearch/moco>`_ to Lightning by: - `William Falcon <https://github.com/williamFalcon>`_ Example: >>> from pl_bolts.models.self_supervised import MocoV2 ... >>> model = MocoV2() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python moco2_module.py --gpus 1 # imagenet python moco2_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: base_encoder: torchvision model name or torch.nn.Module emb_dim: feature dimension (default: 128) num_negatives: queue size; number of negative keys (default: 65536) encoder_momentum: moco momentum of updating key encoder (default: 0.999) softmax_temperature: softmax temperature (default: 0.07) learning_rate: the learning rate momentum: optimizer momentum weight_decay: optimizer weight decay datamodule: the DataModule (train, val, test dataloaders) data_dir: the directory to store data batch_size: batch size use_mlp: add an mlp to the encoders num_workers: workers for the loaders """ super().__init__() self.save_hyperparameters() # use CIFAR-10 by default if no datamodule passed in if datamodule is None: datamodule = CIFAR10DataModule(data_dir) datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() self.datamodule = datamodule # create the encoders # num_classes is the output fc dimension self.encoder_q, self.encoder_k = self.init_encoders(base_encoder) if use_mlp: # hack: brute-force replacement dim_mlp = self.encoder_q.fc.weight.shape[1] self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient # create the queue self.register_buffer("queue", torch.randn(emb_dim, num_negatives)) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
def __post_init__( self, observation_space: gym.Space = None, action_space: gym.Space = None, reward_space: gym.Space = None, ): """ Initializes the fields of the setting that weren't set from the command-line. """ logger.debug("__post_init__ of Setting") # BUG: simple-parsing sometimes parses a list with a single item, itself the # list of transforms. Not sure if this still happens. def is_list_of_list(v: Any) -> bool: return isinstance(v, list) and len(v) == 1 and isinstance(v[0], list) if is_list_of_list(self.train_transforms): self.train_transforms = self.train_transforms[0] if is_list_of_list(self.val_transforms): self.val_transforms = self.val_transforms[0] if is_list_of_list(self.test_transforms): self.test_transforms = self.test_transforms[0] if all( t is None for t in [ self.transforms, self.train_transforms, self.val_transforms, self.test_transforms, ] ): # Use these two transforms by default if no transforms are passed at all. # TODO: Remove this after the competition perhaps. self.transforms = Compose([Transforms.to_tensor, Transforms.three_channels]) # If the constructor is called with just the `transforms` argument, like this: # <SomeSetting>(dataset="bob", transforms=foo_transform) # Then we use this value as the default for the train, val and test transforms. if self.transforms and not any( [self.train_transforms, self.val_transforms, self.test_transforms] ): if not isinstance(self.transforms, list): self.transforms = Compose([self.transforms]) self.train_transforms = self.transforms.copy() self.val_transforms = self.transforms.copy() self.test_transforms = self.transforms.copy() if self.train_transforms is not None and not isinstance( self.train_transforms, list ): self.train_transforms = [self.train_transforms] if self.val_transforms is not None and not isinstance( self.val_transforms, list ): self.val_transforms = [self.val_transforms] if self.test_transforms is not None and not isinstance( self.test_transforms, list ): self.test_transforms = [self.test_transforms] # Actually compose the list of Transforms or callables into a single transform. self.train_transforms: Compose = Compose(self.train_transforms or []) self.val_transforms: Compose = Compose(self.val_transforms or []) self.test_transforms: Compose = Compose(self.test_transforms or []) LightningDataModule.__init__( self, train_transforms=self.train_transforms, val_transforms=self.val_transforms, test_transforms=self.test_transforms, ) self._observation_space = observation_space self._action_space = action_space self._reward_space = reward_space # TODO: It's a bit confusing to also have a `config` attribute on the # Setting. Might want to change this a bit. self.config: Config = None self.train_env: Environment = None # type: ignore self.val_env: Environment = None # type: ignore self.test_env: Environment = None # type: ignore
def train( parameters: dict, datamodule: LightningDataModule, model: LightningModule, callbacks: Optional[Union[List[Callback], Callback]] = None, ): validate_train_parameters(parameters) model_path = Path(parameters["path"]) plot_path = Path(parameters["path"]) / "train_plots" plot_path.mkdir(parents=True, exist_ok=True) regression = parameters["data"]["targets"] == "regression" n_class = 1 if regression else len(parameters["data"]["targets"]["classes"]) # Callbacks if callbacks is None: callbacks = [] elif not isinstance(callbacks, list): callbacks = [callbacks] timer_callback = TimerCallback() early_stop_callback = ( EarlyStopping( monitor="val_loss", min_delta=0.00, patience=parameters["trainer"]["patience"], verbose=True, mode="min" ) if parameters["trainer"]["patience"] > 0 else None ) checkpoint_callback = ModelCheckpoint( monitor="val_loss", mode="min", dirpath=str(model_path), filename="model", ) callbacks += [ early_stop_callback, checkpoint_callback, InfoLogCallback(), timer_callback, PlotTrainValCurveCallback(plot_path, "loss"), PlotTrainValCurveCallback(plot_path, "accuracy"), ] callbacks = [c for c in callbacks if c is not None] # Logging im_logger = InMemoryLogger() pl_logger = [im_logger] # Tracking tracking: Tracking = NeptuneNewTracking( parameters=parameters, tags=parameters.get("tags", []), disabled=not parameters["tracking"] or parameters["tune_lr"] ) tracking_logger = tracking.get_callback("pytorch-lightning") if tracking_logger is not None: pl_logger.append(tracking_logger) callbacks.append(LearningRateMonitor(logging_interval="step")) trainer = pl.Trainer( gpus=parameters["system"]["gpus"], logger=pl_logger, callbacks=callbacks, max_epochs=parameters["trainer"]["epochs"], num_sanity_val_steps=0, fast_dev_run=False, accelerator=None if parameters["system"]["gpus"] in [None, 0, 1] else "dp", sync_batchnorm=True, deterministic=True, ) if parameters["tune_lr"]: lr_find( trainer, model, datamodule, model_path, min_lr=parameters.get("min_lr", 1e-6), max_lr=parameters.get("max_lr", 1e-2), num_training=parameters.get("num_training", 100), ) return trainer.fit(model, datamodule=datamodule) # Save pytorch model / nn.Module / weights torch.save(model.state_dict(), Path(model_path / "model.pt")) # Save meaningful parameters for loading model_config = { "model": type(model).__name__, "backbone": model.backbone_name, "n_channel": datamodule.size(0), "output_size": model.output_size, "hparams": dict(model.hparams), } write_yaml(model_path / "model_config.yaml", model_config) # Log actual config used for training write_yaml(model_path / "config.yaml", parameters) # Log model summary pytorch_model_summary(model, model_path) # Output tracking run id to continue logging in test step run_id = tracking.get_id() # Metadata steps_per_epoch = len(datamodule.train_dataloader()) metadata = get_training_summary( model_path / "model.ckpt", parameters, run_id, timer_callback, datamodule, early_stop_callback, checkpoint_callback, steps_per_epoch, save_path=plot_path, ) write_yaml(model_path / "metadata.yaml", metadata) tracking.log_property("metadata", metadata) tracking.log_artifact(plot_path / "data-examples-train.png") tracking.log_artifact(plot_path / "data-examples-val.png") if (plot_path / "data-examples-test.png").exists(): tracking.log_artifact(plot_path / "data-examples-test.png") tracking.end()
def __init__(self, datamodule: pl.LightningDataModule = None, data_dir: str = './', learning_rate: float = 0.2, weight_decay: float = 15e-6, input_height: int = 32, batch_size: int = 32, num_workers: int = 4, warmup_epochs: int = 10, max_epochs: int = 1000, **kwargs): """ PyTorch Lightning implementation of `Bring Your Own Latent (BYOL) <https://arxiv.org/pdf/2006.07733.pdf.>`_ Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. Model implemented by: - `Annika Brundyn <https://github.com/annikabrundyn>`_ .. warning:: Work in progress. This implementation is still being verified. TODOs: - verify on CIFAR-10 - verify on STL-10 - pre-train on imagenet Example: >>> from pl_bolts.models.self_supervised import BYOL ... >>> model = BYOL() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python byol_module.py --gpus 1 # imagenet python byol_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: datamodule: The datamodule data_dir: directory to store data learning_rate: the learning rate weight_decay: optimizer weight decay input_height: image input height batch_size: the batch size num_workers: number of workers warmup_epochs: num of epochs for scheduler warm up max_epochs: max epochs for scheduler """ super().__init__() self.save_hyperparameters() # init default datamodule if datamodule is None: datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size) datamodule.train_transforms = SimCLRTrainDataTransform( input_height) datamodule.val_transforms = SimCLREvalDataTransform(input_height) self.datamodule = datamodule self.online_network = SiameseArm() self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() # for finetuning callback self.z_dim = 2048 self.num_classes = self.datamodule.num_classes
def __init__( self, datamodule: pl.LightningDataModule = None, encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, task: str = 'cpc', num_workers: int = 4, learning_rate: int = 1e-4, data_dir: str = '', batch_size: int = 32, pretrained: str = None, **kwargs, ): """ Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly encoder: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. online_ft: Enable a 1024-unit MLP to fine-tune online task: Which self-supervised task to use ('cpc', 'amdim', etc...) num_workers: num dataloader worksers learning_rate: what learning rate to use data_dir: where to store data batch_size: batch size pretrained: If true, will use the weights pretrained (using CPC) on Imagenet """ super().__init__() self.save_hyperparameters() self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True # link data if datamodule is None: datamodule = CIFAR10DataModule( self.hparams.data_dir, num_workers=self.hparams.num_workers, batch_size=batch_size) datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() self.datamodule = datamodule # init encoder self.encoder = encoder if isinstance(encoder, str): self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) self.contrastive_task = CPCTask(num_input_channels=c, target_dim=64, embed_scale=0.1) self.z_dim = c * h * h self.num_classes = self.datamodule.num_classes if pretrained: self.load_pretrained(encoder)
def __init__(self, datamodule: pl.LightningDataModule = None, data_dir: str = './', learning_rate: float = 0.00006, weight_decay: float = 0.0005, input_height: int = 32, batch_size: int = 128, online_ft: bool = False, num_workers: int = 4, optimizer: str = 'lars', lr_sched_step: float = 30.0, lr_sched_gamma: float = 0.5, lars_momentum: float = 0.9, lars_eta: float = 0.001, loss_temperature: float = 0.5, **kwargs): """ PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_ Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton. Model implemented by: - `William Falcon <https://github.com/williamFalcon>`_ - `Tullie Murrell <https://github.com/tullie>`_ Example: >>> from pl_bolts.models.self_supervised import SimCLR ... >>> model = SimCLR() Train:: trainer = Trainer() trainer.fit(model) CLI command:: # cifar10 python simclr_module.py --gpus 1 # imagenet python simclr_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32 Args: datamodule: The datamodule data_dir: directory to store data learning_rate: the learning rate weight_decay: optimizer weight decay input_height: image input height batch_size: the batch size online_ft: whether to tune online or not num_workers: number of workers optimizer: optimizer name lr_sched_step: step for learning rate scheduler lr_sched_gamma: gamma for learning rate scheduler lars_momentum: the mom param for lars optimizer lars_eta: for lars optimizer loss_temperature: float = 0. """ super().__init__() self.save_hyperparameters() self.online_evaluator = online_ft # init default datamodule if datamodule is None: datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size) datamodule.train_transforms = SimCLRTrainDataTransform( input_height) datamodule.val_transforms = SimCLREvalDataTransform(input_height) self.datamodule = datamodule self.loss_func = self.init_loss() self.encoder = self.init_encoder() self.projection = self.init_projection() if self.online_evaluator: z_dim = self.projection.output_dim num_classes = self.datamodule.num_classes self.non_linear_evaluator = SSLEvaluator(n_input=z_dim, n_classes=num_classes, p=0.2, n_hidden=1024)
def manual_test_step(self, dm: LightningDataModule) -> None: """Manually perform the test step. Used for debugging.""" dm.setup() batch = next(iter(dm.test_dataloader())) self.test_step(batch, 0)