def test_can_return_tensor_with_more_than_one_element(tmpdir): """Ensure {validation,test}_step return values are not included as callback metrics. #6623 """ class TestModel(BoringModel): def validation_step(self, batch, *args, **kwargs): return {"val": torch.tensor([0, 1])} def validation_epoch_end(self, outputs): # ensure validation step returns still appear here assert len(outputs) == 2 assert all(list(d) == ["val"] for d in outputs) # check keys assert all( torch.equal(d["val"], torch.tensor([0, 1])) for d in outputs) # check values def test_step(self, batch, *args, **kwargs): return {"test": torch.tensor([0, 1])} def test_epoch_end(self, outputs): assert len(outputs) == 2 assert all(list(d) == ["test"] for d in outputs) # check keys assert all( torch.equal(d["test"], torch.tensor([0, 1])) for d in outputs) # check values model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2, enable_progress_bar=False) trainer.fit(model) trainer.validate(model) trainer.test(model)
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders): """ Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario """ os.environ['PL_DEV_DEBUG'] = '1' class TestModel(BoringModel): test_losses = {} @Helper.decorator_with_arguments(fx_name="test_step") def test_step(self, batch, batch_idx, dl_idx=0): output = self.layer(batch) loss = self.loss(batch, output) primary_key = str(dl_idx) if primary_key not in self.test_losses: self.test_losses[primary_key] = [] self.test_losses[primary_key].append(loss) self.log("test_loss", loss, on_step=True, on_epoch=True) return {"test_loss": loss} def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] model = TestModel() model.val_dataloader = None model.test_epoch_end = None limit_test_batches = 4 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=0, limit_val_batches=0, limit_test_batches=limit_test_batches, max_epochs=1, log_every_n_steps=1, weights_summary=None, ) trainer.test(model) test_results = trainer.logger_connector._cached_results["test"] generated = test_results(fx_name="test_step") assert len(generated) == num_dataloaders for dl_idx in range(num_dataloaders): generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx))) assert generated == limit_test_batches test_results.has_batch_loop_finished = True for dl_idx in range(num_dataloaders): expected = torch.stack(model.test_losses[str(dl_idx)]).mean() generated = test_results(fx_name="test_step", dl_idx=str(dl_idx), reduced=True)["test_loss_epoch"] assert abs(expected.item() - generated.item()) < 1e-6
def train(hparams): NUM_GPUS = hparams.num_gpus USE_AMP = False # True if NUM_GPUS > 1 else False MAX_EPOCHS = 50 dataset = load_link_dataset(hparams.dataset, hparams=hparams) hparams.n_classes = dataset.n_classes model = LATTELinkPredictor(hparams, dataset, collate_fn="triples_batch", metrics=[hparams.dataset]) wandb_logger = WandbLogger(name=model.name(), tags=[dataset.name()], project="multiplex-comparison") trainer = Trainer( gpus=NUM_GPUS, distributed_backend='ddp' if NUM_GPUS > 1 else None, auto_lr_find=False, max_epochs=MAX_EPOCHS, early_stop_callback=EarlyStopping(monitor='val_loss', patience=10, min_delta=0.01, strict=False), logger=wandb_logger, # regularizers=regularizers, weights_summary='top', amp_level='O1' if USE_AMP else None, precision=16 if USE_AMP else 32) trainer.fit(model) trainer.test(model)
def test_epoch_results_cache_dp(tmpdir): root_device = torch.device("cuda", 0) class TestModel(BoringModel): def training_step(self, *args, **kwargs): result = super().training_step(*args, **kwargs) self.log("train_loss_epoch", result["loss"], on_step=False, on_epoch=True) return result def training_step_end(self, training_step_outputs): # required for dp loss = training_step_outputs["loss"].mean() return loss def training_epoch_end(self, outputs): assert all(out["loss"].device == root_device for out in outputs) assert self.trainer.callback_metrics["train_loss_epoch"].device == root_device def validation_step(self, *args, **kwargs): val_loss = torch.rand(1, device=torch.device("cuda", 1)) self.log("val_loss_epoch", val_loss, on_step=False, on_epoch=True) return val_loss def validation_epoch_end(self, outputs): assert all(loss.device == root_device for loss in outputs) assert self.trainer.callback_metrics["val_loss_epoch"].device == root_device def test_step(self, *args, **kwargs): test_loss = torch.rand(1, device=torch.device("cuda", 1)) self.log("test_loss_epoch", test_loss, on_step=False, on_epoch=True) return test_loss def test_epoch_end(self, outputs): assert all(loss.device == root_device for loss in outputs) assert self.trainer.callback_metrics["test_loss_epoch"].device == root_device def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=4) def val_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=4) def test_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=4) model = TestModel() trainer = Trainer( default_root_dir=tmpdir, strategy="dp", accelerator="gpu", devices=2, limit_train_batches=2, limit_val_batches=2, max_epochs=1, ) trainer.fit(model) trainer.test(model)
def run_encoder(train, test, epochs): """ Instances and runs autoencoder. Parameters: train (pandas.DataFrame): DataFrame of training data test (pandas.DataFrame): DataFrame of testing data epochs (int): Training epochs Returns: Autoencoder loss on test data """ # Instances training dataset data_train = MELoader(train) # Instances testing dataset data_test = MELoader(test) # Instances non-mechanistic autoencoder feats = data_train.data.shape[1] encoder = NMEncoder(feats, feats // 2) # Instances PyTorch Lightning trainer trainer = Trainer(gpus=1, num_nodes=1, max_epochs=epochs) # Performs model fitting on training set trainer.fit(encoder, DataLoader(dataset=data_train)) # Performs test on testing set performance = trainer.test(encoder, DataLoader(dataset=data_test)) return performance[0]["test_loss"]
def train(hparams: Namespace): NUM_GPUS = hparams.num_gpus USE_AMP = False # True if NUM_GPUS > 1 else False MAX_EPOCHS = 50 neighbor_sizes = [ hparams.n_neighbors, ] for t in range(1, hparams.t_order): neighbor_sizes.extend([neighbor_sizes[-1] // 2]) print("neighbor_sizes", neighbor_sizes) hparams.neighbor_sizes = neighbor_sizes dataset = load_node_dataset(hparams.dataset, method="LATTE", hparams=hparams, train_ratio=None, dir_path=hparams.dir_path) METRICS = [ "precision", "recall", "f1", "accuracy" if dataset.multilabel else hparams.dataset, "top_k" ] hparams.loss_type = "BCE" if dataset.multilabel else hparams.loss_type hparams.n_classes = dataset.n_classes model = LATTENodeClassifier(hparams, dataset, collate_fn="neighbor_sampler", metrics=METRICS) logger = WandbLogger(name=model.name(), tags=[dataset.name()], project="multiplex-comparison") trainer = Trainer( gpus=NUM_GPUS, distributed_backend='ddp' if NUM_GPUS > 1 else None, gradient_clip_val=hparams.gradient_clip_val, # auto_lr_find=True, max_epochs=MAX_EPOCHS, # early_stop_callback=EarlyStopping(monitor='val_loss', patience=5, min_delta=0.001, strict=False), logger=logger, amp_level='O1' if USE_AMP else None, precision=16 if USE_AMP else 32) trainer.fit(model) trainer.test(model)
def run_encoder(train, test, epochs, width, depth, dropout_prob=0.2, reg_coef=0): """ Instances and runs extendable autoencoder. Parameters: train (pandas.DataFrame): DataFrame of training data test (pandas.DataFrame): DataFrame of testing data epochs (int): Training epochs width (int): Number of latent attributes depth (int): Number of encoding/decoding layers dropout_prob (float, default=0.2): Probability of drop-out reg_coef (float, default=0): Regularization coefficient Returns: Autoencoder loss on test data """ # Instances training dataset data_train = MELoader(train) # Instances testing dataset data_test = MELoader(test) # Instances non-mechanistic autoencoder feats = data_train.data.shape[1] encoder = NMEncoder(feats, width, dropout_prob=dropout_prob, n_layers=depth, reg_coef=reg_coef) # Instances PyTorch Lightning trainer trainer = Trainer( auto_scale_batch_size=True, auto_select_gpus=True, checkpoint_callback=False, gpus=1, logger=False, max_epochs=epochs, # progress_bar_refresh_rate=0, weights_summary=None, ) # Performs model fitting on training set trainer.fit(encoder, DataLoader(dataset=data_train)) # Performs test on testing set performance = trainer.test(encoder, DataLoader(dataset=data_test)) loss = performance[0]["test_loss"] latent = performance[0]["latent"] return loss, latent
def test_metrics_reset(tmpdir): """Tests that metrics are reset correctly after the end of the train/val/test epoch.""" class TestModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 1) def _create_metrics(self): acc = Accuracy() acc.reset = mock.Mock(side_effect=acc.reset) ap = AveragePrecision(num_classes=1, pos_label=1) ap.reset = mock.Mock(side_effect=ap.reset) return acc, ap def setup(self, stage): fn = stage if fn == "fit": for stage in ("train", "validate"): acc, ap = self._create_metrics() self.add_module(f"acc_{fn}_{stage}", acc) self.add_module(f"ap_{fn}_{stage}", ap) else: acc, ap = self._create_metrics() stage = self.trainer.state.stage self.add_module(f"acc_{fn}_{stage}", acc) self.add_module(f"ap_{fn}_{stage}", ap) def forward(self, x): return self.layer(x) def _step(self, batch): fn, stage = self.trainer.state.fn, self.trainer.state.stage logits = self(batch) loss = logits.sum() self.log(f"loss/{fn}_{stage}", loss) acc = self._modules[f"acc_{fn}_{stage}"] ap = self._modules[f"ap_{fn}_{stage}"] preds = torch.rand(len(batch)) # Fake preds labels = torch.randint(0, 1, [len(batch)]) # Fake targets acc(preds, labels) ap(preds, labels) # Metric.forward calls reset so reset the mocks here acc.reset.reset_mock() ap.reset.reset_mock() self.log(f"acc/{fn}_{stage}", acc) self.log(f"ap/{fn}_{stage}", ap) return loss def training_step(self, batch, batch_idx, *args, **kwargs): return self._step(batch) def validation_step(self, batch, batch_idx, *args, **kwargs): if self.trainer.sanity_checking: return return self._step(batch) def test_step(self, batch, batch_idx, *args, **kwargs): return self._step(batch) def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] def train_dataloader(self): return DataLoader(RandomDataset(32, 64)) def val_dataloader(self): return DataLoader(RandomDataset(32, 64)) def test_dataloader(self): return DataLoader(RandomDataset(32, 64)) def _assert_called(model, fn, stage): acc = model._modules[f"acc_{fn}_{stage}"] ap = model._modules[f"ap_{fn}_{stage}"] acc.reset.assert_called_once() ap.reset.assert_called_once() model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1, enable_progress_bar=False, num_sanity_val_steps=2, enable_checkpointing=False, ) trainer.fit(model) _assert_called(model, "fit", "train") _assert_called(model, "fit", "validate") trainer.validate(model) _assert_called(model, "validate", "validate") trainer.test(model) _assert_called(model, "test", "test")
def test_fx_validator_integration(tmpdir): """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors.""" not_supported = { None: "`self.trainer` reference is not registered", "on_before_accelerator_backend_setup": "You can't", "setup": "You can't", "configure_sharded_model": "You can't", "on_configure_sharded_model": "You can't", "configure_optimizers": "You can't", "on_fit_start": "You can't", "on_pretrain_routine_start": "You can't", "on_pretrain_routine_end": "You can't", "train_dataloader": "You can't", "val_dataloader": "You can't", "on_validation_end": "You can't", "on_train_end": "You can't", "on_fit_end": "You can't", "teardown": "You can't", "on_sanity_check_start": "You can't", "on_sanity_check_end": "You can't", "prepare_data": "You can't", "configure_callbacks": "You can't", "on_validation_model_eval": "You can't", "on_validation_model_train": "You can't", "lr_scheduler_step": "You can't", "on_save_checkpoint": "You can't", "on_load_checkpoint": "You can't", "on_exception": "You can't", } model = HookedModel(not_supported) with pytest.warns(UserWarning, match=not_supported[None]): model.log("foo", 1) callback = HookedCallback(not_supported) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, callbacks=callback, ) with pytest.deprecated_call(match="is deprecated in"): trainer.fit(model) not_supported.update({ # `lightning_module` ref is now present from the `fit` call "on_before_accelerator_backend_setup": "You can't", "test_dataloader": "You can't", "on_test_model_eval": "You can't", "on_test_model_train": "You can't", "on_test_end": "You can't", }) with pytest.deprecated_call(match="is deprecated in"): trainer.test(model, verbose=False) not_supported.update( {k: "result collection is not registered yet" for k in not_supported}) not_supported.update({ "predict_dataloader": "result collection is not registered yet", "on_predict_model_eval": "result collection is not registered yet", "on_predict_start": "result collection is not registered yet", "on_predict_epoch_start": "result collection is not registered yet", "on_predict_batch_start": "result collection is not registered yet", "predict_step": "result collection is not registered yet", "on_predict_batch_end": "result collection is not registered yet", "on_predict_epoch_end": "result collection is not registered yet", "on_predict_end": "result collection is not registered yet", }) with pytest.deprecated_call(match="is deprecated in"): trainer.predict(model)
def train(hparams): EMBEDDING_DIM = 128 USE_AMP = None NUM_GPUS = hparams.num_gpus MAX_EPOCHS = 1000 batch_order = 11 dataset = load_node_dataset(hparams.dataset, hparams.method, hparams=hparams, train_ratio=hparams.train_ratio) METRICS = [ "precision", "recall", "f1", "accuracy", "top_k" if dataset.multilabel else "ogbn-mag", ] if hparams.method == "HAN": USE_AMP = False model_hparams = { "embedding_dim": EMBEDDING_DIM, "batch_size": 2**batch_order, "num_layers": 2, "collate_fn": "HAN_batch", "train_ratio": dataset.train_ratio, "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "n_classes": dataset.n_classes, "lr": 0.001, } model = HAN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS) elif hparams.method == "GTN": USE_AMP = False model_hparams = { "embedding_dim": EMBEDDING_DIM, "num_channels": len(dataset.metapaths), "num_layers": 2, "batch_size": 2**batch_order, "collate_fn": "HAN_batch", "train_ratio": dataset.train_ratio, "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "n_classes": dataset.n_classes, "lr": 0.001, } model = GTN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS) elif hparams.method == "MetaPath2Vec": USE_AMP = False model_hparams = { "embedding_dim": EMBEDDING_DIM, "walk_length": 50, "context_size": 7, "walks_per_node": 5, "num_negative_samples": 5, "sparse": True, "batch_size": 400, "train_ratio": dataset.train_ratio, "n_classes": dataset.n_classes, "lr": 0.01, } model = MetaPath2Vec(Namespace(**model_hparams), dataset=dataset, metrics=METRICS) elif hparams.method == "HGT": USE_AMP = False model_hparams = { "embedding_dim": EMBEDDING_DIM, "num_channels": len(dataset.metapaths), "n_layers": 2, "attn_heads": 8, "attn_dropout": 0.2, "prev_norm": True, "last_norm": True, "nb_cls_dense_size": 0, "nb_cls_dropout": 0.0, "use_class_weights": False, "batch_size": 2**batch_order, "n_epoch": MAX_EPOCHS, "train_ratio": dataset.train_ratio, "loss_type": "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "n_classes": dataset.n_classes, "collate_fn": "collate_HGT_batch", "lr": 0.001, # Not used here, defaults to 1e-3 } model = HGT(Namespace(**model_hparams), dataset, metrics=METRICS) elif "LATTE" in hparams.method: USE_AMP = False num_gpus = 1 if "-1" in hparams.method: n_layers = 1 elif "-2" in hparams.method: n_layers = 2 elif "-3" in hparams.method: n_layers = 3 else: n_layers = 2 model_hparams = { "embedding_dim": EMBEDDING_DIM, "layer_pooling": "concat", "n_layers": n_layers, "batch_size": 2**batch_order, "nb_cls_dense_size": 0, "nb_cls_dropout": 0.4, "activation": "relu", "dropout": 0.2, "attn_heads": 2, "attn_activation": "sharpening", "batchnorm": False, "layernorm": False, "edge_sampling": False, "edge_threshold": 0.5, "attn_dropout": 0.2, "loss_type": "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "use_proximity": True if "proximity" in hparams.method else False, "neg_sampling_ratio": 2.0, "n_classes": dataset.n_classes, "use_class_weights": False, "lr": 0.001, "momentum": 0.9, "weight_decay": 1e-2, } model_hparams.update(hparams.__dict__) metrics = [ "precision", "recall", "micro_f1", "macro_f1", "accuracy" if dataset.multilabel else "ogbn-mag", "top_k" ] model = LATTENodeClf(Namespace(**model_hparams), dataset, collate_fn="neighbor_sampler", metrics=metrics) MAX_EPOCHS = 250 wandb_logger = WandbLogger(name=model.name(), tags=[dataset.name()], anonymous=True, project="anon-demo") wandb_logger.log_hyperparams(model_hparams) trainer = Trainer(gpus=NUM_GPUS, distributed_backend='dp' if NUM_GPUS > 1 else None, max_epochs=MAX_EPOCHS, stochastic_weight_avg=True, callbacks=[ EarlyStopping(monitor='val_loss', patience=10, min_delta=0.0001, strict=False) ], logger=wandb_logger, weights_summary='top', amp_level='O1' if USE_AMP and NUM_GPUS > 0 else None, precision=16 if USE_AMP else 32) trainer.fit(model) model.register_hooks() trainer.test(model) wandb_logger.log_metrics( model.clustering_metrics(n_runs=10, compare_node_types=False))
def train(hparams): EMBEDDING_DIM = 128 NUM_GPUS = hparams.num_gpus batch_order = 11 dataset = load_node_dataset(hparams.dataset, hparams.method, hparams=hparams, train_ratio=hparams.train_ratio) METRICS = ["precision", "recall", "f1", "accuracy", "top_k" if dataset.multilabel else "ogbn-mag", ] if hparams.method == "HAN": USE_AMP = True model_hparams = { "embedding_dim": EMBEDDING_DIM, "batch_size": 2 ** batch_order * NUM_GPUS, "num_layers": 2, "collate_fn": "HAN_batch", "train_ratio": dataset.train_ratio, "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "n_classes": dataset.n_classes, "lr": 0.0005 * NUM_GPUS, } model = HAN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS) elif hparams.method == "GTN": USE_AMP = True model_hparams = { "embedding_dim": EMBEDDING_DIM, "num_channels": len(dataset.metapaths), "num_layers": 2, "batch_size": 2 ** batch_order * NUM_GPUS, "collate_fn": "HAN_batch", "train_ratio": dataset.train_ratio, "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "n_classes": dataset.n_classes, "lr": 0.0005 * NUM_GPUS, } model = GTN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS) elif hparams.method == "MetaPath2Vec": USE_AMP = True model_hparams = { "embedding_dim": EMBEDDING_DIM, "walk_length": 50, "context_size": 7, "walks_per_node": 5, "num_negative_samples": 5, "sparse": True, "batch_size": 400 * NUM_GPUS, "train_ratio": dataset.train_ratio, "n_classes": dataset.n_classes, "lr": 0.01 * NUM_GPUS, } model = MetaPath2Vec(Namespace(**model_hparams), dataset=dataset, metrics=METRICS) elif "LATTE" in hparams.method: USE_AMP = False num_gpus = 1 if "-1" in hparams.method: t_order = 1 elif "-2" in hparams.method: t_order = 2 elif "-3" in hparams.method: t_order = 3 else: t_order = 2 model_hparams = { "embedding_dim": EMBEDDING_DIM, "t_order": t_order, "batch_size": 2 ** batch_order * max(num_gpus, 1), "nb_cls_dense_size": 0, "nb_cls_dropout": 0.4, "activation": "relu", "attn_heads": 2, "attn_activation": "sharpening", "attn_dropout": 0.2, "loss_type": "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY", "use_proximity": True if "proximity" in hparams.method else False, "neg_sampling_ratio": 2.0, "n_classes": dataset.n_classes, "use_class_weights": False, "lr": 0.001 * num_gpus, "momentum": 0.9, "weight_decay": 1e-2, } metrics = ["precision", "recall", "micro_f1", "accuracy" if dataset.multilabel else "ogbn-mag", "top_k"] model = LATTENodeClassifier(Namespace(**model_hparams), dataset, collate_fn="neighbor_sampler", metrics=metrics) MAX_EPOCHS = 250 wandb_logger = WandbLogger(name=model.name(), tags=[dataset.name()], project="multiplex-comparison") wandb_logger.log_hyperparams(model_hparams) trainer = Trainer( gpus=NUM_GPUS, auto_select_gpus=True, distributed_backend='dp' if NUM_GPUS > 1 else None, max_epochs=MAX_EPOCHS, callbacks=[EarlyStopping(monitor='val_loss', patience=10, min_delta=0.0001, strict=False)], logger=wandb_logger, weights_summary='top', amp_level='O1' if USE_AMP else None, precision=16 if USE_AMP else 32 ) # trainer.fit(model) trainer.fit(model, train_dataloader=model.valtrain_dataloader(), val_dataloaders=model.test_dataloader()) trainer.test(model)
def test_metrics_reset(tmpdir): """Tests that metrics are reset correctly after the end of the train/val/test epoch.""" class TestModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 1) for stage in ['train', 'val', 'test']: acc = Accuracy() acc.reset = mock.Mock(side_effect=acc.reset) ap = AveragePrecision(num_classes=1, pos_label=1) ap.reset = mock.Mock(side_effect=ap.reset) self.add_module(f"acc_{stage}", acc) self.add_module(f"ap_{stage}", ap) def forward(self, x): return self.layer(x) def _step(self, stage, batch): labels = (batch.detach().sum(1) > 0).float() # Fake some targets logits = self.forward(batch) loss = torch.nn.functional.binary_cross_entropy_with_logits( logits, labels.unsqueeze(1)) probs = torch.sigmoid(logits.detach()) self.log(f"loss/{stage}", loss) acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] labels_int = labels.to(torch.long) acc(probs.flatten(), labels_int) ap(probs.flatten(), labels_int) # Metric.forward calls reset so reset the mocks here acc.reset.reset_mock() ap.reset.reset_mock() self.log(f"{stage}/accuracy", acc) self.log(f"{stage}/ap", ap) return loss def training_step(self, batch, batch_idx, *args, **kwargs): return self._step('train', batch) def validation_step(self, batch, batch_idx, *args, **kwargs): return self._step('val', batch) def test_step(self, batch, batch_idx, *args, **kwargs): return self._step('test', batch) def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] def train_dataloader(self): return DataLoader(RandomDataset(32, 64)) def val_dataloader(self): return DataLoader(RandomDataset(32, 64)) def test_dataloader(self): return DataLoader(RandomDataset(32, 64)) def _assert_epoch_end(self, stage): acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] acc.reset.assert_called_once() ap.reset.assert_called_once() def teardown(self, stage): if stage == TrainerFn.FITTING: self._assert_epoch_end('train') self._assert_epoch_end('val') elif stage == TrainerFn.VALIDATING: self._assert_epoch_end('val') elif stage == TrainerFn.TESTING: self._assert_epoch_end('test') def _assert_called(model, stage): acc = model._modules[f"acc_{stage}"] ap = model._modules[f"ap_{stage}"] assert acc.reset.call_count == 1 acc.reset.reset_mock() assert ap.reset.call_count == 1 ap.reset.reset_mock() model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1, progress_bar_refresh_rate=0, num_sanity_val_steps=2, ) trainer.fit(model) _assert_called(model, 'train') _assert_called(model, 'val') trainer.validate(model) _assert_called(model, 'val') trainer.test(model) _assert_called(model, 'test')
def main(model_name='usp_1d', max_epochs=1020, data_dir='./data', dataset='sc09', ps=False, wn=False, mx=False, perc=1, ts=False, fd=False, tts=False, tm=False, train=True, order=True, model_num=None): model_name = model_name + '_' + str(int(perc * 100)) dataset_f = dataset nsynth_class = None if dataset == 'sc09': sample_rate = 16000 n_classes = 10 length = 1 batch_size = 256 train_transform = transforms.Compose([ torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate), pad, ]) elif dataset == 'sc': sample_rate = 16000 n_classes = 35 batch_size = 128 length = 1 train_transform = transforms.Compose([ torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate), partial(pad, length=length), ]) elif dataset == 'nsynth11': sample_rate = 16000 n_classes = 11 batch_size = 32 max_epochs = 120 dataset = 'nsynth' nsynth_class = 'instrument_family' length = 4 train_transform = transforms.Compose([ torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate), partial(pad, length=length), ]) elif dataset == 'nsynth128': sample_rate = 16000 n_classes = 128 batch_size = 16 max_epochs = 120 dataset = 'nsynth' nsynth_class = 'pitch' length = 4 train_transform = transforms.Compose([ torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate), partial(pad, length=length), ]) elif dataset == 'esc50': sample_rate = 16000 max_epochs = 2000 n_classes = 50 batch_size = 64 length = 5 train_transform = transforms.Compose([ torchaudio.transforms.Resample(orig_freq=44100, new_freq=sample_rate), partial(pad, length=length), ]) elif dataset == 'esc10': sample_rate = 16000 n_classes = 10 max_epochs = 2000 batch_size = 64 length = 5 train_transform = transforms.Compose([ torchaudio.transforms.Resample(orig_freq=44100, new_freq=sample_rate), partial(pad, length=length), ]) # model_name = model_name + '_' + dataset spec_transform = None aug_transform = [] if order: if fd: aug_transform.append(transforms.RandomApply(add_fade)) model_name = model_name + '_fd' if tm: aug_transform.append(transforms.RandomApply(time_masking)) model_name = model_name + '_tm' if tts: aug_transform.append( transforms.RandomApply(partial(time_stret, length=length))) model_name = model_name + '_tts' if ps: aug_transform.append(transforms.RandomApply(pitch_shift)) model_name = model_name + '_ps' if ts: aug_transform.append(transforms.RandomApply(time_shift)) model_name = model_name + '_ts' if wn: aug_transform.append(transforms.RandomApply(add_white_noise)) model_name = model_name + '_wn' if mx: m_x = Mixed_Noise(data_dir, sample_rate) aug_transform.append(transforms.RandomApply(m_x)) model_name = model_name + '_mx' else: if mx: m_x = Mixed_Noise(data_dir, sample_rate) aug_transform.append(transforms.RandomApply(m_x)) model_name = model_name + '_mx' if wn: aug_transform.append(transforms.RandomApply(add_white_noise)) model_name = model_name + '_wn' if ts: aug_transform.append(transforms.RandomApply(time_shift)) model_name = model_name + '_ts' if ps: aug_transform.append(transforms.RandomApply(pitch_shift)) model_name = model_name + '_ps' if fd: aug_transform.append(transforms.RandomApply(add_fade)) model_name = model_name + '_fd' if tts: aug_transform.append( transforms.RandomApply(partial(time_stret, length=length))) model_name = model_name + '_tts' if tm: aug_transform.append(transforms.RandomApply(time_masking)) model_name = model_name + '_tm' aug_transform = transforms.Compose(aug_transform) print(f"Model: {model_name}") net = Main(batch_size=batch_size, sampling_rate=sample_rate, data_dir=data_dir, dataset=dataset, perc=perc, nsynth_class=nsynth_class, n_classes=n_classes, train_transform=train_transform, aug_transform=aug_transform, spec_transform=spec_transform, model=model_name) model_path = os.path.join(MODELS_FOLDER, model_name, dataset_f) os.makedirs(model_path, exist_ok=True) checkpoint_callback = ModelCheckpoint( filepath=model_path, save_last=True, mode='min', period=10, save_top_k=20000000, ) if model_num is not None: checkpoint = os.path.join(model_path, get_last(os.listdir(model_path), model_num)) elif os.path.exists(model_path) and len(os.listdir(model_path)) > 0: checkpoint = os.path.join(model_path, get_last(os.listdir(model_path))) else: checkpoint = None logger = TensorBoardLogger(save_dir=LOGS_FOLDER, version=dataset_f, name=model_name) # finetune in real-time print(f"Loading model: {checkpoint}") def to_device(batch, device): (x1, x2), y = batch x1 = x1.to(device) y = y.to(device).squeeze() return x1, y online_eval = SSLOnlineEvaluator(hidden_dim=512, z_dim=512, num_classes=n_classes, train_transform=train_transform, data_dir=data_dir, dataset=dataset, batch_size=batch_size, nsynth_class=nsynth_class) online_eval.to_device = to_device trainer = Trainer(resume_from_checkpoint=checkpoint, distributed_backend='ddp', max_epochs=max_epochs, sync_batchnorm=True, checkpoint_callback=checkpoint_callback, logger=logger, gpus=-1 if train else 1, log_save_interval=25, callbacks=[online_eval]) if train: trainer.fit(net) else: trainer.test(net)