コード例 #1
0
ファイル: sample.py プロジェクト: kassellevi1/DeepProject
def load_model(model, checkpoint, device):
    ckpt = torch.load(os.path.join('checkpoint', checkpoint))

    if 'args' in ckpt:
        args = ckpt['args']

    if model == 'vqvae':
        model = VQVAE()

    elif model == 'pixelsnail_bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=False,
            dropout=args.dropout,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.n_res_channel,
        )

    if 'model' in ckpt:
        ckpt = ckpt['model']

    model.load_state_dict(ckpt)
    model = model.to(device)
    model.eval()

    return model
コード例 #2
0
def main(args):
    torch.cuda.empty_cache()
    pl.trainer.seed_everything(seed=42)

    datamodule = LMDBDataModule(
        path=args.dataset_path,
        embedding_id=args.level,
        batch_size=args.batch_size,
        num_workers=5,
    )

    datamodule.setup()
    args.num_embeddings = datamodule.num_embeddings

    if args.use_model == 'pixelcnn':
        model = PixelCNN(args)
    elif args.use_model == 'pixelsnail':
        model = PixelSNAIL(args)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=1,
                                                       save_last=True,
                                                       monitor='val_loss_mean')
    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback])
    trainer.fit(model, datamodule=datamodule)
コード例 #3
0
                        num_workers=4,
                        drop_last=True)

    ckpt = {}

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    if args.hier == 'top':
        model = PixelSNAIL(
            [32, 32],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
コード例 #4
0
        dataset, batch_size=args.batch, shuffle=True, drop_last=True
    )

    ckpt = {}
    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    model = PixelSNAIL(
            [24, 24],
            args.n_embedding,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=(True if args.self_attention==1 else False),
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.cond_res_channel,
            n_class=25,
            cond_hidden_dim=args.cond_hidden_dim
        )

    if 'model' in ckpt:
        model.load_state_dict(ckpt['model'])

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    model = nn.DataParallel(model)
    model = model.to(device)
コード例 #5
0
        loader.set_description(
            (f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}')
        )


class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()


if __name__ == '__main__':
    device = 'cuda'
    epoch = 10

    dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
    loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(10):
        train(i, loader, model, optimizer, device)
        torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')
コード例 #6
0
                        num_workers=4,
                        drop_last=True)

    ckpt = {}

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    if args.hier == 'top':
        model = PixelSNAIL(
            [args.size // 8, args.size // 8],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [args.size // 4, args.size // 4],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
コード例 #7
0
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    print(args)

    device = 'cuda'

    dataset = LMDBDataset(args.path)
    loader = DataLoader(dataset,
                        batch_size=args.batch,
                        shuffle=True,
                        num_workers=4)

    if args.hier == 'top':
        model = PixelSNAIL([32, 32], 512, 256, 5, 4, 4, 256)

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            256,
            5,
            3,
            4,
            256,
            attention=False,
            n_cond_res_block=3,
            cond_res_channel=256,
        )
コード例 #8
0
ファイル: sample.py プロジェクト: Cynetics/MSGNet
def load_model(model, checkpoint, device):
    ckpt = torch.load(checkpoint)

    if 'args' in ckpt:
        args = ckpt['args']

    if model == 'vqvae':
        num_hiddens = 128
        num_residual_hiddens = 64
        num_residual_layers = 2
        embedding_dim = 64
        num_embeddings = 256

        model = Model(3,
                      num_hiddens,
                      num_residual_layers,
                      num_residual_hiddens,
                      num_embeddings,
                      embedding_dim,
                      0.25,
                      embedding_prior_input=True)

    if model == 'label_model':
        model = PixelSNAIL(
            [24, 24],
            25,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=(True if args.self_attention == 1 else False),
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.cond_res_channel,
            n_class=25,
            cond_hidden_dim=args.cond_hidden_dim)

    if model == 'pixelsnail':
        model = PixelSNAIL(
            [48, 24],
            args.n_embedding,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=(True if args.self_attention == 1 else False),
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.cond_res_channel,
            n_class=25,
            cond_hidden_dim=args.cond_hidden_dim)
        #vqvae = vqvae

    if 'model' in ckpt:
        ckpt = ckpt['model']

    model.load_state_dict(ckpt)
    model = model.to(device)
    model.eval()

    return model
コード例 #9
0
    elif args.img_size == 256:
        top_input = [32, 32]
        bottom_input = [64, 64]
    else:
        raise NotImplementedError('Currently supports img_size = 32 | 256')

    if not os.path.exists(args.out):
        os.mkdir(args.out)

    if args.hier == 'top':
        model = PixelSNAIL(
            top_input,
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            bottom_input,
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
コード例 #10
0
ファイル: train_pixelsnail.py プロジェクト: juvu/vq-vae-2
        dataset, batch_size=args.batch, shuffle=True, num_workers=4, drop_last=True
    )

    ckpt = {}

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt)
        args = ckpt['args']

    if args.hier == 'top':
        model = PixelSNAIL(
            [32, 32],
            32,  # represent for the class numbers of the embed space
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            dropout=args.dropout,
            n_out_res_block=args.n_out_res_block,
        )

    elif args.hier == 'bottom':
        model = PixelSNAIL(
            [64, 64],
            32, # represent for the class numbers of the embed space
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,