コード例 #1
0
def visualize_spectrograms():
    logger = TensorBoardLogger('runs', name='SimCLR_libri_speech')

    td = LibrispeechSpectrogramDataset(transform=None, train=True)

    toTensor = transforms.ToTensor()
    samples = []
    for i in range(0, 24):
        img, cls = td.__getitem__(i)
        img = toTensor(img)
        samples.append(img)

    grid = torchvision.utils.make_grid(samples, padding=10, nrow=4)

    logger.experiment.add_image("generated_images", grid, 0)
    logger.finalize("success")
コード例 #2
0
ファイル: lighting.py プロジェクト: sssilvar/VAE
def train_vae(model, min_epochs, max_epochs, train_loader, val_loader=None, logger=True, logger_path='model_logs'):
    import torch
    from pytorch_lightning import Trainer
    from pytorch_lightning.callbacks import EarlyStopping
    from pytorch_lightning.loggers import TensorBoardLogger

    n_gpus = torch.cuda.device_count()
    auto_select_gpus = True if n_gpus >= 2 else False

    loss_monitor_key = 'Validation/Loss' if val_loader is not None else 'Train/Loss'
    early_stop_loss = EarlyStopping(loss_monitor_key, patience=5, mode='min')

    if logger:
        logger = TensorBoardLogger(logger_path, name='VAE')
    trainer = Trainer(min_epochs=min_epochs, max_epochs=max_epochs, logger=logger,
                      callbacks=[early_stop_loss],
                      gpus=n_gpus, auto_select_gpus=auto_select_gpus)
    trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

    if logger is not None and logger is not False:
        logger.finalize('success')
    return model
コード例 #3
0
def test_tensorboard_finalize(summary_writer, tmpdir):
    """ Test that the SummaryWriter closes in finalize. """
    logger = TensorBoardLogger(save_dir=tmpdir)
    logger.finalize("any")
    summary_writer().flush.assert_called()
    summary_writer().close.assert_called()
コード例 #4
0
    def objective(trial):

        if hparams.version is None:
            hparams.version = str(uuid1())

        # main LightningModule
        pretrain_system = PreTrainSystem(
            learning_rate=trial.suggest_loguniform("learning_rate", 1e-5, 1e-2),
            beta_1=hparams.beta_1,
            beta_2=hparams.beta_2,
            weight_decay=trial.suggest_uniform("weight_decay", 1e-5, 1e-2),
            optimizer=hparams.optimizer,
            batch_size=hparams.batch_size,
            multiplier=hparams.multiplier,
            scheduler_patience=hparams.scheduler_patience,
        )

        pretrain_checkpoints = ModelCheckpoint(
            dirpath=MODEL_CHECKPOINTS_DIR,
            monitor="Val/loss_epoch",
            verbose=True,
            mode="min",
            save_top_k=hparams.save_top_k,
        )

        pretrain_early_stopping = EarlyStopping(
            monitor="Val/loss_epoch",
            min_delta=0.00,
            patience=hparams.patience,
            verbose=False,
            mode="min",
        )

        pretrain_gpu_stats_monitor = GPUStatsMonitor(temperature=True)

        log_recoloring_to_tensorboard = LogPairRecoloringToTensorboard()

        optuna_pruning = PyTorchLightningPruningCallback(monitor="Val/loss_epoch", trial=trial)

        logger = TensorBoardLogger(
            S3_LIGHTNING_LOGS_DIR,
            name=hparams.name,
            version=hparams.version,
            log_graph=True,
            default_hp_metric=False,
        )

        trainer = Trainer.from_argparse_args(
            hparams,
            logger=logger,
            checkpoint_callback=pretrain_checkpoints,
            callbacks=[
                pretrain_early_stopping,
                log_recoloring_to_tensorboard,
                pretrain_gpu_stats_monitor,
                optuna_pruning,
            ],
            profiler="simple",
        )

        datamodule = PreTrainDataModule(
            batch_size=pretrain_system.hparams.batch_size,
            multiplier=pretrain_system.hparams.multiplier,
            shuffle=hparams.shuffle,
            num_workers=hparams.num_workers,
            size=hparams.size,
            pin_memory=hparams.pin_memory,
            train_batch_from_same_image=hparams.train_batch_from_same_image,
            val_batch_from_same_image=hparams.val_batch_from_same_image,
            test_batch_from_same_image=hparams.test_batch_from_same_image,
        )

        # trainer.tune(pretrain_system, datamodule=datamodule)

        trainer.fit(pretrain_system, datamodule=datamodule)

        # get best checkpoint
        best_model_path = pretrain_checkpoints.best_model_path

        pretrain_system = PreTrainSystem.load_from_checkpoint(best_model_path)

        test_result = trainer.test(pretrain_system, datamodule=datamodule)

        pretrain_system.hparams.test_metric_name = test_result[0]["Test/loss_epoch"]
        logger.log_hyperparams(pretrain_system.hparams)
        logger.finalize(status="success")

        # upload best model to S3
        S3_best_model_path = os.path.join(
            S3_MODEL_CHECKPOINTS_RELATIVE_DIR,
            hparams.name,
            ".".join([hparams.version, best_model_path.split(".")[-1]]),
        )
        upload_to_s3(best_model_path, S3_best_model_path)

        return test_result[0]["Test/loss_epoch"]
コード例 #5
0
checkpointer = ModelCheckpoint(filepath=checkpoint_filename + '.ckpt',
                               monitor='swa_loss_no_reg')

trainer = Trainer(gpus=1,
                  num_nodes=1,
                  max_epochs=epochs,
                  logger=logger,
                  callbacks=[lr_logger],
                  checkpoint_callback=checkpointer,
                  benchmark=True,
                  terminate_on_nan=True,
                  gradient_clip_val=max_l2_norm)

try:
    trainer.fit(swag_model)
except ValueError:
    print("Model", checkpoint_filename, 'exited early!', flush=True)
    exit(1)

# Save model:

logger.log_hyperparams(
    params=swag_model.hparams,
    metrics={'swa_loss_no_reg': checkpointer.best_model_score.item()})
logger.save()
logger.finalize('success')

spock_reg_model.save_swag(swag_model, output_filename + '.pkl')
import pickle as pkl
pkl.dump(swag_model.ssX, open(output_filename + '_ssX.pkl', 'wb'))