def train(args): if args.pretrained is not None: model = ImageGPT.load_from_checkpoint(args.pretrained) # potentially modify model for classification model.hparams = args else: model = ImageGPT(args) logger = pl.logging.TensorBoardLogger("logs", name=args.name) if args.classify: # classification # stop early for best validation accuracy for finetuning early_stopping = pl.callbacks.EarlyStopping(monitor="val_acc", patience=3) checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_acc") trainer = pl.Trainer( max_steps=args.steps, gpus=args.gpus, # early_stopping_callback=early_stopping, checkpoint_callback=checkpoint, logger=logger, ) else: # pretraining checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss") trainer = pl.Trainer( max_steps=args.steps, gpus=args.gpus, checkpoint_callback=checkpoint, logger=logger, ) trainer.fit(model)
def main(args): model = ImageGPT.load_from_checkpoint(args.checkpoint).gpt.cuda() test_x = np.load(args.test_x) np.random.shuffle(test_x) # rows for figure rows = [] for example in tqdm(range(args.num_examples)): img = test_x[example] seq = img.reshape(-1) # first half of image is context context = seq[:int(len(seq) / 2)] context_img = np.pad(context, (0, int(len(seq) / 2))).reshape(*img.shape) context = torch.from_numpy(context).cuda() # predict second half of image preds = (sample( model, context, int(len(seq) / 2), num_samples=args.num_samples).cpu().numpy().transpose()) preds = preds.reshape(-1, *img.shape) # combine context, preds, and truth for figure rows.append( np.concatenate([context_img[None, ...], preds, img[None, ...]], axis=0)) figure = make_figure(rows) figure.save("figure.png")
def main(args): model = ImageGPT.load_from_checkpoint(args.checkpoint).gpt.cuda() model.eval() context = torch.zeros(0, dtype=torch.long).cuda() frames = generate(model, context, 28 * 28, num_samples=args.rows * args.cols) pad_frames = [] for frame in frames: pad = ((0, 0), (0, 28 * 28 - frame.shape[1])) pad_frames.append(np.pad(frame, pad_width=pad)) pad_frames = np.stack(pad_frames) f, n, _ = pad_frames.shape pad_frames = pad_frames.reshape(f, args.rows, args.cols, 28, 28) pad_frames = pad_frames.swapaxes(2, 3).reshape(f, 28 * args.rows, 28 * args.cols) pad_frames = pad_frames[..., np.newaxis] * np.ones(3) * 17 pad_frames = pad_frames.astype(np.uint8) clip = ImageSequenceClip(list(pad_frames)[::args.downsample], fps=args.fps).resize(args.scale) clip.write_gif("out.gif", fps=args.fps)
import pytorch_lightning as pl import pytorch_lightning.logging import argparse from module import ImageGPT if __name__ == "__main__": parser = argparse.ArgumentParser() parser = ImageGPT.add_model_specific_args(parser) parser.add_argument("--train_x", default="./data/train_x.npy") parser.add_argument("--train_y", default="./data/train_y.npy") parser.add_argument("--test_x", default="./data/test_x.npy") parser.add_argument("--test_y", default="./data/test_y.npy") parser.add_argument("--gpus", default="0") subparsers = parser.add_subparsers() parser.add_argument("--pretrained", type=str, default=None) # parser_train.add_argument("-n", "--name", type=str, required=True) parser.add_argument("-n", "--name", type=str, default='fmnist_gen') parser_test = subparsers.add_parser("test") parser_test.add_argument("checkpoint", type=str) args = parser.parse_args() logger = pl.logging.TensorBoardLogger("logs", name=args.name) model = ImageGPT(args) checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss") trainer = pl.Trainer(
def test(args): trainer = pl.Trainer(gpus=args.gpus) model = ImageGPT.load_from_checkpoint(args.checkpoint) model.prepare_data() trainer.test(model)