Example #1
0
def main():

    import argparse
    parser = argparse.ArgumentParser(
        description="imsitu VSRL. Training, evaluation and prediction.")
    parser.add_argument("--gpuid",
                        default=-1,
                        help="put GPU id > -1 in GPU mode",
                        type=int)
    parser.add_argument('--output_dir',
                        type=str,
                        default='./trained_models',
                        help='Location to output the model')
    parser.add_argument('--resume_training',
                        action='store_true',
                        help='Resume training from the model [resume_model]')
    parser.add_argument('--resume_model',
                        type=str,
                        default='',
                        help='The model we resume')
    parser.add_argument('--evaluate',
                        action='store_true',
                        help='Only use the testing mode')
    parser.add_argument('--evaluate_visualize',
                        action='store_true',
                        help='Only use the testing mode to visualize ')
    parser.add_argument('--evaluate_rare',
                        action='store_true',
                        help='Only use the testing mode')
    parser.add_argument('--test',
                        action='store_true',
                        help='Only use the testing mode')
    parser.add_argument('--dataset_folder',
                        type=str,
                        default='./imSitu',
                        help='Location of annotations')
    parser.add_argument('--imgset_dir',
                        type=str,
                        default='./resized_256',
                        help='Location of original images')
    parser.add_argument('--train_file',
                        default="train_freq2000.json",
                        type=str,
                        help='trainfile name')
    parser.add_argument('--dev_file',
                        default="dev_freq2000.json",
                        type=str,
                        help='dev file name')
    parser.add_argument('--test_file',
                        default="test_freq2000.json",
                        type=str,
                        help='test file name')
    parser.add_argument('--model_saving_name',
                        type=str,
                        help='saving name of the outpul model')

    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--model', type=str, default='top_down_baseline')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--clip_norm', type=float, default=0.25)
    parser.add_argument('--num_workers', type=int, default=3)

    args = parser.parse_args()

    n_epoch = args.epochs
    batch_size = args.batch_size
    clip_norm = args.clip_norm
    n_worker = args.num_workers

    dataset_folder = args.dataset_folder
    imgset_folder = args.imgset_dir

    train_set = json.load(open(dataset_folder + '/' + args.train_file))

    encoder = imsitu_encoder.imsitu_encoder(train_set)

    train_set = imsitu_loader.imsitu_loader(imgset_folder, train_set, encoder,
                                            'train', encoder.train_transform)

    constructor = 'build_%s' % args.model
    model = getattr(top_down_baseline, constructor)(encoder.get_num_roles(),
                                                    encoder.get_num_verbs(),
                                                    encoder.get_num_labels(),
                                                    encoder)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=n_worker)

    dev_set = json.load(open(dataset_folder + '/' + args.dev_file))
    dev_set = imsitu_loader.imsitu_loader(imgset_folder, dev_set, encoder,
                                          'val', encoder.dev_transform)
    dev_loader = torch.utils.data.DataLoader(dev_set,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=n_worker)

    test_set = json.load(open(dataset_folder + '/' + args.test_file))
    test_set = imsitu_loader.imsitu_loader(imgset_folder, test_set, encoder,
                                           'test', encoder.dev_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=n_worker)

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    torch.manual_seed(args.seed)
    if args.gpuid >= 0:
        model.cuda()
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.benchmark = True

    if args.resume_training:
        print('Resume training from: {}'.format(args.resume_model))
        args.train_all = True
        if len(args.resume_model) == 0:
            raise Exception('[pretrained module] not specified')
        utils.load_net(args.resume_model, [model])
        optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3)
        model_name = 'resume_all'

    else:
        print('Training from the scratch.')
        model_name = 'train_full'
        utils.set_trainable(model, True)
        optimizer = torch.optim.Adamax(
            [{
                'params': model.convnet.parameters(),
                'lr': 5e-5
            }, {
                'params': model.role_emb.parameters()
            }, {
                'params': model.verb_emb.parameters()
            }, {
                'params': model.query_composer.parameters()
            }, {
                'params': model.v_att.parameters()
            }, {
                'params': model.q_net.parameters()
            }, {
                'params': model.v_net.parameters()
            }, {
                'params': model.classifier.parameters()
            }],
            lr=1e-3)

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    if args.evaluate:
        top1, top5, val_loss = eval(model, dev_loader, encoder, args.gpuid)

        top1_avg = top1.get_average_results_nouns()
        top5_avg = top5.get_average_results_nouns()

        avg_score = top1_avg["verb"] + top1_avg["value"] + top1_avg["value-all"] + top5_avg["verb"] + \
                    top5_avg["value"] + top5_avg["value-all"] + top5_avg["value*"] + top5_avg["value-all*"]
        avg_score /= 8

        print('Dev average :{:.2f} {} {}'.format(
            avg_score * 100, utils.format_dict(top1_avg, '{:.2f}', '1-'),
            utils.format_dict(top5_avg, '{:.2f}', '5-')))

    elif args.test:
        top1, top5, val_loss = eval(model, test_loader, encoder, args.gpuid)

        top1_avg = top1.get_average_results_nouns()
        top5_avg = top5.get_average_results_nouns()

        avg_score = top1_avg["verb"] + top1_avg["value"] + top1_avg["value-all"] + top5_avg["verb"] + \
                    top5_avg["value"] + top5_avg["value-all"] + top5_avg["value*"] + top5_avg["value-all*"]
        avg_score /= 8

        print('Test average :{:.2f} {} {}'.format(
            avg_score * 100, utils.format_dict(top1_avg, '{:.2f}', '1-'),
            utils.format_dict(top5_avg, '{:.2f}', '5-')))

    else:

        print('Model training started!')
        train(
            model,
            train_loader,
            dev_loader,
            optimizer,
            scheduler,
            n_epoch,
            args.output_dir,
            encoder,
            args.gpuid,
            clip_norm,
            model_name,
            args.model_saving_name,
        )
Example #2
0
def main(args):
    optimizer = Adam()

    enc = encoder(batch_size=args.batch_size,
                  time=2 * args.time - 1,
                  latent_dim=args.latent_dim,
                  frame_height=args.frame_height,
                  frame_width=args.frame_width,
                  frame_channels=args.frame_channels)
    gen = generator(batch_size=args.batch_size,
                    time=args.time,
                    latent_dim=args.latent_dim,
                    frame_height=args.frame_height,
                    frame_width=args.frame_width,
                    frame_channels=args.frame_channels)
    dis_gan = discriminator(batch_size=args.batch_size,
                            time=2 * args.time - 1,
                            name='gan',
                            frame_height=args.frame_height,
                            frame_width=args.frame_width,
                            frame_channels=args.frame_channels)
    dis_vae = discriminator(batch_size=args.batch_size,
                            time=2 * args.time - 1,
                            name='vae',
                            frame_height=args.frame_height,
                            frame_width=args.frame_width,
                            frame_channels=args.frame_channels)

    gen.compile(optimizer, ['mean_absolute_error'])

    print(
        "\n\n\n-------------------------------- Encoder Summary --------------------------------\n"
    )
    enc.summary()

    print(
        "\n\n\n-------------------------------- Generator Summary --------------------------------\n"
    )
    gen.summary()

    print(
        "\n\n\n-------------------------------- Discriminator Summary --------------------------------\n"
    )
    dis_gan.summary()

    plot_model(enc, to_file='../models/savp_encoder.png', show_shapes=True)
    plot_model(gen, to_file='../models/savp_generator.png', show_shapes=True)
    plot_model(dis_gan,
               to_file='../models/savp_discriminator.png',
               show_shapes=True)

    encoder_train, generator_train, discriminator_train, vaegan = build_graph(
        encoder=enc,
        generator=gen,
        discriminator_gan=dis_gan,
        discriminator_vae=dis_vae,
        batch_size=args.batch_size,
        time=args.time,
        latent_dim=args.latent_dim,
        frame_height=args.frame_height,
        frame_width=args.frame_width,
        frame_channels=args.frame_channels)

    set_trainable(enc, False)
    set_trainable(gen, False)
    discriminator_train.compile(optimizer, [
        'binary_crossentropy', 'binary_crossentropy', 'binary_crossentropy',
        'binary_crossentropy'
    ],
                                loss_weights=[1., 1., 1., 1.])
    print(
        "\n\n\n-------------------------------- Discriminator Train Summary --------------------------------\n"
    )
    discriminator_train.summary()

    plot_model(discriminator_train,
               to_file='../models/savp_discriminator_train.png',
               show_shapes=True)

    set_trainable(dis_vae, False)
    set_trainable(dis_gan, False)
    set_trainable(gen, True)
    generator_train.compile(
        optimizer,
        ['binary_crossentropy', 'binary_crossentropy', 'mean_absolute_error'],
        loss_weights=[1., 1., 1.])
    print(
        "\n\n\n-------------------------------- Generator Train Summary --------------------------------\n"
    )
    generator_train.summary()

    plot_model(generator_train,
               to_file='../models/savp_generator_train.png',
               show_shapes=True)

    set_trainable(gen, False)
    set_trainable(enc, True)
    encoder_train.compile(optimizer, [kl_loss, 'mean_absolute_error'],
                          loss_weights=[0.1, 100.])
    print(
        "\n\n\n-------------------------------- Encoder Train Summary --------------------------------\n"
    )
    encoder_train.summary()

    plot_model(encoder_train,
               to_file='../models/savp_encoder_train.png',
               show_shapes=True)

    #Determine the steps per epoch for training
    train_video_dir = os.path.join(args.train_directory, '*_act_14_*/*.jpg')
    steps_per_epoch = count_images(train_video_dir,
                                   batch_size=args.batch_size,
                                   time=args.time)

    print("Steps per epoch: {}".format(steps_per_epoch))

    #Determine the steps per epoch for evaluating
    val_video_dir = os.path.join(args.val_directory, '*_act_14_*/*.jpg')
    val_steps = count_images(val_video_dir,
                             batch_size=args.batch_size,
                             time=args.time)

    # #Determine the steps per epoch for training
    # train_video_dir = os.path.join(args.train_directory, '*.avi')
    # steps_per_epoch = count_frames(train_video_dir, batch_size=args.batch_size, time=args.time, camera_fps=args.camera_fps)

    # print("Steps per epoch: {}".format(steps_per_epoch))

    # #Determine the steps per epoch for evaluating
    # val_video_dir = os.path.join(args.val_directory, '*.avi')
    # val_steps = count_frames(val_video_dir, batch_size=args.batch_size, time=args.time, camera_fps=args.camera_fps)

    initial_epoch = args.load_epoch + 1

    l = 0

    if args.load_epoch > 0:
        discriminator_train.load_weights(
            '../weights/savp.discriminator.{:03d}.h5'.format(args.load_epoch))
        generator_train.load_weights(
            '../weights/savp.generator.{:03d}.h5'.format(args.load_epoch))
        encoder_train.load_weights('../weights/savp.encoder.{:03d}.h5'.format(
            args.load_epoch))
        print('Loaded weights {:03d}'.format(args.load_epoch))
        l = args.load_epoch * steps_per_epoch * (args.time -
                                                 1) * args.batch_size

    for i in range(args.load_epoch + 1, args.epochs + 1):
        video_loader = video_generator(video_dir=args.train_directory,
                                       frame_height=args.frame_height,
                                       frame_width=args.frame_width,
                                       frame_channels=args.frame_channels,
                                       batch_size=args.batch_size,
                                       time=2 * args.time - 1,
                                       camera_fps=args.camera_fps,
                                       json_filename="kinetics_train.json")

        dis_losses_avg = np.zeros((5))
        gen_losses_avg = np.zeros((4))
        enc_losses_avg = np.zeros((3))
        rec_losses_avg = 0

        seed = 0
        rng = np.random.RandomState(seed)

        print("\nEpoch {:03d} \n".format(i))

        for j in range(steps_per_epoch):

            x, y = next(video_loader)

            dis_inputs, dis_outputs = discriminator_data(
                x,
                y,
                latent_dim=args.latent_dim,
                seed=seed,
                time_init=args.time)
            gen_inputs, gen_outputs = generator_data(
                x,
                y,
                latent_dim=args.latent_dim,
                seed=seed,
                time_init=args.time)
            enc_inputs, enc_outputs = encoder_data(x,
                                                   y,
                                                   latent_dim=args.latent_dim,
                                                   seed=seed,
                                                   time_init=args.time)
            # print(enc_outputs[0].shape)

            # if j == 0:
            # previous_inputs = np.zeros_like(gen_outputs[2])
            # current_data = x[0]
            # output_data = y[0]
            # print(current_data.shape)
            # for m in range(args.time):
            # 	cv2.imshow("Input", current_data[m])
            # 	cv2.imshow("Output", output_data[m])
            # 	cv2.waitKey(25)
            # cv2.destroyAllWindows()

            # Use model inference
            z_p = rng.normal(size=(args.batch_size, args.time,
                                   args.latent_dim))
            gen_sample_input = [gen_inputs[0], gen_inputs[1], z_p]
            for k in range(args.time - 1):
                l = l + 1 * args.batch_size
                epsilon = -1. / float(args.k) * float(l) + 1.
                if epsilon < np.random.random_sample():
                    # print('\nUsing model inference at ', epsilon)
                    gen_sample_output = gen.predict_on_batch(gen_sample_input)
                    gen_inputs[3][:, k + args.time -
                                  1, :, :, :] = gen_sample_output[:,
                                                                  args.time -
                                                                  1, :, :, :]
                    gen_sample_input[1][:, 0:args.time -
                                        1, :, :, :] = gen_sample_input[
                                            1][:, 1:args.time, :, :, :]
                    gen_sample_input[
                        1][:, args.time -
                           1, :, :, :] = gen_sample_output[:, args.time -
                                                           1, :, :, :]
                    gen_sample_input[2] = rng.normal(size=(args.batch_size,
                                                           args.time,
                                                           args.latent_dim))
                else:
                    gen_sample_input[1][:, 0:args.time, :, :, :] = gen_inputs[
                        3][:, k + 1:args.time + k + 1, :, :, :]

            dis_losses = discriminator_train.train_on_batch(
                dis_inputs, dis_outputs)
            gen_losses = generator_train.train_on_batch(
                gen_inputs, gen_outputs)
            enc_losses = encoder_train.train_on_batch(enc_inputs, enc_outputs)

            m = 0
            for loss in dis_losses:
                dis_losses_avg[m] = (dis_losses_avg[m] * j + loss) / (j + 1)
                m = m + 1

            m = 0
            for loss in gen_losses:
                gen_losses_avg[m] = (gen_losses_avg[m] * j + loss) / (j + 1)
                m = m + 1

            m = 0
            for loss in enc_losses:
                enc_losses_avg[m] = (enc_losses_avg[m] * j + loss) / (j + 1)
                m = m + 1

            print(
                '\rEpoch {:03d} Step {:04d} of {:04d}. D_t_vae: {:.4f} D_t_gan: {:.4f} D_f_vae: {:.4f} D_f_gan: {:.4f} G_t_vae: {:.4f} G_t_gan: {:.4f} E_kl: {:.4f} Rec: {:.4f}'
                .format(i, j, steps_per_epoch, dis_losses_avg[1],
                        dis_losses_avg[2], dis_losses_avg[3],
                        dis_losses_avg[4], gen_losses_avg[1],
                        gen_losses_avg[2], enc_losses_avg[1],
                        enc_losses_avg[2]),
                end="",
                flush=True)

        if i % 10 == 0 and i > 0:
            encoder_train.optimizer.lr = encoder_train.optimizer.lr / 10
            discriminator_train.optimizer.lr = discriminator_train.optimizer.lr / 10
            generator_train.optimizer.lr = generator_train.optimizer.lr / 10
            print("\nCurrent learning rate is now: {}".format(
                K.eval(encoder_train.optimizer.lr)))

        print('\nSaving models')
        discriminator_train.save_weights(
            '../weights/savp.discriminator.{:03d}.h5'.format(i))
        generator_train.save_weights(
            '../weights/savp.generator.{:03d}.h5'.format(i))
        encoder_train.save_weights(
            '../weights/savp.encoder.{:03d}.h5'.format(i))
        enc.save('../weights/savp.enc.{:03d}.h5'.format(i))

        dis_losses_avg = np.zeros((5))
        gen_losses_avg = np.zeros((4))
        enc_losses_avg = np.zeros((3))
        rec_losses_avg = 0

        seed = 0
        rng = np.random.RandomState(seed)

        video_loader_val = video_generator(video_dir=args.val_directory,
                                           frame_height=args.frame_height,
                                           frame_width=args.frame_width,
                                           frame_channels=args.frame_channels,
                                           batch_size=args.batch_size,
                                           time=2 * args.time - 1,
                                           camera_fps=args.camera_fps,
                                           json_filename="kinetics_val.json")

        print('Evaluating the model')
        for j in range(val_steps):

            x, y = next(video_loader_val)

            dis_inputs, dis_outputs = discriminator_data(
                x,
                y,
                latent_dim=args.latent_dim,
                seed=seed,
                time_init=args.time)
            gen_inputs, gen_outputs = generator_data(
                x,
                y,
                latent_dim=args.latent_dim,
                seed=seed,
                time_init=args.time)
            enc_inputs, enc_outputs = encoder_data(x,
                                                   y,
                                                   latent_dim=args.latent_dim,
                                                   seed=seed,
                                                   time_init=args.time)

            dis_losses = discriminator_train.test_on_batch(
                dis_inputs, dis_outputs)
            gen_losses = generator_train.test_on_batch(gen_inputs, gen_outputs)
            enc_losses = encoder_train.test_on_batch(enc_inputs, enc_outputs)

            m = 0
            for loss in dis_losses:
                dis_losses_avg[m] = (dis_losses_avg[m] * j + loss) / (j + 1)
                m = m + 1

            m = 0
            for loss in gen_losses:
                gen_losses_avg[m] = (gen_losses_avg[m] * j + loss) / (j + 1)
                m = m + 1

            m = 0
            for loss in enc_losses:
                enc_losses_avg[m] = (enc_losses_avg[m] * j + loss) / (j + 1)
                m = m + 1

            print(
                '\rEpoch {:03d} Evaluation Step {:04d} of {:04d}. D_t_vae: {:.4f} D_t_gan: {:.4f} D_f_vae: {:.4f} D_f_gan: {:.4f} G_t_vae: {:.4f} G_t_gan: {:.4f} E_kl: {:.4f} Rec: {:.4f}'
                .format(i, j, val_steps, dis_losses_avg[1], dis_losses_avg[2],
                        dis_losses_avg[3], dis_losses_avg[4],
                        gen_losses_avg[1], gen_losses_avg[2],
                        enc_losses_avg[1], enc_losses_avg[2]),
                end="",
                flush=True)
def main(args):
    video_gen = video_generator(video_dir=args.test_directory,
                                frame_height=args.frame_height,
                                frame_width=args.frame_width,
                                frame_channels=args.frame_channels,
                                batch_size=args.batch_size,
                                time=args.time,
                                camera_fps=args.camera_fps,
                                json_filename="kinetics_test.json")

    rmsprop = RMSprop(lr=0.0003)

    predict_gen = vaegan_loader(video_gen,
                                time_init=2 * args.time - 1,
                                latent_dim=args.latent_dim)

    enc = encoder(time=2 * args.time - 1,
                  latent_dim=args.latent_dim,
                  batch_size=args.batch_size,
                  frame_height=args.frame_height,
                  frame_width=args.frame_width,
                  frame_channels=args.frame_channels)
    gen = generator(time=args.time,
                    latent_dim=args.latent_dim,
                    batch_size=args.batch_size,
                    frame_height=args.frame_height,
                    frame_width=args.frame_width,
                    frame_channels=args.frame_channels)
    dis_gan = discriminator(time=2 * args.time - 1,
                            name='gan',
                            batch_size=args.batch_size,
                            frame_height=args.frame_height,
                            frame_width=args.frame_width,
                            frame_channels=args.frame_channels)
    dis_vae = discriminator(time=2 * args.time - 1,
                            name='vae',
                            batch_size=args.batch_size,
                            frame_height=args.frame_height,
                            frame_width=args.frame_width,
                            frame_channels=args.frame_channels)

    encoder_train, generator_train, discriminator_train, vaegan = build_graph(
        enc,
        gen,
        dis_gan,
        dis_vae,
        time=args.time,
        latent_dim=args.latent_dim,
        frame_channels=args.frame_channels)

    set_trainable(dis_vae, True)
    set_trainable(enc, False)
    set_trainable(gen, False)
    discriminator_train.compile(rmsprop, [
        'binary_crossentropy', 'binary_crossentropy', 'binary_crossentropy',
        'binary_crossentropy'
    ],
                                loss_weights=[1., 1., 1., 1.])
    discriminator_train.summary()

    set_trainable(dis_vae, False)
    set_trainable(dis_gan, False)
    set_trainable(gen, True)
    generator_train.compile(
        rmsprop,
        ['binary_crossentropy', 'binary_crossentropy', 'mean_absolute_error'],
        loss_weights=[1., 1., 5.])
    generator_train.summary()

    set_trainable(gen, False)
    set_trainable(enc, True)
    encoder_train.compile(rmsprop, [kl_loss, 'mean_absolute_error'])
    encoder_train.summary()

    if args.load_epoch > 0:
        # discriminator_train.load_weights('../weights/discriminator.{:03d}.h5'.format(args.load_epoch))
        # generator_train.load_weights('../weights/generator.{:03d}.h5'.format(args.load_epoch))
        encoder_train.load_weights('../weights/encoder.{:03d}.h5'.format(
            args.load_epoch))
        print('Loaded weights {:03d}'.format(args.load_epoch))
    # gen_single = generator(time=1, latent_dim=args.latent_dim, frame_height=args.frame_height, frame_width=args.frame_width, frame_channels=args.frame_channels)
    # gen_single.compile(rmsprop, ['mean_absolute_error'])
    # gen_single.load_weights(args.weights_directory, by_name=True)
    # gen_single.summary()
    # for j in range(len(glob.glob(os.path.join(args.test_directory,'*.mp4')))):

    seed = 0

    val_video_dir = os.path.join(args.test_directory, '*_act_14_*/*.jpg')
    val_steps = count_images(val_video_dir,
                             batch_size=args.batch_size,
                             time=args.time)

    # val_video_dir = os.path.join(args.test_directory, '*.avi')
    # val_steps = count_frames(val_video_dir, batch_size=args.batch_size, time=args.time, camera_fps=args.camera_fps)

    video_loader_val_enc = video_generator(video_dir=args.test_directory,
                                           frame_height=args.frame_height,
                                           frame_width=args.frame_width,
                                           frame_channels=args.frame_channels,
                                           batch_size=args.batch_size,
                                           time=args.time,
                                           camera_fps=args.camera_fps,
                                           json_filename="kinetics_val.json")
    enc_loader_val = encoder_loader(video_loader_val_enc,
                                    latent_dim=args.latent_dim,
                                    seed=seed,
                                    time_init=args.time)
    enc_losses_val = encoder_train.evaluate_generator(enc_loader_val,
                                                      steps=val_steps,
                                                      verbose=1)
    print('Encoder: ')
    print(encoder_train.metrics_names)
    print(enc_losses_val)

    j = 0
    while True:
        j = j + 1
        inputs, next_actual_frames = next(predict_gen)
        # print(inputs.shape)
        # next_pred_frames = np.empty_like(next_actual_frames)
        # next_pred_frames[:,0] = gen.predict([np.repeat(np.expand_dims(inputs[0][:,0], axis=0), args.time, axis=1), np.repeat(inputs[1][0], args.time, axis=1)], batch_size=args.batch_size)[:,0]
        # for k in range(1, args.time):
        # 	next_pred_frames[:,k] = gen.predict([np.expand_dims(next_pred_frames[:,k-1], axis=0), inputs[1][k]]*args.time, batch_size=args.batch_size)[:,k]
        # out = dis_gan.predict(next_pred_frames, batch_size=args.batch_size)
        # print(out)
        # current_pred_frames = vaegan.predict(inputs, batch_size=args.batch_size)
        # current_pred_frames[:,0] = x
        # for i in range(30):
        # next_pred_frames = vaegan.predict(inputs, batch_size=args.batch_size)
        next_pred_frames = np.empty_like(next_actual_frames)
        # print(next_pred_frames.shape)
        # next_pred_frames = gen.predict(inputs, batch_size= args.batch_size)
        # next_pred_frames[:,0:args.time - 1] = inputs[1][:, 1:args.time]
        for i in range(args.time):
            next_gen_frames = gen.predict(inputs, batch_size=args.batch_size)
            if i == 0:
                next_pred_frames[:, :args.time] = next_gen_frames
            else:
                next_pred_frames[:, args.time + i -
                                 1] = next_gen_frames[:, args.time - 1]
            inputs[1][:, 0:args.time - 1] = inputs[1][:, 1:args.time]
            inputs[1][:, args.time - 1] = next_gen_frames[:, args.time - 1]
        # losses = vaegan.evaluate(inputs, next_actual_frames, batch_size=args.batch_size)
        # print("losses", losses)
        for i in range(next_pred_frames.shape[1]):
            # cv2.imshow('previous', inputs[1][0,i])
            # cv2.imshow('next pred {:04d}'.format(j), next_pred_frames[0,i])
            # cv2.imshow('next actual {:04d}'.format(j), next_actual_frames[0,i])
            cv2.imshow('next pred', next_pred_frames[0, i])
            cv2.imshow('next actual', next_actual_frames[0, i])
            # filename_pred = "results_images/predicted_{:04d}_{:04d}.jpg".format(j,i)
            # filename_actual = "results_images/actual_{:04d}_{:04d}.jpg".format(j,i)
            # cv2.imwrite(filename_pred, (255.*next_pred_frames[0,i]).astype(int))
            # cv2.imwrite(filename_actual, (255.*next_actual_frames[0,i]).astype(int))
            cv2.waitKey(0)
            # inputs[0] = np.expand_dims(next_actual_frames[0, next_pred_frames.shape[1] - 1], axis=0)
            # inputs[3] = next_pred_frames

        # out = dis_gan.predict(next_pred_frames, batch_size=args.batch_size)
        # # print('Discriminator: ', out)
        # next_pred_frames = np.empty_like(next_actual_frames)
        # # next_pred_frames = vaegan.predict(inputs)
        # # next_pred_frames = gen_single.predict(inputs)
        # # prev_frame = np.expand_dims(x, axis=0)
        # for i in range(args.time):
        # # 	z = z_p[:,i]
        # # 	z = np.expand_dims(z, axis=1)
        # 	pred_frame = gen_single.predict([inputs[0], np.expand_dims(inputs[1][:,i], axis=1), np.expand_dims(inputs[2][:,i], axis=1)], batch_size=args.batch_size)
        # 	next_pred_frames[0,i] = pred_frame.squeeze()
        # # 	loss = gen_single.evaluate([x, prev_frame, z], np.expand_dims(next_actual_frames[:,i], axis=1), batch_size=args.batch_size)
        # # 	print('Generator loss: ', loss.squeeze())
        # # 	# pred_frames = vaegan.predict(inputs, batch_size=args.batch_size)
        # 	cv2.imshow('previous', inputs[1][0,i])
        # 	cv2.imshow('next pred', next_pred_frames[0,i])
        # 	cv2.imshow('next actual', next_actual_frames[0,i])
        # 	cv2.waitKey(0)
        # # 	next_pred_frames[:, i] = pred_frame
        # # 	prev_frame = pred_frame
        # 	# print(i)
        # 	# print(np.sum(pred_frames[:,0]))
        # 	# inputs[1] = pred_frames
        # prev_frame = pred_frame
        # for k in range(args.time, 2*args.time-1):
        # 	pred_frame = gen_single.predict([inputs[0], prev_frame, np.expand_dims(inputs[2][:,k], axis=1)], batch_size=args.batch_size)
        # 	next_pred_frames[0,k] = pred_frame.squeeze()
        # # 	loss = gen_single.evaluate([x, prev_frame, z], np.expand_dims(next_actual_frames[:,i], axis=1), batch_size=args.batch_size)
        # # 	print('Generator loss: ', loss.squeeze())
        # # 	# pred_frames = vaegan.predict(inputs, batch_size=args.batch_size)
        # 	# cv2.imshow('previous', inputs[1][0,j])
        # 	cv2.imshow('next pred', next_pred_frames[0,k])
        # 	cv2.imshow('next actual', next_actual_frames[0,k])
        # 	cv2.waitKey(0)
        # 	prev_frame = pred_frame

        # out = dis_gan.predict(next_pred_frames, batch_size=args.batch_size)
        # print('Discriminator for predicted: ', out.squeeze())
        # out = dis_gan.predict(next_actual_frames, batch_size=args.batch_size)
        # print('Discriminator for actual: ', out.squeeze())

        # encoder_out = encoder_train.evaluate(
        # 	[inputs[0], inputs[1], inputs[3], inputs[2]],
        # 	[inputs[4], next_actual_frames],
        # 	batch_size=args.batch_size)
        # print("Encoder {:04d} KL: {:.4f} Rec: {:.4f}".format(j, encoder_out[1], encoder_out[2]))

        # cv2.waitKey(0)

        # cv2.destroyAllWindows()
        # next_pred_frames = np.empty_like(frames)

        # for i in range(len(next_actual_frames.squeeze())):
        # 	if i == 0:
        # 		next_pred_frames[0,i] = model.predict(np.expand_dims(frames[:,i],axis=0), batch_size=args.batch_size)
        # 	else:
        # 		next_pred_frames[0,i] = model.predict(np.expand_dims(next_pred_frames[:,i-1],axis=0), batch_size=args.batch_size)

        # frames = frames.squeeze()
        next_pred_frames = next_pred_frames.squeeze()
        next_actual_frames = next_actual_frames.squeeze()
        # print(next_actual_frames.shape)
        # current_pred_frames = current_pred_frames.squeeze()

        write_to_video(next_actual_frames,
                       '../results/{:03d}_actual.avi'.format(j),
                       frame_height=args.frame_height,
                       frame_width=args.frame_width,
                       video_fps=2)
        write_to_video(next_pred_frames,
                       '../results/{:03d}_pred.avi'.format(j),
                       frame_height=args.frame_height,
                       frame_width=args.frame_width,
                       video_fps=2)