def test_can_return_tensor_with_more_than_one_element(tmpdir):
    """Ensure {validation,test}_step return values are not included as callback metrics.

    #6623
    """
    class TestModel(BoringModel):
        def validation_step(self, batch, *args, **kwargs):
            return {"val": torch.tensor([0, 1])}

        def validation_epoch_end(self, outputs):
            # ensure validation step returns still appear here
            assert len(outputs) == 2
            assert all(list(d) == ["val"] for d in outputs)  # check keys
            assert all(
                torch.equal(d["val"], torch.tensor([0, 1]))
                for d in outputs)  # check values

        def test_step(self, batch, *args, **kwargs):
            return {"test": torch.tensor([0, 1])}

        def test_epoch_end(self, outputs):
            assert len(outputs) == 2
            assert all(list(d) == ["test"] for d in outputs)  # check keys
            assert all(
                torch.equal(d["test"], torch.tensor([0, 1]))
                for d in outputs)  # check values

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=2,
                      enable_progress_bar=False)
    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders):
    """
    Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario
    """

    os.environ['PL_DEV_DEBUG'] = '1'

    class TestModel(BoringModel):

        test_losses = {}

        @Helper.decorator_with_arguments(fx_name="test_step")
        def test_step(self, batch, batch_idx, dl_idx=0):
            output = self.layer(batch)
            loss = self.loss(batch, output)

            primary_key = str(dl_idx)
            if primary_key not in self.test_losses:
                self.test_losses[primary_key] = []

            self.test_losses[primary_key].append(loss)

            self.log("test_loss", loss, on_step=True, on_epoch=True)
            return {"test_loss": loss}

        def test_dataloader(self):
            return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]

    model = TestModel()
    model.val_dataloader = None
    model.test_epoch_end = None

    limit_test_batches = 4

    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=0,
        limit_val_batches=0,
        limit_test_batches=limit_test_batches,
        max_epochs=1,
        log_every_n_steps=1,
        weights_summary=None,
    )
    trainer.test(model)

    test_results = trainer.logger_connector._cached_results["test"]

    generated = test_results(fx_name="test_step")
    assert len(generated) == num_dataloaders

    for dl_idx in range(num_dataloaders):
        generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx)))
        assert generated == limit_test_batches

    test_results.has_batch_loop_finished = True

    for dl_idx in range(num_dataloaders):
        expected = torch.stack(model.test_losses[str(dl_idx)]).mean()
        generated = test_results(fx_name="test_step", dl_idx=str(dl_idx), reduced=True)["test_loss_epoch"]
        assert abs(expected.item() - generated.item()) < 1e-6
Пример #3
0
def train(hparams):
    NUM_GPUS = hparams.num_gpus
    USE_AMP = False  # True if NUM_GPUS > 1 else False
    MAX_EPOCHS = 50

    dataset = load_link_dataset(hparams.dataset, hparams=hparams)
    hparams.n_classes = dataset.n_classes

    model = LATTELinkPredictor(hparams,
                               dataset,
                               collate_fn="triples_batch",
                               metrics=[hparams.dataset])
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               project="multiplex-comparison")

    trainer = Trainer(
        gpus=NUM_GPUS,
        distributed_backend='ddp' if NUM_GPUS > 1 else None,
        auto_lr_find=False,
        max_epochs=MAX_EPOCHS,
        early_stop_callback=EarlyStopping(monitor='val_loss',
                                          patience=10,
                                          min_delta=0.01,
                                          strict=False),
        logger=wandb_logger,
        # regularizers=regularizers,
        weights_summary='top',
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32)

    trainer.fit(model)
    trainer.test(model)
def test_epoch_results_cache_dp(tmpdir):

    root_device = torch.device("cuda", 0)

    class TestModel(BoringModel):
        def training_step(self, *args, **kwargs):
            result = super().training_step(*args, **kwargs)
            self.log("train_loss_epoch", result["loss"], on_step=False, on_epoch=True)
            return result

        def training_step_end(self, training_step_outputs):  # required for dp
            loss = training_step_outputs["loss"].mean()
            return loss

        def training_epoch_end(self, outputs):
            assert all(out["loss"].device == root_device for out in outputs)
            assert self.trainer.callback_metrics["train_loss_epoch"].device == root_device

        def validation_step(self, *args, **kwargs):
            val_loss = torch.rand(1, device=torch.device("cuda", 1))
            self.log("val_loss_epoch", val_loss, on_step=False, on_epoch=True)
            return val_loss

        def validation_epoch_end(self, outputs):
            assert all(loss.device == root_device for loss in outputs)
            assert self.trainer.callback_metrics["val_loss_epoch"].device == root_device

        def test_step(self, *args, **kwargs):
            test_loss = torch.rand(1, device=torch.device("cuda", 1))
            self.log("test_loss_epoch", test_loss, on_step=False, on_epoch=True)
            return test_loss

        def test_epoch_end(self, outputs):
            assert all(loss.device == root_device for loss in outputs)
            assert self.trainer.callback_metrics["test_loss_epoch"].device == root_device

        def train_dataloader(self):
            return DataLoader(RandomDataset(32, 64), batch_size=4)

        def val_dataloader(self):
            return DataLoader(RandomDataset(32, 64), batch_size=4)

        def test_dataloader(self):
            return DataLoader(RandomDataset(32, 64), batch_size=4)

    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        strategy="dp",
        accelerator="gpu",
        devices=2,
        limit_train_batches=2,
        limit_val_batches=2,
        max_epochs=1,
    )
    trainer.fit(model)
    trainer.test(model)
Пример #5
0
def run_encoder(train, test, epochs):
    """
    Instances and runs autoencoder.

    Parameters:
        train (pandas.DataFrame): DataFrame of training data
        test (pandas.DataFrame): DataFrame of testing data
        epochs (int): Training epochs

    Returns:
        Autoencoder loss on test data
    """
    # Instances training dataset
    data_train = MELoader(train)

    # Instances testing dataset
    data_test = MELoader(test)

    # Instances non-mechanistic autoencoder
    feats = data_train.data.shape[1]
    encoder = NMEncoder(feats, feats // 2)

    # Instances PyTorch Lightning trainer
    trainer = Trainer(gpus=1, num_nodes=1, max_epochs=epochs)

    # Performs model fitting on training set
    trainer.fit(encoder, DataLoader(dataset=data_train))

    # Performs test on testing set
    performance = trainer.test(encoder, DataLoader(dataset=data_test))

    return performance[0]["test_loss"]
Пример #6
0
def train(hparams: Namespace):
    NUM_GPUS = hparams.num_gpus
    USE_AMP = False  # True if NUM_GPUS > 1 else False
    MAX_EPOCHS = 50

    neighbor_sizes = [
        hparams.n_neighbors,
    ]
    for t in range(1, hparams.t_order):
        neighbor_sizes.extend([neighbor_sizes[-1] // 2])
    print("neighbor_sizes", neighbor_sizes)
    hparams.neighbor_sizes = neighbor_sizes

    dataset = load_node_dataset(hparams.dataset,
                                method="LATTE",
                                hparams=hparams,
                                train_ratio=None,
                                dir_path=hparams.dir_path)

    METRICS = [
        "precision", "recall", "f1",
        "accuracy" if dataset.multilabel else hparams.dataset, "top_k"
    ]
    hparams.loss_type = "BCE" if dataset.multilabel else hparams.loss_type
    hparams.n_classes = dataset.n_classes
    model = LATTENodeClassifier(hparams,
                                dataset,
                                collate_fn="neighbor_sampler",
                                metrics=METRICS)

    logger = WandbLogger(name=model.name(),
                         tags=[dataset.name()],
                         project="multiplex-comparison")

    trainer = Trainer(
        gpus=NUM_GPUS,
        distributed_backend='ddp' if NUM_GPUS > 1 else None,
        gradient_clip_val=hparams.gradient_clip_val,
        # auto_lr_find=True,
        max_epochs=MAX_EPOCHS,
        # early_stop_callback=EarlyStopping(monitor='val_loss', patience=5, min_delta=0.001, strict=False),
        logger=logger,
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32)

    trainer.fit(model)
    trainer.test(model)
Пример #7
0
def run_encoder(train,
                test,
                epochs,
                width,
                depth,
                dropout_prob=0.2,
                reg_coef=0):
    """
    Instances and runs extendable autoencoder.

    Parameters:
        train (pandas.DataFrame): DataFrame of training data
        test (pandas.DataFrame): DataFrame of testing data
        epochs (int): Training epochs
        width (int): Number of latent attributes
        depth (int): Number of encoding/decoding layers
        dropout_prob (float, default=0.2): Probability of drop-out
        reg_coef (float, default=0): Regularization coefficient

    Returns:
        Autoencoder loss on test data
    """
    # Instances training dataset
    data_train = MELoader(train)

    # Instances testing dataset
    data_test = MELoader(test)

    # Instances non-mechanistic autoencoder
    feats = data_train.data.shape[1]
    encoder = NMEncoder(feats,
                        width,
                        dropout_prob=dropout_prob,
                        n_layers=depth,
                        reg_coef=reg_coef)

    # Instances PyTorch Lightning trainer
    trainer = Trainer(
        auto_scale_batch_size=True,
        auto_select_gpus=True,
        checkpoint_callback=False,
        gpus=1,
        logger=False,
        max_epochs=epochs,
        # progress_bar_refresh_rate=0,
        weights_summary=None,
    )

    # Performs model fitting on training set
    trainer.fit(encoder, DataLoader(dataset=data_train))

    # Performs test on testing set
    performance = trainer.test(encoder, DataLoader(dataset=data_test))
    loss = performance[0]["test_loss"]
    latent = performance[0]["latent"]

    return loss, latent
def test_metrics_reset(tmpdir):
    """Tests that metrics are reset correctly after the end of the train/val/test epoch."""
    class TestModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(32, 1)

        def _create_metrics(self):
            acc = Accuracy()
            acc.reset = mock.Mock(side_effect=acc.reset)
            ap = AveragePrecision(num_classes=1, pos_label=1)
            ap.reset = mock.Mock(side_effect=ap.reset)
            return acc, ap

        def setup(self, stage):
            fn = stage
            if fn == "fit":
                for stage in ("train", "validate"):
                    acc, ap = self._create_metrics()
                    self.add_module(f"acc_{fn}_{stage}", acc)
                    self.add_module(f"ap_{fn}_{stage}", ap)
            else:
                acc, ap = self._create_metrics()
                stage = self.trainer.state.stage
                self.add_module(f"acc_{fn}_{stage}", acc)
                self.add_module(f"ap_{fn}_{stage}", ap)

        def forward(self, x):
            return self.layer(x)

        def _step(self, batch):
            fn, stage = self.trainer.state.fn, self.trainer.state.stage

            logits = self(batch)
            loss = logits.sum()
            self.log(f"loss/{fn}_{stage}", loss)

            acc = self._modules[f"acc_{fn}_{stage}"]
            ap = self._modules[f"ap_{fn}_{stage}"]

            preds = torch.rand(len(batch))  # Fake preds
            labels = torch.randint(0, 1, [len(batch)])  # Fake targets
            acc(preds, labels)
            ap(preds, labels)

            # Metric.forward calls reset so reset the mocks here
            acc.reset.reset_mock()
            ap.reset.reset_mock()

            self.log(f"acc/{fn}_{stage}", acc)
            self.log(f"ap/{fn}_{stage}", ap)

            return loss

        def training_step(self, batch, batch_idx, *args, **kwargs):
            return self._step(batch)

        def validation_step(self, batch, batch_idx, *args, **kwargs):
            if self.trainer.sanity_checking:
                return
            return self._step(batch)

        def test_step(self, batch, batch_idx, *args, **kwargs):
            return self._step(batch)

        def configure_optimizers(self):
            optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                           step_size=1)
            return [optimizer], [lr_scheduler]

        def train_dataloader(self):
            return DataLoader(RandomDataset(32, 64))

        def val_dataloader(self):
            return DataLoader(RandomDataset(32, 64))

        def test_dataloader(self):
            return DataLoader(RandomDataset(32, 64))

    def _assert_called(model, fn, stage):
        acc = model._modules[f"acc_{fn}_{stage}"]
        ap = model._modules[f"ap_{fn}_{stage}"]
        acc.reset.assert_called_once()
        ap.reset.assert_called_once()

    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        max_epochs=1,
        enable_progress_bar=False,
        num_sanity_val_steps=2,
        enable_checkpointing=False,
    )

    trainer.fit(model)
    _assert_called(model, "fit", "train")
    _assert_called(model, "fit", "validate")

    trainer.validate(model)
    _assert_called(model, "validate", "validate")

    trainer.test(model)
    _assert_called(model, "test", "test")
def test_fx_validator_integration(tmpdir):
    """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors."""
    not_supported = {
        None: "`self.trainer` reference is not registered",
        "on_before_accelerator_backend_setup": "You can't",
        "setup": "You can't",
        "configure_sharded_model": "You can't",
        "on_configure_sharded_model": "You can't",
        "configure_optimizers": "You can't",
        "on_fit_start": "You can't",
        "on_pretrain_routine_start": "You can't",
        "on_pretrain_routine_end": "You can't",
        "train_dataloader": "You can't",
        "val_dataloader": "You can't",
        "on_validation_end": "You can't",
        "on_train_end": "You can't",
        "on_fit_end": "You can't",
        "teardown": "You can't",
        "on_sanity_check_start": "You can't",
        "on_sanity_check_end": "You can't",
        "prepare_data": "You can't",
        "configure_callbacks": "You can't",
        "on_validation_model_eval": "You can't",
        "on_validation_model_train": "You can't",
        "lr_scheduler_step": "You can't",
        "on_save_checkpoint": "You can't",
        "on_load_checkpoint": "You can't",
        "on_exception": "You can't",
    }
    model = HookedModel(not_supported)

    with pytest.warns(UserWarning, match=not_supported[None]):
        model.log("foo", 1)

    callback = HookedCallback(not_supported)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        callbacks=callback,
    )
    with pytest.deprecated_call(match="is deprecated in"):
        trainer.fit(model)

    not_supported.update({
        # `lightning_module` ref is now present from the `fit` call
        "on_before_accelerator_backend_setup": "You can't",
        "test_dataloader": "You can't",
        "on_test_model_eval": "You can't",
        "on_test_model_train": "You can't",
        "on_test_end": "You can't",
    })
    with pytest.deprecated_call(match="is deprecated in"):
        trainer.test(model, verbose=False)

    not_supported.update(
        {k: "result collection is not registered yet"
         for k in not_supported})
    not_supported.update({
        "predict_dataloader":
        "result collection is not registered yet",
        "on_predict_model_eval":
        "result collection is not registered yet",
        "on_predict_start":
        "result collection is not registered yet",
        "on_predict_epoch_start":
        "result collection is not registered yet",
        "on_predict_batch_start":
        "result collection is not registered yet",
        "predict_step":
        "result collection is not registered yet",
        "on_predict_batch_end":
        "result collection is not registered yet",
        "on_predict_epoch_end":
        "result collection is not registered yet",
        "on_predict_end":
        "result collection is not registered yet",
    })
    with pytest.deprecated_call(match="is deprecated in"):
        trainer.predict(model)
Пример #10
0
def train(hparams):
    EMBEDDING_DIM = 128
    USE_AMP = None
    NUM_GPUS = hparams.num_gpus
    MAX_EPOCHS = 1000
    batch_order = 11

    dataset = load_node_dataset(hparams.dataset,
                                hparams.method,
                                hparams=hparams,
                                train_ratio=hparams.train_ratio)

    METRICS = [
        "precision",
        "recall",
        "f1",
        "accuracy",
        "top_k" if dataset.multilabel else "ogbn-mag",
    ]

    if hparams.method == "HAN":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "batch_size": 2**batch_order,
            "num_layers": 2,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY"
            if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.001,
        }
        model = HAN(Namespace(**model_hparams),
                    dataset=dataset,
                    metrics=METRICS)
    elif hparams.method == "GTN":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "num_layers": 2,
            "batch_size": 2**batch_order,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY"
            if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.001,
        }
        model = GTN(Namespace(**model_hparams),
                    dataset=dataset,
                    metrics=METRICS)

    elif hparams.method == "MetaPath2Vec":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "walk_length": 50,
            "context_size": 7,
            "walks_per_node": 5,
            "num_negative_samples": 5,
            "sparse": True,
            "batch_size": 400,
            "train_ratio": dataset.train_ratio,
            "n_classes": dataset.n_classes,
            "lr": 0.01,
        }
        model = MetaPath2Vec(Namespace(**model_hparams),
                             dataset=dataset,
                             metrics=METRICS)

    elif hparams.method == "HGT":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "n_layers": 2,
            "attn_heads": 8,
            "attn_dropout": 0.2,
            "prev_norm": True,
            "last_norm": True,
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.0,
            "use_class_weights": False,
            "batch_size": 2**batch_order,
            "n_epoch": MAX_EPOCHS,
            "train_ratio": dataset.train_ratio,
            "loss_type":
            "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "collate_fn": "collate_HGT_batch",
            "lr": 0.001,  # Not used here, defaults to 1e-3
        }
        model = HGT(Namespace(**model_hparams), dataset, metrics=METRICS)

    elif "LATTE" in hparams.method:
        USE_AMP = False
        num_gpus = 1

        if "-1" in hparams.method:
            n_layers = 1
        elif "-2" in hparams.method:
            n_layers = 2
        elif "-3" in hparams.method:
            n_layers = 3
        else:
            n_layers = 2

        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "layer_pooling": "concat",
            "n_layers": n_layers,
            "batch_size": 2**batch_order,
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.4,
            "activation": "relu",
            "dropout": 0.2,
            "attn_heads": 2,
            "attn_activation": "sharpening",
            "batchnorm": False,
            "layernorm": False,
            "edge_sampling": False,
            "edge_threshold": 0.5,
            "attn_dropout": 0.2,
            "loss_type":
            "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "use_proximity": True if "proximity" in hparams.method else False,
            "neg_sampling_ratio": 2.0,
            "n_classes": dataset.n_classes,
            "use_class_weights": False,
            "lr": 0.001,
            "momentum": 0.9,
            "weight_decay": 1e-2,
        }

        model_hparams.update(hparams.__dict__)

        metrics = [
            "precision", "recall", "micro_f1", "macro_f1",
            "accuracy" if dataset.multilabel else "ogbn-mag", "top_k"
        ]

        model = LATTENodeClf(Namespace(**model_hparams),
                             dataset,
                             collate_fn="neighbor_sampler",
                             metrics=metrics)

    MAX_EPOCHS = 250
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               anonymous=True,
                               project="anon-demo")
    wandb_logger.log_hyperparams(model_hparams)

    trainer = Trainer(gpus=NUM_GPUS,
                      distributed_backend='dp' if NUM_GPUS > 1 else None,
                      max_epochs=MAX_EPOCHS,
                      stochastic_weight_avg=True,
                      callbacks=[
                          EarlyStopping(monitor='val_loss',
                                        patience=10,
                                        min_delta=0.0001,
                                        strict=False)
                      ],
                      logger=wandb_logger,
                      weights_summary='top',
                      amp_level='O1' if USE_AMP and NUM_GPUS > 0 else None,
                      precision=16 if USE_AMP else 32)
    trainer.fit(model)

    model.register_hooks()
    trainer.test(model)

    wandb_logger.log_metrics(
        model.clustering_metrics(n_runs=10, compare_node_types=False))
def train(hparams):
    EMBEDDING_DIM = 128
    NUM_GPUS = hparams.num_gpus
    batch_order = 11

    dataset = load_node_dataset(hparams.dataset, hparams.method, hparams=hparams, train_ratio=hparams.train_ratio)

    METRICS = ["precision", "recall", "f1", "accuracy", "top_k" if dataset.multilabel else "ogbn-mag", ]

    if hparams.method == "HAN":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "batch_size": 2 ** batch_order * NUM_GPUS,
            "num_layers": 2,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.0005 * NUM_GPUS,
        }
        model = HAN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif hparams.method == "GTN":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "num_layers": 2,
            "batch_size": 2 ** batch_order * NUM_GPUS,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.0005 * NUM_GPUS,
        }
        model = GTN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif hparams.method == "MetaPath2Vec":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "walk_length": 50,
            "context_size": 7,
            "walks_per_node": 5,
            "num_negative_samples": 5,
            "sparse": True,
            "batch_size": 400 * NUM_GPUS,
            "train_ratio": dataset.train_ratio,
            "n_classes": dataset.n_classes,
            "lr": 0.01 * NUM_GPUS,
        }
        model = MetaPath2Vec(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif "LATTE" in hparams.method:
        USE_AMP = False
        num_gpus = 1

        if "-1" in hparams.method:
            t_order = 1
        elif "-2" in hparams.method:
            t_order = 2
        elif "-3" in hparams.method:
            t_order = 3
        else:
            t_order = 2

        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "t_order": t_order,
            "batch_size": 2 ** batch_order * max(num_gpus, 1),
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.4,
            "activation": "relu",
            "attn_heads": 2,
            "attn_activation": "sharpening",
            "attn_dropout": 0.2,
            "loss_type": "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "use_proximity": True if "proximity" in hparams.method else False,
            "neg_sampling_ratio": 2.0,
            "n_classes": dataset.n_classes,
            "use_class_weights": False,
            "lr": 0.001 * num_gpus,
            "momentum": 0.9,
            "weight_decay": 1e-2,
        }

        metrics = ["precision", "recall", "micro_f1",
                   "accuracy" if dataset.multilabel else "ogbn-mag", "top_k"]

        model = LATTENodeClassifier(Namespace(**model_hparams), dataset, collate_fn="neighbor_sampler", metrics=metrics)

    MAX_EPOCHS = 250
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               project="multiplex-comparison")
    wandb_logger.log_hyperparams(model_hparams)

    trainer = Trainer(
        gpus=NUM_GPUS, auto_select_gpus=True,
        distributed_backend='dp' if NUM_GPUS > 1 else None,
        max_epochs=MAX_EPOCHS,
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, min_delta=0.0001, strict=False)],
        logger=wandb_logger,
        weights_summary='top',
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32
    )

    # trainer.fit(model)
    trainer.fit(model, train_dataloader=model.valtrain_dataloader(), val_dataloaders=model.test_dataloader())
    trainer.test(model)
def test_metrics_reset(tmpdir):
    """Tests that metrics are reset correctly after the end of the train/val/test epoch."""
    class TestModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(32, 1)

            for stage in ['train', 'val', 'test']:
                acc = Accuracy()
                acc.reset = mock.Mock(side_effect=acc.reset)
                ap = AveragePrecision(num_classes=1, pos_label=1)
                ap.reset = mock.Mock(side_effect=ap.reset)
                self.add_module(f"acc_{stage}", acc)
                self.add_module(f"ap_{stage}", ap)

        def forward(self, x):
            return self.layer(x)

        def _step(self, stage, batch):
            labels = (batch.detach().sum(1) > 0).float()  # Fake some targets
            logits = self.forward(batch)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                logits, labels.unsqueeze(1))
            probs = torch.sigmoid(logits.detach())
            self.log(f"loss/{stage}", loss)

            acc = self._modules[f"acc_{stage}"]
            ap = self._modules[f"ap_{stage}"]

            labels_int = labels.to(torch.long)
            acc(probs.flatten(), labels_int)
            ap(probs.flatten(), labels_int)

            # Metric.forward calls reset so reset the mocks here
            acc.reset.reset_mock()
            ap.reset.reset_mock()

            self.log(f"{stage}/accuracy", acc)
            self.log(f"{stage}/ap", ap)

            return loss

        def training_step(self, batch, batch_idx, *args, **kwargs):
            return self._step('train', batch)

        def validation_step(self, batch, batch_idx, *args, **kwargs):
            return self._step('val', batch)

        def test_step(self, batch, batch_idx, *args, **kwargs):
            return self._step('test', batch)

        def configure_optimizers(self):
            optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                           step_size=1)
            return [optimizer], [lr_scheduler]

        def train_dataloader(self):
            return DataLoader(RandomDataset(32, 64))

        def val_dataloader(self):
            return DataLoader(RandomDataset(32, 64))

        def test_dataloader(self):
            return DataLoader(RandomDataset(32, 64))

        def _assert_epoch_end(self, stage):
            acc = self._modules[f"acc_{stage}"]
            ap = self._modules[f"ap_{stage}"]

            acc.reset.assert_called_once()
            ap.reset.assert_called_once()

        def teardown(self, stage):
            if stage == TrainerFn.FITTING:
                self._assert_epoch_end('train')
                self._assert_epoch_end('val')

            elif stage == TrainerFn.VALIDATING:
                self._assert_epoch_end('val')

            elif stage == TrainerFn.TESTING:
                self._assert_epoch_end('test')

    def _assert_called(model, stage):
        acc = model._modules[f"acc_{stage}"]
        ap = model._modules[f"ap_{stage}"]

        assert acc.reset.call_count == 1
        acc.reset.reset_mock()

        assert ap.reset.call_count == 1
        ap.reset.reset_mock()

    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        max_epochs=1,
        progress_bar_refresh_rate=0,
        num_sanity_val_steps=2,
    )

    trainer.fit(model)
    _assert_called(model, 'train')
    _assert_called(model, 'val')

    trainer.validate(model)
    _assert_called(model, 'val')

    trainer.test(model)
    _assert_called(model, 'test')
Пример #13
0
def main(model_name='usp_1d',
         max_epochs=1020,
         data_dir='./data',
         dataset='sc09',
         ps=False,
         wn=False,
         mx=False,
         perc=1,
         ts=False,
         fd=False,
         tts=False,
         tm=False,
         train=True,
         order=True,
         model_num=None):
    model_name = model_name + '_' + str(int(perc * 100))
    dataset_f = dataset
    nsynth_class = None
    if dataset == 'sc09':
        sample_rate = 16000
        n_classes = 10
        length = 1
        batch_size = 256
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000,
                                           new_freq=sample_rate),
            pad,
        ])
    elif dataset == 'sc':
        sample_rate = 16000
        n_classes = 35
        batch_size = 128
        length = 1
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000,
                                           new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'nsynth11':
        sample_rate = 16000
        n_classes = 11
        batch_size = 32
        max_epochs = 120
        dataset = 'nsynth'
        nsynth_class = 'instrument_family'
        length = 4
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000,
                                           new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'nsynth128':
        sample_rate = 16000
        n_classes = 128
        batch_size = 16
        max_epochs = 120
        dataset = 'nsynth'
        nsynth_class = 'pitch'
        length = 4
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000,
                                           new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'esc50':
        sample_rate = 16000
        max_epochs = 2000
        n_classes = 50
        batch_size = 64
        length = 5
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=44100,
                                           new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'esc10':
        sample_rate = 16000
        n_classes = 10
        max_epochs = 2000
        batch_size = 64
        length = 5
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=44100,
                                           new_freq=sample_rate),
            partial(pad, length=length),
        ])

    # model_name = model_name + '_' + dataset
    spec_transform = None
    aug_transform = []
    if order:
        if fd:
            aug_transform.append(transforms.RandomApply(add_fade))
            model_name = model_name + '_fd'
        if tm:
            aug_transform.append(transforms.RandomApply(time_masking))
            model_name = model_name + '_tm'
        if tts:
            aug_transform.append(
                transforms.RandomApply(partial(time_stret, length=length)))
            model_name = model_name + '_tts'
        if ps:
            aug_transform.append(transforms.RandomApply(pitch_shift))
            model_name = model_name + '_ps'
        if ts:
            aug_transform.append(transforms.RandomApply(time_shift))
            model_name = model_name + '_ts'
        if wn:
            aug_transform.append(transforms.RandomApply(add_white_noise))
            model_name = model_name + '_wn'
        if mx:
            m_x = Mixed_Noise(data_dir, sample_rate)
            aug_transform.append(transforms.RandomApply(m_x))
            model_name = model_name + '_mx'
    else:
        if mx:
            m_x = Mixed_Noise(data_dir, sample_rate)
            aug_transform.append(transforms.RandomApply(m_x))
            model_name = model_name + '_mx'
        if wn:
            aug_transform.append(transforms.RandomApply(add_white_noise))
            model_name = model_name + '_wn'
        if ts:
            aug_transform.append(transforms.RandomApply(time_shift))
            model_name = model_name + '_ts'
        if ps:
            aug_transform.append(transforms.RandomApply(pitch_shift))
            model_name = model_name + '_ps'
        if fd:
            aug_transform.append(transforms.RandomApply(add_fade))
            model_name = model_name + '_fd'
        if tts:
            aug_transform.append(
                transforms.RandomApply(partial(time_stret, length=length)))
            model_name = model_name + '_tts'
        if tm:
            aug_transform.append(transforms.RandomApply(time_masking))
            model_name = model_name + '_tm'
    aug_transform = transforms.Compose(aug_transform)
    print(f"Model: {model_name}")

    net = Main(batch_size=batch_size,
               sampling_rate=sample_rate,
               data_dir=data_dir,
               dataset=dataset,
               perc=perc,
               nsynth_class=nsynth_class,
               n_classes=n_classes,
               train_transform=train_transform,
               aug_transform=aug_transform,
               spec_transform=spec_transform,
               model=model_name)

    model_path = os.path.join(MODELS_FOLDER, model_name, dataset_f)
    os.makedirs(model_path, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        filepath=model_path,
        save_last=True,
        mode='min',
        period=10,
        save_top_k=20000000,
    )
    if model_num is not None:
        checkpoint = os.path.join(model_path,
                                  get_last(os.listdir(model_path), model_num))
    elif os.path.exists(model_path) and len(os.listdir(model_path)) > 0:
        checkpoint = os.path.join(model_path, get_last(os.listdir(model_path)))
    else:
        checkpoint = None

    logger = TensorBoardLogger(save_dir=LOGS_FOLDER,
                               version=dataset_f,
                               name=model_name)

    # finetune in real-time
    print(f"Loading model: {checkpoint}")

    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device).squeeze()
        return x1, y

    online_eval = SSLOnlineEvaluator(hidden_dim=512,
                                     z_dim=512,
                                     num_classes=n_classes,
                                     train_transform=train_transform,
                                     data_dir=data_dir,
                                     dataset=dataset,
                                     batch_size=batch_size,
                                     nsynth_class=nsynth_class)
    online_eval.to_device = to_device

    trainer = Trainer(resume_from_checkpoint=checkpoint,
                      distributed_backend='ddp',
                      max_epochs=max_epochs,
                      sync_batchnorm=True,
                      checkpoint_callback=checkpoint_callback,
                      logger=logger,
                      gpus=-1 if train else 1,
                      log_save_interval=25,
                      callbacks=[online_eval])
    if train:
        trainer.fit(net)
    else:
        trainer.test(net)