def fetch_mmn_data(args, shape=None):
    print('Fetching train...\n')
    train_set = MovingMNIST(root=MMN_PATH,
                            train=True,
                            download=True,
                            shape=shape)
    print('Fetching test...\n')
    test_set = MovingMNIST(root=MMN_PATH,
                           train=False,
                           download=True,
                           shape=shape)

    batch_size = args.batch_size
    train_set.normalize()
    test_set.normalize()

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)
    dev_loader = test_loader = torch.utils.data.DataLoader(
        dataset=test_set, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, dev_loader, test_loader
Esempio n. 2
0
    if args.seed != None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    # Logging
    log_interval_num = args.log_interval
    log_dir = os.path.join(args.root_log_dir, args.log_dir)
    os.mkdir(log_dir)
    os.mkdir(os.path.join(log_dir, 'models'))
    os.mkdir(os.path.join(log_dir, 'runs'))
    writer = SummaryWriter(log_dir=os.path.join(log_dir, 'runs'))

    # Dataset
    if args.dataset_type == 'MovingMNIST':
        data_path = os.path.join(args.data_dir, 'mnist_test_seq.npy')
        full_dataset = MovingMNIST(data_path, rescale=args.rescale)
        data_num = len(full_dataset)
        train_size = int(0.9 * data_num)
        test_size = data_num - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, test_size])
    elif args.dataset_type == 'MovingMNISTLR':
        train_dataset = MovingMNISTLR(args.data_dir, train=True, download=True)
        test_dataset = MovingMNISTLR(args.data_dir, train=False, download=True)
    else:
        raise NotImplementedError()
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers)
    train_loader_iterator = iter(train_loader)
Esempio n. 3
0
def main():
    args = TrainOptions()
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)

    # transforms = T.Compose([T.ToTensor(),
    #                         MovingMNIST.normalize])
    transforms = T.ToTensor()

    dataset = MovingMNIST(args.data_dir, transforms, False)
    dataloader = DataLoader(dataset, 32, True, pin_memory=True)

    result_dir = '/home/chaehuny/data/svp/result/'
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    model = SVP(args.img_dim, args.h_dim, args.g_dim, args.latent_dim,
                args.rnn_size, args.pre_n_layers, args.pos_n_layers,
                args.pri_n_layers).to(device)
    _ = model.load(os.path.join('/home/chaehuny/data/svp/ckpt', args.resume))

    for i, video in enumerate(dataloader):
        video = video.to(device)
        prediction_frames = video[:, :, 0]
        # prediction_frames = MovingMNIST.denormalize(make_grid(prediction_frames, nrow=4).cpu().detach())
        prediction_frames = make_grid(prediction_frames, nrow=8).cpu().detach()
        save_image(prediction_frames,
                   os.path.join(result_dir, 'generate_image_%03d.jpg' % (0)))
        save_image(prediction_frames,
                   os.path.join(result_dir, 'original_image_%03d.jpg' % (0)))

        posterior_state = model.posterior_initial_update(video[:, :, 0])
        prior_state = None
        predictor_state = None

        skip = None
        prev_frames = video[:, :, 0]
        curr_frames = video[:, :, 1]

        for t in range(1, video.shape[2]):
            print(t)
            if t < args.n_past:
                predict_frames, _, _, _, _, posterior_state, prior_state, predictor_state, skip \
                    = model(prev_frames, curr_frames, posterior_state, prior_state, predictor_state, skip)
                # generate_frames = MovingMNIST.denormalize(make_grid(predict_frames, nrow=4).cpu().detach())
                # original_frames = MovingMNIST.denormalize(make_grid(curr_frames, nrow=4).cpu().detach())
                generate_frames = make_grid(predict_frames,
                                            nrow=8).cpu().detach()
                original_frames = make_grid(curr_frames, nrow=8).cpu().detach()
                save_image(
                    generate_frames,
                    os.path.join(result_dir, 'generate_image_%03d.jpg' % (t)))
                save_image(
                    original_frames,
                    os.path.join(result_dir, 'original_image_%03d.jpg' % (t)))

                skip = None
                prev_frames = video[:, :, t]
                curr_frames = video[:, :, t + 1]
            else:
                predict_frames, prior_state, predictor_state, skip = model.inference(
                    prev_frames, prior_state, predictor_state, skip)

                # generate_frames = MovingMNIST.denormalize(make_grid(predict_frames, nrow=4).cpu().detach())
                # original_frames = MovingMNIST.denormalize(make_grid(curr_frames, nrow=4).cpu().detach())
                generate_frames = make_grid(predict_frames,
                                            nrow=8).cpu().detach()
                original_frames = make_grid(curr_frames, nrow=8).cpu().detach()
                save_image(
                    generate_frames,
                    os.path.join(result_dir, 'generate_image_%03d.jpg' % (t)))
                save_image(
                    original_frames,
                    os.path.join(result_dir, 'original_image_%03d.jpg' % (t)))

                if t < video.shape[2] - 1:
                    prev_frames = predict_frames
                    curr_frames = video[:, :, t + 1]

        grid2gif(os.path.join(result_dir, 'generate_image_*'),
                 os.path.join(result_dir, 'generate_image.gif'),
                 delay=10)
        grid2gif(os.path.join(result_dir, 'original_image_*'),
                 os.path.join(result_dir, 'original_image.gif'),
                 delay=10)

        if i == 0:
            break
Esempio n. 4
0
    gradients = grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).cuda()
        if opt['use_cuda'] else torch.ones(disc_interpolates.size()),
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0]

    gradient_penalty = (
        (gradients.norm(2, dim=1) - 1)**2).mean() * opt['lambda']
    return gradient_penalty


dset = MovingMNIST(opt['dataset_path'])
denormalizer = dset.denormalize
loader = DataLoader(dset,
                    batch_size=opt["batch_size"],
                    shuffle=True,
                    num_workers=4)
loader_iterator = iter(loader)
data_len = len(loader)

gen_o = opt['models']['generator']
dis_o = opt['models']['discriminator']

gen = Gen(z_slow_dim=gen_o["z_slow_dim"],
          z_fast_dim=gen_o["z_fast_dim"],
          out_channels=gen_o["out_channels"],
          bottom_width=gen_o["bottom_width"],
Esempio n. 5
0
    gradients = grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).cuda()
        if opt['use_cuda'] else torch.ones(disc_interpolates.size()),
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0]

    gradient_penalty = (
        (gradients.norm(2, dim=1) - 1)**2).mean() * opt['lambda']
    return gradient_penalty


# dset = MovingMNIST(opt['dataset_path'])
dset = MovingMNIST("./data/mnist_test_seq.npy")
denormalizer = dset.denormalize
loader = DataLoader(dset,
                    batch_size=opt["batch_size"],
                    shuffle=True,
                    num_workers=0)
loader_iterator = iter(loader)
data_len = len(loader)

gen_o = opt['models']['generator']
dis_o = opt['models']['discriminator']

gen = Gen(z_slow_dim=gen_o["z_slow_dim"],
          z_fast_dim=gen_o["z_fast_dim"],
          out_channels=gen_o["out_channels"],
          bottom_width=gen_o["bottom_width"],