Beispiel #1
0
def train(args):
    if args.pretrained is not None:
        model = ImageGPT.load_from_checkpoint(args.pretrained)
        # potentially modify model for classification
        model.hparams = args
    else:
        model = ImageGPT(args)

    logger = pl.logging.TensorBoardLogger("logs", name=args.name)

    if args.classify:
        # classification
        # stop early for best validation accuracy for finetuning
        early_stopping = pl.callbacks.EarlyStopping(monitor="val_acc",
                                                    patience=3)
        checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_acc")
        trainer = pl.Trainer(
            max_steps=args.steps,
            gpus=args.gpus,
            #             early_stopping_callback=early_stopping,
            checkpoint_callback=checkpoint,
            logger=logger,
        )

    else:
        # pretraining
        checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss")
        trainer = pl.Trainer(
            max_steps=args.steps,
            gpus=args.gpus,
            checkpoint_callback=checkpoint,
            logger=logger,
        )

    trainer.fit(model)
Beispiel #2
0
def main(args):
    model = ImageGPT.load_from_checkpoint(args.checkpoint).gpt.cuda()

    test_x = np.load(args.test_x)
    np.random.shuffle(test_x)

    # rows for figure
    rows = []

    for example in tqdm(range(args.num_examples)):
        img = test_x[example]
        seq = img.reshape(-1)

        # first half of image is context
        context = seq[:int(len(seq) / 2)]
        context_img = np.pad(context,
                             (0, int(len(seq) / 2))).reshape(*img.shape)
        context = torch.from_numpy(context).cuda()

        # predict second half of image
        preds = (sample(
            model, context, int(len(seq) / 2),
            num_samples=args.num_samples).cpu().numpy().transpose())

        preds = preds.reshape(-1, *img.shape)

        # combine context, preds, and truth for figure
        rows.append(
            np.concatenate([context_img[None, ...], preds, img[None, ...]],
                           axis=0))

    figure = make_figure(rows)
    figure.save("figure.png")
Beispiel #3
0
def main(args):
    model = ImageGPT.load_from_checkpoint(args.checkpoint).gpt.cuda()
    model.eval()

    context = torch.zeros(0, dtype=torch.long).cuda()

    frames = generate(model,
                      context,
                      28 * 28,
                      num_samples=args.rows * args.cols)

    pad_frames = []
    for frame in frames:
        pad = ((0, 0), (0, 28 * 28 - frame.shape[1]))
        pad_frames.append(np.pad(frame, pad_width=pad))

    pad_frames = np.stack(pad_frames)
    f, n, _ = pad_frames.shape
    pad_frames = pad_frames.reshape(f, args.rows, args.cols, 28, 28)
    pad_frames = pad_frames.swapaxes(2, 3).reshape(f, 28 * args.rows,
                                                   28 * args.cols)
    pad_frames = pad_frames[..., np.newaxis] * np.ones(3) * 17
    pad_frames = pad_frames.astype(np.uint8)

    clip = ImageSequenceClip(list(pad_frames)[::args.downsample],
                             fps=args.fps).resize(args.scale)
    clip.write_gif("out.gif", fps=args.fps)
Beispiel #4
0
import pytorch_lightning as pl
import pytorch_lightning.logging
import argparse

from module import ImageGPT

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = ImageGPT.add_model_specific_args(parser)

    parser.add_argument("--train_x", default="./data/train_x.npy")
    parser.add_argument("--train_y", default="./data/train_y.npy")
    parser.add_argument("--test_x", default="./data/test_x.npy")
    parser.add_argument("--test_y", default="./data/test_y.npy")

    parser.add_argument("--gpus", default="0")
    subparsers = parser.add_subparsers()
    parser.add_argument("--pretrained", type=str, default=None)
    # parser_train.add_argument("-n", "--name", type=str, required=True)
    parser.add_argument("-n", "--name", type=str, default='fmnist_gen')

    parser_test = subparsers.add_parser("test")
    parser_test.add_argument("checkpoint", type=str)

    args = parser.parse_args()

    logger = pl.logging.TensorBoardLogger("logs", name=args.name)

    model = ImageGPT(args)
    checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss")
    trainer = pl.Trainer(
Beispiel #5
0
def test(args):
    trainer = pl.Trainer(gpus=args.gpus)
    model = ImageGPT.load_from_checkpoint(args.checkpoint)
    model.prepare_data()
    trainer.test(model)