Esempio n. 1
0
def main(args):
    torch.cuda.empty_cache()
    pl.trainer.seed_everything(seed=42)

    datamodule = LMDBDataModule(
        path=args.dataset_path,
        embedding_id=args.level,
        batch_size=args.batch_size,
        num_workers=5,
    )

    datamodule.setup()
    args.num_embeddings = datamodule.num_embeddings

    if args.use_model == 'pixelcnn':
        model = PixelCNN(args)
    elif args.use_model == 'pixelsnail':
        model = PixelSNAIL(args)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=1,
                                                       save_last=True,
                                                       monitor='val_loss_mean')
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback])
    trainer.fit(model, datamodule=datamodule)
Esempio n. 2
0
def load_model(model, checkpoint, device):
    ckpt = torch.load(os.path.join('checkpoint', checkpoint))

    if 'args' in ckpt:
        args = ckpt['args']

    if model == 'vqvae':
        model = VQVAE()

    elif model == 'pixelsnail_bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=False,
            dropout=args.dropout,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.n_res_channel,
        )

    if 'model' in ckpt:
        ckpt = ckpt['model']

    model.load_state_dict(ckpt)
    model = model.to(device)
    model.eval()

    return model
Esempio n. 3
0
def parse_arguments():
    parser = ArgumentParser()
    parser.add_argument("--use-model",
                        type=str,
                        choices=['pixelcnn', 'pixelsnail'],
                        default='pixelcnn')
    use_model = parser.parse_known_args()[0].use_model

    if use_model == 'pixelcnn':
        parser = PixelCNN.add_model_specific_args(parser)
    elif use_model == 'pixelsnail':
        parser = PixelSNAIL.add_model_specific_args(parser)

    parser = pl.Trainer.add_argparse_args(parser)

    parser.add_argument("dataset_path", type=Path)
    parser.add_argument("level",
                        type=int,
                        help="Which PixelCNN hierarchy level to train")
    parser.add_argument("--batch-size", type=int)

    parser.set_defaults(gpus="-1",
                        distributed_backend='ddp',
                        benchmark=True,
                        num_sanity_val_steps=0,
                        precision=16,
                        log_every_n_steps=50,
                        val_check_interval=0.5,
                        flush_logs_every_n_steps=100,
                        weights_summary='full',
                        max_epochs=int(5e4))

    args = parser.parse_args()
    args.use_model = use_model

    assert args.dataset_path.resolve().exists()
    args.dataset_path = str(args.dataset_path.resolve())

    return args
Esempio n. 4
0
                        num_workers=4,
                        drop_last=True)

    ckpt = {}

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    if args.hier == 'top':
        model = PixelSNAIL(
            [32, 32],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
Esempio n. 5
0
        dataset, batch_size=args.batch, shuffle=True, drop_last=True
    )

    ckpt = {}
    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    model = PixelSNAIL(
            [24, 24],
            args.n_embedding,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=(True if args.self_attention==1 else False),
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.cond_res_channel,
            n_class=25,
            cond_hidden_dim=args.cond_hidden_dim
        )

    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    model = nn.DataParallel(model)
    model = model.to(device)
Esempio n. 6
0
        loader.set_description(
            (f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}')
        )


class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()


if __name__ == '__main__':
    device = 'cuda'
    epoch = 10

    dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
    loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(10):
        train(i, loader, model, optimizer, device)
        torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')
Esempio n. 7
0
            (f'epoch: {epoch + 1}; loss: {loss.item():.5f}; '
             f'acc: {accuracy:.5f}'))


class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()


if __name__ == '__main__':
    device = 'cuda'
    epoch = 10

    dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
    loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(10):
        train(i, loader, model, optimizer, device)
        torch.save(model.state_dict(),
                   f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')
Esempio n. 8
0
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    print(args)

    device = 'cuda'

    dataset = LMDBDataset(args.path)
    loader = DataLoader(dataset,
                        batch_size=args.batch,
                        shuffle=True,
                        num_workers=4)

    if args.hier == 'top':
        model = PixelSNAIL([32, 32], 512, 256, 5, 4, 4, 256)

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            256,
            5,
            3,
            4,
            256,
            attention=False,
            n_cond_res_block=3,
            cond_res_channel=256,
        )
Esempio n. 9
0
if __name__ == '__main__':
    device = 'cuda'

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=8)
    parser.add_argument('--vqvae', type=str)
    parser.add_argument('--top', type=str)
    parser.add_argument('--bottom', type=str)
    parser.add_argument('--temp', type=float, default=1.0)
    parser.add_argument('filename', type=str)

    args = parser.parse_args()

    model_vqvae = VQVAE()
    model_top = PixelSNAIL([32, 32], 512, 256, 5, 4, 4, 256)
    model_bottom = PixelSNAIL(
        [64, 64],
        512,
        256,
        5,
        3,
        4,
        256,
        attention=False,
        n_cond_res_block=3,
        cond_res_channel=256,
    )

    model_vqvae = load_model(model_vqvae, args.vqvae, device)
    model_top = load_model(model_top, args.top, device)
Esempio n. 10
0
def load_model(model, checkpoint, device):
    ckpt = torch.load(checkpoint)

    if 'args' in ckpt:
        args = ckpt['args']

    if model == 'vqvae':
        num_hiddens = 128
        num_residual_hiddens = 64
        num_residual_layers = 2
        embedding_dim = 64
        num_embeddings = 256

        model = Model(3,
                      num_hiddens,
                      num_residual_layers,
                      num_residual_hiddens,
                      num_embeddings,
                      embedding_dim,
                      0.25,
                      embedding_prior_input=True)

    if model == 'label_model':
        model = PixelSNAIL(
            [24, 24],
            25,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=(True if args.self_attention == 1 else False),
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.cond_res_channel,
            n_class=25,
            cond_hidden_dim=args.cond_hidden_dim)

    if model == 'pixelsnail':
        model = PixelSNAIL(
            [48, 24],
            args.n_embedding,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=(True if args.self_attention == 1 else False),
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.cond_res_channel,
            n_class=25,
            cond_hidden_dim=args.cond_hidden_dim)
        #vqvae = vqvae

    if 'model' in ckpt:
        ckpt = ckpt['model']

    model.load_state_dict(ckpt)
    model = model.to(device)
    model.eval()

    return model
Esempio n. 11
0
                        num_workers=4,
                        drop_last=True)

    ckpt = {}

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    if args.hier == 'top':
        model = PixelSNAIL(
            [32, 32],
            embedDim,  # represent for the class numbers of the embed space
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [64, 64],
            embedDim,  # represent for the class numbers of the embed space
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,