Пример #1
0
def main(hparams, model = None):
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    if model == None:
      model = GAN(hparams)

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    os.makedirs(hparams.log_dir, exist_ok=True)
    try:
        log_dir = sorted(os.listdir(hparams.log_dir))[-1]
    except IndexError:
        log_dir = os.path.join(hparams.log_dir, 'version_0')

    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(log_dir, 'checkpoints'),
        save_top_k=-1,
        verbose=True,
    )
    trainer = Trainer(gpus = 1,limit_train_batches=1.0, max_epochs=5, benchmark=True,
                      checkpoint_callback = checkpoint_callback)

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)

    return  model
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
    """test that auto_add_dataloader_idx argument works."""
    class TestModel(BoringModel):
        def val_dataloader(self):
            dl = super().val_dataloader()
            return [dl, dl]

        def validation_step(self, *args, **kwargs):
            output = super().validation_step(*args[:-1], **kwargs)
            if add_dataloader_idx:
                name = "val_loss"
            else:
                name = f"val_loss_custom_naming_{args[-1]}"

            self.log(name, output["x"], add_dataloader_idx=add_dataloader_idx)
            return output

    model = TestModel()
    model.validation_epoch_end = None

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
    trainer.fit(model)
    logged = trainer.logged_metrics

    # Check that the correct keys exist
    if add_dataloader_idx:
        assert "val_loss/dataloader_idx_0" in logged
        assert "val_loss/dataloader_idx_1" in logged
    else:
        assert "val_loss_custom_naming_0" in logged
        assert "val_loss_custom_naming_1" in logged
Пример #3
0
def main(args: Namespace) -> None:
    # init
    model = WGAN(**vars(args))

    trainer = Trainer(max_epochs=20, gpus=args.gpus)

    trainer.fit(model)
Пример #4
0
def run_encoder(train, test, epochs):
    """
    Instances and runs autoencoder.

    Parameters:
        train (pandas.DataFrame): DataFrame of training data
        test (pandas.DataFrame): DataFrame of testing data
        epochs (int): Training epochs

    Returns:
        Autoencoder loss on test data
    """
    # Instances training dataset
    data_train = MELoader(train)

    # Instances testing dataset
    data_test = MELoader(test)

    # Instances non-mechanistic autoencoder
    feats = data_train.data.shape[1]
    encoder = NMEncoder(feats, feats // 2)

    # Instances PyTorch Lightning trainer
    trainer = Trainer(gpus=1, num_nodes=1, max_epochs=epochs)

    # Performs model fitting on training set
    trainer.fit(encoder, DataLoader(dataset=data_train))

    # Performs test on testing set
    performance = trainer.test(encoder, DataLoader(dataset=data_test))

    return performance[0]["test_loss"]
Пример #5
0
def test_val_check_interval_third(tmpdir, max_epochs):

    class TestModel(BoringModel):

        def __init__(self):
            super().__init__()
            self.train_epoch_calls = 0
            self.val_epoch_calls = 0

        def on_train_epoch_start(self) -> None:
            self.train_epoch_calls += 1

        def on_validation_epoch_start(self) -> None:
            if not self.trainer.running_sanity_check:
                self.val_epoch_calls += 1

    model = TestModel()
    trainer = Trainer(
        max_epochs=max_epochs,
        val_check_interval=0.33,
        logger=False,
    )
    trainer.fit(model)

    assert model.val_epoch_calls == max_epochs * 3
Пример #6
0
def test_check_val_every_n_epoch(tmpdir, max_epochs, expected_val_loop_calls,
                                 expected_val_batches):
    class TestModel(BoringModel):
        val_epoch_calls = 0
        val_batches = []

        def on_train_epoch_end(self, *args, **kwargs):
            self.val_batches.append(
                self.trainer.progress_bar_callback.total_val_batches)

        def on_validation_epoch_start(self) -> None:
            self.val_epoch_calls += 1

    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=max_epochs,
        num_sanity_val_steps=0,
        limit_val_batches=2,
        check_val_every_n_epoch=2,
        logger=False,
    )
    trainer.fit(model)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    assert model.val_epoch_calls == expected_val_loop_calls
    assert model.val_batches == expected_val_batches
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders):
    """
    Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario
    """

    os.environ['PL_DEV_DEBUG'] = '1'

    class TestModel(BoringModel):

        test_losses = {}

        @Helper.decorator_with_arguments(fx_name="test_step")
        def test_step(self, batch, batch_idx, dl_idx=0):
            output = self.layer(batch)
            loss = self.loss(batch, output)

            primary_key = str(dl_idx)
            if primary_key not in self.test_losses:
                self.test_losses[primary_key] = []

            self.test_losses[primary_key].append(loss)

            self.log("test_loss", loss, on_step=True, on_epoch=True)
            return {"test_loss": loss}

        def test_dataloader(self):
            return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]

    model = TestModel()
    model.val_dataloader = None
    model.test_epoch_end = None

    limit_test_batches = 4

    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=0,
        limit_val_batches=0,
        limit_test_batches=limit_test_batches,
        max_epochs=1,
        log_every_n_steps=1,
        weights_summary=None,
    )
    trainer.test(model)

    test_results = trainer.logger_connector._cached_results["test"]

    generated = test_results(fx_name="test_step")
    assert len(generated) == num_dataloaders

    for dl_idx in range(num_dataloaders):
        generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx)))
        assert generated == limit_test_batches

    test_results.has_batch_loop_finished = True

    for dl_idx in range(num_dataloaders):
        expected = torch.stack(model.test_losses[str(dl_idx)]).mean()
        generated = test_results(fx_name="test_step", dl_idx=str(dl_idx), reduced=True)["test_loss_epoch"]
        assert abs(expected.item() - generated.item()) < 1e-6
Пример #8
0
def main():
    # create the model
    model = VGGNet(num_classes=10)

    # create data loader
    train_loader = create_dataloader('CIFAR10',
                                     './dataset/cifar10',
                                     split='train',
                                     batch_size=128)
    val_loader = create_dataloader('CIFAR10',
                                   './dataset/cifar10',
                                   split='val',
                                   batch_size=128,
                                   shuffle=False)

    # create the trainer
    trainer = Trainer(
        max_epochs=100,
        gpus=1,
        auto_select_gpus=True,
        callbacks=[
            LearningRateMonitor(logging_interval='step'),
            EarlyStopping(monitor='val_loss', patience=15)
        ],
    )
    trainer.fit(model, train_loader, val_loader)
    def handle_return(self, trainer: Trainer) -> None:
        """Writes a PyTorch Lightning trainer.

        Args:
            trainer: A PyTorch Lightning trainer object.
        """
        super().handle_return(trainer)
        trainer.save_checkpoint(
            os.path.join(self.artifact.uri, CHECKPOINT_NAME))
Пример #10
0
def run_encoder(train,
                test,
                epochs,
                width,
                depth,
                dropout_prob=0.2,
                reg_coef=0):
    """
    Instances and runs extendable autoencoder.

    Parameters:
        train (pandas.DataFrame): DataFrame of training data
        test (pandas.DataFrame): DataFrame of testing data
        epochs (int): Training epochs
        width (int): Number of latent attributes
        depth (int): Number of encoding/decoding layers
        dropout_prob (float, default=0.2): Probability of drop-out
        reg_coef (float, default=0): Regularization coefficient

    Returns:
        Autoencoder loss on test data
    """
    # Instances training dataset
    data_train = MELoader(train)

    # Instances testing dataset
    data_test = MELoader(test)

    # Instances non-mechanistic autoencoder
    feats = data_train.data.shape[1]
    encoder = NMEncoder(feats,
                        width,
                        dropout_prob=dropout_prob,
                        n_layers=depth,
                        reg_coef=reg_coef)

    # Instances PyTorch Lightning trainer
    trainer = Trainer(
        auto_scale_batch_size=True,
        auto_select_gpus=True,
        checkpoint_callback=False,
        gpus=1,
        logger=False,
        max_epochs=epochs,
        # progress_bar_refresh_rate=0,
        weights_summary=None,
    )

    # Performs model fitting on training set
    trainer.fit(encoder, DataLoader(dataset=data_train))

    # Performs test on testing set
    performance = trainer.test(encoder, DataLoader(dataset=data_test))
    loss = performance[0]["test_loss"]
    latent = performance[0]["latent"]

    return loss, latent
Пример #11
0
def main(hparams):
    """
    Main training routine specific for this project
    """
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    print('loading model...')
    model = DSANet(hparams)
    print('model built')

    # ------------------------
    # 2 INIT TEST TUBE EXP
    # ------------------------

    # init experiment
    exp = Experiment(
        name='dsanet_exp_{}_window={}_horizon={}'.format(hparams.data_name, hparams.window, hparams.horizon),
        save_dir=hparams.test_tube_save_path,
        autosave=False,
        description='test demo'
    )

    exp.argparse(hparams)
    exp.save()

    # ------------------------
    # 3 DEFINE CALLBACKS
    # ------------------------
    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
    early_stop = EarlyStopping(
        monitor='val_loss',
        patience=5,
        verbose=True,
        mode='min'
    )

    # ------------------------
    # 4 INIT TRAINER
    # ------------------------
    trainer = Trainer(
        gpus=[0],
        # auto_scale_batch_size=True,
        max_epochs=10,
        # num_processes=2,
        # num_nodes=2
        
    )

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model)

    print('View tensorboard logs by running\ntensorboard --logdir %s' % os.getcwd())
    print('and going to http://localhost:6006 on your browser')
def main():
    wandb.login()
    wandb.init(entity='hyeonsu')

    model = CGAN()
    dataset = MNIST_dataset()
    wandb_logger = WandbLogger(project='CGAN-Wandb')
    trainer = Trainer(gpus=1, logger=wandb_logger, max_epochs=50)

    trainer.fit(model, dataset)
Пример #13
0
def main(config, resume: bool):

    model = FOTSModel(config)
    if resume:
        assert pathlib.Path(config.pretrain).exists()
        resume_ckpt = config.pretrain
        logger.info('Resume training from: {}'.format(config.pretrain))
    else:
        if config.pretrain:
            assert pathlib.Path(config.pretrain).exists()
            logger.info('Finetune with: {}'.format(config.pretrain))
            model = model.load_from_checkpoint(config.pretrain, config=config, map_location='cpu')
            resume_ckpt = None
        else:
            resume_ckpt = None

    if config.data_loader.dataset == 'synth800k':
        data_module = SynthTextDataModule(config)
    else:
        data_module = ICDARDataModule(config)
    data_module.setup()

    root_dir = str(pathlib.Path(config.trainer.save_dir).absolute() / config.name)
    checkpoint_callback = ModelCheckpoint(dirpath=root_dir + '/checkpoints', every_n_train_steps=config.trainer.every_n_train_steps)
    wandb_dir = pathlib.Path(root_dir) / 'wandb'
    if not wandb_dir.exists():
        wandb_dir.mkdir(parents=True, exist_ok=True)
    wandb_logger = WandbLogger(name=config.name,
                               project='FOTS',
                               config=config,
                               save_dir=root_dir)
    if not config.cuda:
        gpus = 0
    else:
        gpus = config.gpus

    trainer = Trainer(
        logger=wandb_logger,
        callbacks=[checkpoint_callback],
        max_epochs=config.trainer.epochs,
        default_root_dir=root_dir,
        gpus=gpus,
        accelerator='ddp',
        benchmark=True,
        sync_batchnorm=True,
        precision=config.precision,
        log_gpu_memory=config.trainer.log_gpu_memory,
        log_every_n_steps=config.trainer.log_every_n_steps,
        overfit_batches=config.trainer.overfit_batches,
        weights_summary='full',
        terminate_on_nan=config.trainer.terminate_on_nan,
        fast_dev_run=config.trainer.fast_dev_run,
        check_val_every_n_epoch=config.trainer.check_val_every_n_epoch,
        resume_from_checkpoint=resume_ckpt)
    trainer.fit(model=model, datamodule=data_module)
def test_logged_metrics_has_logged_epoch_value(tmpdir):
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            self.log("epoch", -batch_idx, logger=True)
            return super().training_step(batch, batch_idx)

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
    trainer.fit(model)

    # should not get overridden if logged manually
    assert trainer.logged_metrics == {"epoch": -1}
def test_logging_to_progress_bar_with_reserved_key(tmpdir):
    """Test that logging a metric with a reserved name to the progress bar raises a warning."""

    class TestModel(BoringModel):
        def training_step(self, *args, **kwargs):
            output = super().training_step(*args, **kwargs)
            self.log("loss", output["loss"], prog_bar=True)
            return output

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
        trainer.fit(model)
Пример #16
0
def test_val_check_interval_info_message(caplog, value):
    with caplog.at_level(logging.INFO):
        Trainer(val_check_interval=value)
    assert f"`Trainer(val_check_interval={value})` was configured" in caplog.text
    message = "configured so validation will run"
    assert message in caplog.text

    caplog.clear()

    # the message should not appear by default
    with caplog.at_level(logging.INFO):
        Trainer()
    assert message not in caplog.text
def test_can_return_tensor_with_more_than_one_element(tmpdir):
    """Ensure {validation,test}_step return values are not included as callback metrics.

    #6623
    """
    class TestModel(BoringModel):
        def validation_step(self, batch, *args, **kwargs):
            return {"val": torch.tensor([0, 1])}

        def validation_epoch_end(self, outputs):
            # ensure validation step returns still appear here
            assert len(outputs) == 2
            assert all(list(d) == ["val"] for d in outputs)  # check keys
            assert all(
                torch.equal(d["val"], torch.tensor([0, 1]))
                for d in outputs)  # check values

        def test_step(self, batch, *args, **kwargs):
            return {"test": torch.tensor([0, 1])}

        def test_epoch_end(self, outputs):
            assert len(outputs) == 2
            assert all(list(d) == ["test"] for d in outputs)  # check keys
            assert all(
                torch.equal(d["test"], torch.tensor([0, 1]))
                for d in outputs)  # check values

    model = TestModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=2,
                      enable_progress_bar=False)
    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
Пример #18
0
def run(config: argparse.Namespace):
    logger = TensorBoardLogger(save_dir=os.path.join(
        config.metaconf["ws_path"], 'tensorboard_logs'),
                               name=config.metaconf["experiment_name"])
    trainer = Trainer(
        gpus=config.metaconf["ngpus"],
        distributed_backend="dp",
        max_epochs=config.hyperparams["max_epochs"],
        logger=logger,
        # truncated_bptt_steps=10
    )

    # Start training
    model = CRLSModel(config)
    trainer.fit(model)
Пример #19
0
def main(hparams):
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = GAN(hparams)

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    trainer = Trainer()

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)
def main(args: Namespace) -> None:
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = GAN(**vars(args))

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    # If use distubuted training  PyTorch recommends to use DistributedDataParallel.
    # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
    trainer = Trainer()

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)
Пример #21
0
def run(config: argparse.Namespace):
    logger = TensorBoardLogger(
        save_dir=os.path.join(config.metaconf["ws_path"], "tensorboard_logs"),
        name=config.metaconf["experiment_name"],
    )
    trainer = Trainer(
        gpus=config.metaconf["ngpus"],
        distributed_backend="dp",
        max_epochs=config.metaconf["max_epochs"],
        logger=logger,
        # truncated_bptt_steps=10
    )

    # Start training
    vae_model = get_vae_model(config.model_params["vae_model"])(**config.model_params)
    model = VAEExperiment(vae_model, config)
    trainer.fit(model)
Пример #22
0
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
    """Make sure DDP works with dataloaders passed to fit()"""
    tutils.set_random_main_port()

    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_progress_bar=False,
        max_epochs=1,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        gpus=[0, 1],
        strategy="ddp_spawn",
    )
    trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader())
    assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()."
Пример #23
0
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
    """Make sure DDP works with dataloaders passed to fit()"""
    tutils.set_random_master_port()

    model = EvalModelTemplate()
    fit_options = dict(train_dataloader=model.train_dataloader(),
                       val_dataloaders=model.val_dataloader())

    trainer = Trainer(default_root_dir=tmpdir,
                      progress_bar_refresh_rate=0,
                      max_epochs=1,
                      limit_train_batches=0.2,
                      limit_val_batches=0.2,
                      gpus=[0, 1],
                      distributed_backend='ddp_spawn')
    result = trainer.fit(model, **fit_options)
    assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
    def handle_input(self, data_type: Type[Any]) -> Trainer:
        """Reads and returns a PyTorch Lightning trainer.

        Returns:
            A PyTorch Lightning trainer object.
        """
        super().handle_input(data_type)
        return Trainer(resume_from_checkpoint=os.path.join(
            self.artifact.uri, CHECKPOINT_NAME))
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
    """Make sure DDP works with dataloaders passed to fit()"""
    tutils.set_random_master_port()

    model = BoringModel()
    fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())

    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=1,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        gpus=[0, 1],
        accelerator="ddp_spawn",
    )
    trainer.fit(model, **fit_options)
    assert trainer.state.finished, "DDP doesn't work with dataloaders passed to fit()."
Пример #26
0
def train_self_supervised():
    logger = TensorBoardLogger('runs', name='SimCLR_libri_speech')

    # 8, 224, 8 worked well
    # 16, 224, 4 as well
    batch_size = 16
    input_height = 224
    num_workers = 4

    train_dataset = LibrispeechSpectrogramDataset(
        transform=SimCLRTrainDataTransform(input_height=input_height,
                                           gaussian_blur=False),
        train=True)
    val_dataset = LibrispeechSpectrogramDataset(
        transform=SimCLREvalDataTransform(input_height=input_height,
                                          gaussian_blur=False),
        train=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers)
    test_loader = DataLoader(val_dataset,
                             batch_size=batch_size,
                             num_workers=num_workers)

    model = SimCLR(gpus=1,
                   num_samples=len(train_dataset),
                   batch_size=batch_size,
                   dataset='librispeech')

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=r'D:\Users\lVavrek\research\data',
        filename="self-supervised-librispeech-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    early_stopping = EarlyStopping(monitor="val_loss")

    trainer = Trainer(gpus=1,
                      callbacks=[checkpoint_callback, early_stopping],
                      logger=logger)
    trainer.fit(model, train_loader, test_loader)
Пример #27
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", choices=["srcnn", "srgan"], required=True)
    parser.add_argument("--scale_factor", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--patch_size", type=int, default=96)
    parser.add_argument("--gpus", type=str, default="0")
    opt = parser.parse_args()

    # load model class
    if opt.model == "srcnn":
        model = models.SRCNNModel
    elif opt.model == "srgan":
        model = models.SRGANModel
    else:
        raise RuntimeError(opt.model)

    # add model specific arguments to original parser
    parser = model.add_model_specific_args(parser)
    opt = parser.parse_args()

    # instantiate experiment
    exp = test_tube.TestTubeLogger(save_dir=f"./logs/{opt.model}")
    exp.experiment.argparse(opt)

    model = model(opt)

    # define callbacks
    checkpoint_callback = ModelCheckpoint(
        filepath=exp.experiment.get_media_path(exp.name, exp.version), )

    # instantiate trainer
    trainer = Trainer(
        logger=exp,
        max_nb_epochs=4000,
        row_log_interval=50,
        check_val_every_n_epoch=10,
        checkpoint_callback=checkpoint_callback,
        gpus=[int(i)
              for i in opt.gpus.split(",")] if opt.gpus != "-1" else None,
    )

    # start training!
    trainer.fit(model)
Пример #28
0
def test_lightning_integration(tmp_dir):
    # init model
    model = LitMNIST()
    # init logger
    dvclive_logger = DvcLiveLogger("test_run", path="logs")
    trainer = Trainer(logger=dvclive_logger,
                      max_epochs=1,
                      checkpoint_callback=False)
    trainer.fit(model)

    assert os.path.exists("logs")
    assert not os.path.exists("DvcLiveLogger")

    logs, _ = read_logs(tmp_dir / "logs" / Scalar.subfolder)

    assert len(logs) == 3
    assert "train_loss_step" in logs
    assert "train_loss_epoch" in logs
    assert "epoch" in logs
def test_val_check_interval(tmpdir, max_epochs, denominator):
    class TestModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.train_epoch_calls = 0
            self.val_epoch_calls = 0

        def on_train_epoch_start(self) -> None:
            self.train_epoch_calls += 1

        def on_validation_epoch_start(self) -> None:
            if not self.trainer.sanity_checking:
                self.val_epoch_calls += 1

    model = TestModel()
    trainer = Trainer(max_epochs=max_epochs, val_check_interval=1 / denominator, logger=False)
    trainer.fit(model)

    assert model.train_epoch_calls == max_epochs
    assert model.val_epoch_calls == max_epochs * denominator
Пример #30
0
def train(hparams):
    NUM_GPUS = hparams.num_gpus
    USE_AMP = False  # True if NUM_GPUS > 1 else False
    MAX_EPOCHS = 50

    dataset = load_link_dataset(hparams.dataset, hparams=hparams)
    hparams.n_classes = dataset.n_classes

    model = LATTELinkPredictor(hparams,
                               dataset,
                               collate_fn="triples_batch",
                               metrics=[hparams.dataset])
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               project="multiplex-comparison")

    trainer = Trainer(
        gpus=NUM_GPUS,
        distributed_backend='ddp' if NUM_GPUS > 1 else None,
        auto_lr_find=False,
        max_epochs=MAX_EPOCHS,
        early_stop_callback=EarlyStopping(monitor='val_loss',
                                          patience=10,
                                          min_delta=0.01,
                                          strict=False),
        logger=wandb_logger,
        # regularizers=regularizers,
        weights_summary='top',
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32)

    trainer.fit(model)
    trainer.test(model)