Пример #1
0
    # Wrap in PLModule, & configure metrics ####################
    lm = RnnPLModule(
        model,
        optimizer,
        criterion,
        lr_scheduler=lr_scheduler,
        metrics={"acc": FromLogits(pl.metrics.classification.Accuracy())},
        hparams=config,
    )

    # Run debugging session or fit & test the model ############

    if config.debug:
        logger.info("Running in debug mode: Fast run on 5 batches")
        trainer = make_trainer(fast_dev_run=5)
        trainer.fit(lm, datamodule=ldm)

        logger.info("Running in debug mode: Overfitting 5 batches")
        trainer = make_trainer(overfit_batches=5)
        trainer.fit(lm, datamodule=ldm)

    else:
        trainer = make_trainer(**config.trainer)
        watch_model(trainer, model)

        trainer.fit(lm, datamodule=ldm)

        trainer.test(ckpt_path="best", test_dataloaders=ldm.test_dataloader())

        logger.info("Run finished. Uploading files to wandb...")
    model = TransformerLM(
        vocab_size=ldm.vocab_size,
        num_layers=2,
        hidden_size=200,
        num_heads=2,
        inner_size=256,
        dropout=0.2,
        tie_weights=True,
    )

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=lr)
    criterion = nn.CrossEntropyLoss()

    lm = TransformerPLModule(
        model,
        optimizer,
        criterion,
        calculate_perplexity=True,
    )

    trainer = make_trainer(EXPERIMENT_NAME,
                           max_epochs=100,
                           gpus=1,
                           gradient_clip_val=0.25)
    watch_model(trainer, model)

    trainer.fit(lm, datamodule=ldm)

    trainer.test(ckpt_path="best", test_dataloaders=ldm.test_dataloader())