Ejemplo n.º 1
0
def reproduce(
    n_epochs=457, batch_size=64, log_dir="/tmp/run", device="cuda", debug_loader=None
):
    """Training script with defaults to reproduce results.

    The code inside this function is self contained and can be used as a top level
    training script, e.g. by copy/pasting it into a Jupyter notebook.

    Args:
        n_epochs: Number of epochs to train for.
        batch_size: Batch size to use for training and evaluation.
        log_dir: Directory where to log trainer state and TensorBoard summaries.
        device: Device to train on (either 'cuda' or 'cpu').
        debug_loader: Debug DataLoader which replaces the default training and
            evaluation loaders if not 'None'. Do not use unless you're writing unit
            tests.
    """
    from torch import optim
    from torch.nn import functional as F
    from torch.optim import lr_scheduler

    from pytorch_generative import datasets
    from pytorch_generative import models
    from pytorch_generative import trainer

    train_loader, test_loader = debug_loader, debug_loader
    if train_loader is None:
        train_loader, test_loader = datasets.get_mnist_loaders(
            batch_size, dynamically_binarize=True
        )

    model = models.ImageGPT(
        in_channels=1,
        out_channels=1,
        in_size=28,
        n_transformer_blocks=8,
        n_attention_heads=2,
        n_embedding_channels=64,
    )
    optimizer = optim.Adam(model.parameters(), lr=5e-3)
    scheduler = lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda _: 0.999977)

    def loss_fn(x, _, preds):
        batch_size = x.shape[0]
        x, preds = x.view((batch_size, -1)), preds.view((batch_size, -1))
        loss = F.binary_cross_entropy_with_logits(preds, x, reduction="none")
        return loss.sum(dim=1).mean()

    model_trainer = trainer.Trainer(
        model=model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        train_loader=train_loader,
        eval_loader=test_loader,
        lr_scheduler=scheduler,
        log_dir=log_dir,
        device=device,
    )
    model_trainer.interleaved_train_and_eval(n_epochs)
Ejemplo n.º 2
0
 def test_ImageGPT(self):
     model = models.ImageGPT(in_channels=3,
                             in_size=5,
                             out_channels=3,
                             n_transformer_blocks=1,
                             n_attention_heads=2,
                             n_embedding_channels=4)
     self._smoke_test(model, in_channels=3)
Ejemplo n.º 3
0
 def test_ImageGPT(self):
   # Test ImageGPT using MaskedAttention.
   model = models.ImageGPT(in_channels=1,
                           in_size=5,
                           out_channels=1,
                           n_transformer_blocks=1,
                           n_attention_heads=2,
                           n_embedding_channels=4)
   self._smoke_test(model)
Ejemplo n.º 4
0
 def test_ImageGPT(self):
     model = models.ImageGPT(
         in_channels=3,
         out_channels=3,
         in_size=8,
         n_transformer_blocks=1,
         n_attention_heads=2,
         n_embedding_channels=4,
     )
     self._test_multiple_channels(model, conditional_sample=True)