def get_dataset(dataset_directory: str,
                dataset_class_name: str,
                dataset_hparams_dict: Union[str, dict],
                dataset_hparams: str,
                mode: str,
                epochs: Optional[int],
                batch_size: int,
                seed: int,
                balance_key: Optional[str] = None,
                shuffle: bool = True):
    if isinstance(dataset_hparams_dict, str):
        dataset_hparams_dict = json.load(open(dataset_hparams_dict, 'r'))

    dataset_class = get_dataset_class(dataset_class_name)
    my_dataset = dataset_class(dataset_directory,
                               mode=mode,
                               num_epochs=epochs,
                               seed=seed,
                               hparams_dict=dataset_hparams_dict,
                               hparams=dataset_hparams)

    if balance_key is not None:
        tf_dataset = my_dataset.make_dataset(batch_size=batch_size,
                                             use_batches=False)
        tf_dataset = balance_xy_dataset(tf_dataset, balance_key)
        tf_dataset = tf_dataset.batch(batch_size)
    else:
        tf_dataset = my_dataset.make_dataset(batch_size, shuffle=shuffle)

    return my_dataset, tf_dataset
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument(
        "--val_input_dirs",
        type=str,
        nargs='+',
        help="directories containing the tfrecords. default: [input_dir]")
    parser.add_argument("--logs_dir",
                        default='logs',
                        help="ignored if output_dir is specified")
    parser.add_argument(
        "--output_dir",
        help=
        "output directory where json files, summary, model, gifs, etc are saved. "
        "default is logs_dir/model_fname, where model_fname consists of "
        "information from model and model_hparams")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )
    parser.add_argument("--resume",
                        action='store_true',
                        help='resume from lastest checkpoint in output_dir.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--dataset_hparams_dict",
                        type=str,
                        help="a json file of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")
    parser.add_argument("--model_hparams_dict",
                        type=str,
                        help="a json file of model hyperparameters")

    parser.add_argument(
        "--summary_freq",
        type=int,
        default=1000,
        help=
        "save summaries (except for image and eval summaries) every summary_freq steps"
    )
    parser.add_argument(
        "--image_summary_freq",
        type=int,
        default=5000,
        help="save image summaries every image_summary_freq steps")
    parser.add_argument(
        "--eval_summary_freq",
        type=int,
        default=0,
        help="save eval summaries every eval_summary_freq steps")
    parser.add_argument("--progress_freq",
                        type=int,
                        default=100,
                        help="display progress every progress_freq steps")
    parser.add_argument("--metrics_freq",
                        type=int,
                        default=0,
                        help="run and display metrics every metrics_freq step")
    parser.add_argument(
        "--gif_freq",
        type=int,
        default=0,
        help="save gifs of predicted frames every gif_freq steps")
    parser.add_argument("--save_freq",
                        type=int,
                        default=5000,
                        help="save model every save_freq steps, 0 to disable")

    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    if args.output_dir is None:
        list_depth = 0
        model_fname = ''
        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
            if t == '[':
                list_depth += 1
            if t == ']':
                list_depth -= 1
            if list_depth and t == ',':
                t = '..'
            if t in '=,':
                t = '.'
            if t in '[]':
                t = ''
            model_fname += t
        args.output_dir = os.path.join(args.logs_dir, model_fname)

    if args.resume:
        if args.checkpoint:
            raise ValueError('resume and checkpoint cannot both be specified')
        args.checkpoint = args.output_dir

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.dataset_hparams_dict:
        with open(args.dataset_hparams_dict) as f:
            dataset_hparams_dict.update(json.loads(f.read()))
    if args.model_hparams_dict:
        with open(args.model_hparams_dict) as f:
            model_hparams_dict.update(json.loads(f.read()))
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict.update(json.loads(f.read()))
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict.update(json.loads(f.read()))
                model_hparams_dict.pop('num_gpus',
                                       None)  # backwards-compatibility
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    train_dataset = VideoDataset(args.input_dir,
                                 mode='train',
                                 hparams_dict=dataset_hparams_dict,
                                 hparams=args.dataset_hparams)
    val_input_dirs = args.val_input_dirs or [args.input_dir]
    val_datasets = [
        VideoDataset(val_input_dir,
                     mode='val',
                     hparams_dict=dataset_hparams_dict,
                     hparams=args.dataset_hparams)
        for val_input_dir in val_input_dirs
    ]
    if len(val_input_dirs) > 1:
        if isinstance(val_datasets[-1], datasets.KTHVideoDataset):
            val_datasets[-1].set_sequence_length(40)
        else:
            val_datasets[-1].set_sequence_length(30)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    train_model = VideoPredictionModel(
        mode='train',
        hparams_dict=override_hparams_dict(train_dataset),
        hparams=args.model_hparams)
    val_models = [
        VideoPredictionModel(mode='val',
                             hparams_dict=override_hparams_dict(val_dataset),
                             hparams=args.model_hparams)
        for val_dataset in val_datasets
    ]

    batch_size = train_model.hparams.batch_size
    with tf.variable_scope('') as training_scope:
        train_model.build_graph(*train_dataset.make_batch(batch_size))
    for val_model, val_dataset in zip(val_models, val_datasets):
        with tf.variable_scope(training_scope, reuse=True):
            val_model.build_graph(*val_dataset.make_batch(batch_size))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
        f.write(
            json.dumps(train_dataset.hparams.values(),
                       sort_keys=True,
                       indent=4))
    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
        f.write(
            json.dumps(train_model.hparams.values(), sort_keys=True, indent=4))

    if args.gif_freq:
        val_model = val_models[0]
        val_tensors = OrderedDict()
        context_images = val_model.inputs['images'][:, :val_model.hparams.
                                                    context_frames]
        val_tensors['gen_images_vis'] = tf.concat(
            [context_images, val_model.gen_images], axis=1)
        if val_model.gen_images_enc is not None:
            val_tensors['gen_images_enc_vis'] = tf.concat(
                [context_images, val_model.gen_images_enc], axis=1)
        val_tensors.update({
            name: tensor
            for name, tensor in val_model.inputs.items()
            if tensor.shape.ndims >= 4
        })
        val_tensors['targets'] = val_model.targets
        val_tensors.update({
            name: tensor
            for name, tensor in val_model.outputs.items()
            if tensor.shape.ndims >= 4
        })
        val_tensor_clips = OrderedDict([
            (name, tf_utils.tensor_to_clip(output))
            for name, output in val_tensors.items()
        ])

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=3)
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))
    eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES))
    eval_image_summaries = image_summaries & eval_summaries
    image_summaries -= eval_image_summaries
    eval_summaries -= eval_image_summaries
    if args.summary_freq:
        summary_op = tf.summary.merge(summaries)
    if args.image_summary_freq:
        image_summary_op = tf.summary.merge(list(image_summaries))
    if args.eval_summary_freq:
        eval_summary_op = tf.summary.merge(list(eval_summaries))
        eval_image_summary_op = tf.summary.merge(list(eval_image_summaries))

    if args.summary_freq or args.image_summary_freq or args.eval_summary_freq:
        summary_writer = tf.summary.FileWriter(args.output_dir)

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    global_step = tf.train.get_or_create_global_step()
    max_steps = train_model.hparams.max_steps
    with tf.Session(config=config) as sess:
        print("parameter_count =", sess.run(parameter_count))

        sess.run(tf.global_variables_initializer())
        train_model.restore(sess, args.checkpoint)

        start_step = sess.run(global_step)
        # start at one step earlier to log everything without doing any training
        # step is relative to the start_step
        for step in range(-1, max_steps - start_step):
            if step == 0:
                start = time.time()

            def should(freq):
                return freq and ((step + 1) % freq == 0 or
                                 (step + 1) in (0, max_steps - start_step))

            fetches = {"global_step": global_step}
            if step >= 0:
                fetches["train_op"] = train_model.train_op

            if should(args.progress_freq):
                fetches['d_losses'] = train_model.d_losses
                fetches['g_losses'] = train_model.g_losses
                if isinstance(train_model.learning_rate, tf.Tensor):
                    fetches["learning_rate"] = train_model.learning_rate
            if should(args.metrics_freq):
                fetches['metrics'] = train_model.metrics
            if should(args.summary_freq):
                fetches["summary"] = summary_op
            if should(args.image_summary_freq):
                fetches["image_summary"] = image_summary_op
            if should(args.eval_summary_freq):
                fetches["eval_summary"] = eval_summary_op
                fetches["eval_image_summary"] = eval_image_summary_op

            run_start_time = time.time()
            results = sess.run(fetches)
            run_elapsed_time = time.time() - run_start_time
            if run_elapsed_time > 1.5:
                print('session.run took %0.1fs' % run_elapsed_time)

            if should(args.summary_freq):
                print("recording summary")
                summary_writer.add_summary(results["summary"],
                                           results["global_step"])
                print("done")
            if should(args.image_summary_freq):
                print("recording image summary")
                summary_writer.add_summary(
                    tf_utils.convert_tensor_to_gif_summary(
                        results["image_summary"]), results["global_step"])
                print("done")
            if should(args.eval_summary_freq):
                print("recording eval summary")
                summary_writer.add_summary(results["eval_summary"],
                                           results["global_step"])
                summary_writer.add_summary(
                    tf_utils.convert_tensor_to_gif_summary(
                        results["eval_image_summary"]), results["global_step"])
                print("done")
            if should(args.summary_freq) or should(
                    args.image_summary_freq) or should(args.eval_summary_freq):
                summary_writer.flush()
            if should(args.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                steps_per_epoch = math.ceil(
                    train_dataset.num_examples_per_epoch() / batch_size)
                train_epoch = math.ceil(results["global_step"] /
                                        steps_per_epoch)
                train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                print("progress  global step %d  epoch %d  step %d" %
                      (results["global_step"], train_epoch, train_step))
                if step >= 0:
                    elapsed_time = time.time() - start
                    average_time = elapsed_time / (step + 1)
                    images_per_sec = batch_size / average_time
                    remaining_time = (max_steps -
                                      (start_step + step)) * average_time
                    print(
                        "          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)"
                        %
                        (images_per_sec, remaining_time / 60, remaining_time /
                         60 / 60, remaining_time / 60 / 60 / 24))

                for name, loss in itertools.chain(results['d_losses'].items(),
                                                  results['g_losses'].items()):
                    print(name, loss)
                if isinstance(train_model.learning_rate, tf.Tensor):
                    print("learning_rate", results["learning_rate"])
            if should(args.metrics_freq):
                for name, metric in results['metrics'].items():
                    print(name, metric)

            if should(args.save_freq):
                print("saving model to", args.output_dir)
                saver.save(sess,
                           os.path.join(args.output_dir, "model"),
                           global_step=global_step)
                print("done")

            if should(args.gif_freq):
                image_dir = os.path.join(args.output_dir, 'images')
                if not os.path.exists(image_dir):
                    os.makedirs(image_dir)

                gif_clips = sess.run(val_tensor_clips)
                gif_step = results["global_step"]
                for name, clip in gif_clips.items():
                    filename = "%08d-%s.gif" % (gif_step, name)
                    print("saving gif to", os.path.join(image_dir, filename))
                    ffmpeg_gif.save_gif(os.path.join(image_dir, filename),
                                        clip,
                                        fps=4)
                    print("done")
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument("--results_dir",
                        type=str,
                        default='results',
                        help="ignored if output_gif_dir is specified")
    parser.add_argument(
        "--results_gif_dir",
        type=str,
        help="default is results_dir. ignored if output_gif_dir is specified")
    parser.add_argument(
        "--results_png_dir",
        type=str,
        help="default is results_dir. ignored if output_png_dir is specified")
    parser.add_argument(
        "--output_gif_dir",
        help="output directory where samples are saved as gifs. default is "
        "results_gif_dir/model_fname")
    parser.add_argument(
        "--output_png_dir",
        help="output directory where samples are saved as pngs. default is "
        "results_png_dir/model_fname")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )

    parser.add_argument("--mode",
                        type=str,
                        choices=['val', 'test'],
                        default='val',
                        help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size",
                        type=int,
                        default=8,
                        help="number of samples in batch")
    parser.add_argument(
        "--num_samples",
        type=int,
        help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--num_stochastic_samples", type=int, default=5)
    parser.add_argument("--gif_length",
                        type=int,
                        help="default is sequence_length")
    parser.add_argument("--fps", type=int, default=4)

    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    args.results_gif_dir = args.results_gif_dir or args.results_dir
    args.results_png_dir = args.results_png_dir or args.results_dir
    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")
        args.output_gif_dir = args.output_gif_dir or os.path.join(
            args.results_gif_dir,
            os.path.split(checkpoint_dir)[1])
        args.output_png_dir = args.output_png_dir or os.path.join(
            args.results_png_dir,
            os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError(
                'dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError(
                'model is required when checkpoint is not specified')
        args.output_gif_dir = args.output_gif_dir or os.path.join(
            args.results_gif_dir, 'model.%s' % args.model)
        args.output_png_dir = args.output_png_dir or os.path.join(
            args.results_png_dir, 'model.%s' % args.model)

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir,
                           mode=args.mode,
                           num_epochs=args.num_epochs,
                           seed=args.seed,
                           hparams_dict=dataset_hparams_dict,
                           hparams=args.dataset_hparams)

    VideoPredictionModel = models.get_model_class(args.model)
    hparams_dict = dict(model_hparams_dict)
    hparams_dict.update({
        'context_frames': dataset.hparams.context_frames,
        'sequence_length': dataset.hparams.sequence_length,
        'repeat': dataset.hparams.time_shift,
    })
    model = VideoPredictionModel(mode=args.mode,
                                 hparams_dict=hparams_dict,
                                 hparams=args.model_hparams)

    sequence_length = model.hparams.sequence_length
    context_frames = model.hparams.context_frames
    future_length = sequence_length - context_frames

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    print('HELLOO', num_examples_per_epoch, args.batch_size)
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError(
            'batch_size should evenly divide the dataset size %d' %
            num_examples_per_epoch)

    inputs = dataset.make_batch(args.batch_size)
    input_phs = {
        k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k)
        for k, v in inputs.items()
    }
    with tf.variable_scope(''):
        model.build_graph(input_phs)

    for output_dir in (args.output_gif_dir, args.output_png_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
            f.write(
                json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
            f.write(
                json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.graph.as_default()

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    # dir_every_n = 128
    key = args.input_dir.split('/')[-1]
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results = sess.run(inputs)
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" %
              (sample_ind, sample_ind + args.batch_size))
        # if sample_ind % dir_every_n == 0:
        #     output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + dir_every_n))
        feed_dict = {
            input_ph: input_results[name]
            for name, input_ph in input_phs.items()
        }
        # for stochastic_sample_ind in range(args.num_stochastic_samples):
        for stochastic_sample_ind in range(1):
            gen_images = sess.run(model.outputs['gen_images'],
                                  feed_dict=feed_dict)
            # only keep the future frames
            gen_images = gen_images[:, -future_length:]
            for i, gen_images_ in enumerate(gen_images):
                context_images_ = (input_results['images'][i] * 255.0).astype(
                    np.uint8)
                gen_images_ = (gen_images_ * 255.0).astype(np.uint8)

                gen_images_fname = 'gen_image_%05d_%02d.gif' % (
                    sample_ind + i, stochastic_sample_ind)
                context_and_gen_images = list(
                    context_images_[:context_frames]) + list(gen_images_)
                if args.gif_length:
                    context_and_gen_images = context_and_gen_images[:args.
                                                                    gif_length]
                # output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_images_fname)
                # print('SAVE GIF', output_dir, gen_images_fname)
                save_gif(os.path.join(args.output_gif_dir, key,
                                      gen_images_fname),
                         context_and_gen_images,
                         fps=args.fps)

                gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(
                    2, len(str(len(gen_images_) - 1)))
                for t, gen_image in enumerate(gen_images_):
                    gen_image_fname = gen_image_fname_pattern % (
                        sample_ind + i, stochastic_sample_ind, t)
                    if gen_image.shape[-1] == 1:
                        gen_image = np.tile(gen_image, (1, 1, 3))
                    else:
                        gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
                    # output_dir = os.path.join(args.output_png_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_image_fname)
                    print('SAVE PNG', key)
                    cv2.imwrite(
                        os.path.join(args.output_png_dir, key,
                                     gen_image_fname), gen_image)

        sample_ind += args.batch_size
def main():
    """
    results_dir
    ├── output_dir                              # condition / method
    │   ├── prediction_eval_lpips_max           # task: best sample in terms of LPIPS similarity
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── lpips.csv
    │   ├── prediction_eval_ssim_max            # task: best sample in terms of SSIM
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the future ones)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── ssim.csv
    │   └── ...
    └── ...
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument("--results_dir",
                        type=str,
                        default='results',
                        help="ignored if output_dir is specified")
    parser.add_argument(
        "--output_dir",
        help=
        "output directory where results are saved. default is results_dir/model_fname, "
        "where model_fname is the directory name of checkpoint")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )

    parser.add_argument("--mode",
                        type=str,
                        choices=['val', 'test'],
                        default='val',
                        help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size",
                        type=int,
                        default=8,
                        help="number of samples in batch")
    parser.add_argument(
        "--num_samples",
        type=int,
        help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--eval_substasks",
                        type=str,
                        nargs='+',
                        default=['max', 'avg', 'min'],
                        help='subtasks to evaluate (e.g. max, avg, min)')
    parser.add_argument("--only_metrics", action='store_true')
    parser.add_argument("--num_stochastic_samples", type=int, default=100)

    parser.add_argument(
        "--gt_inputs_dir",
        type=str,
        help="directory containing input ground truth images for ismple dataset"
    )
    parser.add_argument(
        "--gt_outputs_dir",
        type=str,
        help=
        "directory containing output ground truth images for ismple dataset")

    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")
        args.output_dir = args.output_dir or os.path.join(
            args.results_dir,
            os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError(
                'dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError(
                'model is required when checkpoint is not specified')
        args.output_dir = args.output_dir or os.path.join(
            args.results_dir, 'model.%s' % args.model)

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir,
                           mode=args.mode,
                           num_epochs=args.num_epochs,
                           seed=args.seed,
                           hparams_dict=dataset_hparams_dict,
                           hparams=args.dataset_hparams)

    VideoPredictionModel = models.get_model_class(args.model)
    hparams_dict = dict(model_hparams_dict)
    hparams_dict.update({
        'context_frames': dataset.hparams.context_frames,
        'sequence_length': dataset.hparams.sequence_length,
        'repeat': dataset.hparams.time_shift,
    })
    model = VideoPredictionModel(
        hparams_dict=hparams_dict,
        hparams=args.model_hparams,
        eval_num_samples=args.num_stochastic_samples,
        eval_parallel_iterations=args.eval_parallel_iterations)

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError(
            'batch_size should evenly divide the dataset size %d' %
            num_examples_per_epoch)

    inputs = dataset.make_batch(args.batch_size)
    input_phs = {
        k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k)
        for k, v in inputs.items()
    }
    with tf.variable_scope(''):
        model.build_graph(input_phs)

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.graph.as_default()

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results = sess.run(inputs)
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" %
              (sample_ind, sample_ind + args.batch_size))

        feed_dict = {
            input_ph: input_results[name]
            for name, input_ph in input_phs.items()
        }
        # compute "best" metrics using the computation graph
        fetches = {'images': model.inputs['images']}
        fetches.update(model.eval_outputs.items())
        fetches.update(model.eval_metrics.items())
        results = sess.run(fetches, feed_dict=feed_dict)
        save_prediction_eval_results(
            os.path.join(output_dir, 'prediction_eval'), results,
            model.hparams, sample_ind, args.only_metrics, args.eval_substasks)
        sample_ind += args.batch_size

    metric_fnames = []
    metric_names = ['psnr', 'ssim', 'lpips']
    subtasks = ['max']
    for metric_name in metric_names:
        for subtask in subtasks:
            metric_fnames.append(
                os.path.join(output_dir,
                             'prediction_eval_%s_%s' % (metric_name, subtask),
                             'metrics', metric_name))

    for metric_fname in metric_fnames:
        task_name, _, metric_name = metric_fname.split('/')[-3:]
        metric = load_metrics(metric_fname)
        print('=' * 31)
        print(task_name, metric_name)
        print('-' * 31)
        metric_header_format = '{:>10} {:>20}'
        metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})'
        print(
            metric_header_format.format('time step',
                                        os.path.split(metric_fname)[1]))
        for t, (metric_mean, metric_std) in enumerate(
                zip(metric.mean(axis=0), metric.std(axis=0))):
            print(metric_row_format.format(t, metric_mean, metric_std))
        print(
            metric_row_format.format('mean (std)', metric.mean(),
                                     metric.std()))
        print('=' * 31)
Exemple #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                     "train, val, test, etc, or a directory containing "
                                                                     "the tfrecords")
    parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
    parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
                                             "default is logs_dir/model_fname, where model_fname consists of "
                                             "information from model and model_hparams")
    parser.add_argument("--output_dir_postfix", default="")
    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
    parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")

    parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set")
    parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set")
    parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set")
    parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only")
    parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
    parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable")

    parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    if args.output_dir is None:
        list_depth = 0
        model_fname = ''
        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
            if t == '[':
                list_depth += 1
            if t == ']':
                list_depth -= 1
            if list_depth and t == ',':
                t = '..'
            if t in '=,':
                t = '.'
            if t in '[]':
                t = ''
            model_fname += t
        args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix

    if args.resume:
        if args.checkpoint:
            raise ValueError('resume and checkpoint cannot both be specified')
        args.checkpoint = args.output_dir

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.dataset_hparams_dict:
        with open(args.dataset_hparams_dict) as f:
            dataset_hparams_dict.update(json.loads(f.read()))
    if args.model_hparams_dict:
        with open(args.model_hparams_dict) as f:
            model_hparams_dict.update(json.loads(f.read()))
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict.update(json.loads(f.read()))
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict.update(json.loads(f.read()))
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")

    print('----------------------------------- Options ------------------------------------')
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print('------------------------------------- End --------------------------------------')

    VideoDataset = datasets.get_dataset_class(args.dataset)
    train_dataset = VideoDataset(
        args.input_dir,
        mode='train',
        hparams_dict=dataset_hparams_dict,
        hparams=args.dataset_hparams)
    val_dataset = VideoDataset(
        args.val_input_dir or args.input_dir,
        mode='val',
        hparams_dict=dataset_hparams_dict,
        hparams=args.dataset_hparams)
    if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length:
        # the longer dataset is only used for the accum_eval_metrics
        long_val_dataset = VideoDataset(
            args.val_input_dir or args.input_dir,
            mode='val',
            hparams_dict=dataset_hparams_dict,
            hparams=args.dataset_hparams)
        long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length)
    else:
        long_val_dataset = None

    variable_scope = tf.get_variable_scope()
    variable_scope.set_use_resource(True)

    VideoPredictionModel = models.get_model_class(args.model)
    hparams_dict = dict(model_hparams_dict)
    hparams_dict.update({
        'context_frames': train_dataset.hparams.context_frames,
        'sequence_length': train_dataset.hparams.sequence_length,
        'repeat': train_dataset.hparams.time_shift,
    })
    model = VideoPredictionModel(
        hparams_dict=hparams_dict,
        hparams=args.model_hparams,
        aggregate_nccl=args.aggregate_nccl)

    batch_size = model.hparams.batch_size
    train_tf_dataset = train_dataset.make_dataset(batch_size)
    train_iterator = train_tf_dataset.make_one_shot_iterator()
    train_handle = train_iterator.string_handle()
    val_tf_dataset = val_dataset.make_dataset(batch_size)
    val_iterator = val_tf_dataset.make_one_shot_iterator()
    val_handle = val_iterator.string_handle()
    iterator = tf.data.Iterator.from_string_handle(
        train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
    inputs = iterator.get_next()

    # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles
    model.build_graph(inputs)

    if long_val_dataset is not None:
        # separately build a model for the longer sequence.
        # this is needed because the model doesn't support dynamic shapes.
        long_hparams_dict = dict(hparams_dict)
        long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length
        # use smaller batch size for longer model to prevenet running out of memory
        long_hparams_dict['batch_size'] = model.hparams.batch_size // 2
        long_model = VideoPredictionModel(
            mode="test",  # to not build the losses and discriminators
            hparams_dict=long_hparams_dict,
            hparams=args.model_hparams,
            aggregate_nccl=args.aggregate_nccl)
        tf.get_variable_scope().reuse_variables()
        long_model.build_graph(long_val_dataset.make_batch(batch_size))
    else:
        long_model = None

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    with tf.name_scope("parameter_count"):
        # exclude trainable variables that are replicas (used in multi-gpu setting)
        trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])

    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)

    # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
    if (args.summary_freq != 0 or args.image_summary_freq != 0 or
            args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0):
        summary_writer = tf.summary.FileWriter(args.output_dir)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    global_step = tf.train.get_or_create_global_step()
    max_steps = model.hparams.max_steps
    with tf.Session(config=config) as sess:
        print("parameter_count =", sess.run(parameter_count))

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        model.restore(sess, args.checkpoint)
        sess.run(model.post_init_ops)
        val_handle_eval = sess.run(val_handle)
        sess.graph.finalize()

        start_step = sess.run(global_step)

        def should(step, freq):
            if freq is None:
                return (step + 1) == (max_steps - start_step)
            else:
                return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))

        def should_eval(step, freq):
            # never run eval summaries at the beginning since it's expensive, unless it's the last iteration
            return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step))

        # start at one step earlier to log everything without doing any training
        # step is relative to the start_step
        for step in range(-1, max_steps - start_step):
            if step == 1:
                # skip step -1 and 0 for timing purposes (for warmstarting)
                start_time = time.time()

            fetches = {"global_step": global_step}
            if step >= 0:
                fetches["train_op"] = model.train_op
            if should(step, args.progress_freq):
                fetches['d_loss'] = model.d_loss
                fetches['g_loss'] = model.g_loss
                fetches['d_losses'] = model.d_losses
                fetches['g_losses'] = model.g_losses
                if isinstance(model.learning_rate, tf.Tensor):
                    fetches["learning_rate"] = model.learning_rate
            if should(step, args.summary_freq):
                fetches["summary"] = model.summary_op
            if should(step, args.image_summary_freq):
                fetches["image_summary"] = model.image_summary_op
            if should_eval(step, args.eval_summary_freq):
                fetches["eval_summary"] = model.eval_summary_op

            run_start_time = time.time()
            results = sess.run(fetches)
            run_elapsed_time = time.time() - run_start_time
            if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
                print('running train_op took too long (%0.1fs)' % run_elapsed_time)

            if (should(step, args.summary_freq) or
                    should(step, args.image_summary_freq) or
                    should_eval(step, args.eval_summary_freq)):
                val_fetches = {"global_step": global_step}
                if should(step, args.summary_freq):
                    val_fetches["summary"] = model.summary_op
                if should(step, args.image_summary_freq):
                    val_fetches["image_summary"] = model.image_summary_op
                if should_eval(step, args.eval_summary_freq):
                    val_fetches["eval_summary"] = model.eval_summary_op
                val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
                for name, summary in val_results.items():
                    if name == 'global_step':
                        continue
                    val_results[name] = add_tag_suffix(summary, '_1')

            if should(step, args.summary_freq):
                print("recording summary")
                summary_writer.add_summary(results["summary"], results["global_step"])
                summary_writer.add_summary(val_results["summary"], val_results["global_step"])
                print("done")
            if should(step, args.image_summary_freq):
                print("recording image summary")
                summary_writer.add_summary(results["image_summary"], results["global_step"])
                summary_writer.add_summary(val_results["image_summary"], val_results["global_step"])
                print("done")
            if should_eval(step, args.eval_summary_freq):
                print("recording eval summary")
                summary_writer.add_summary(results["eval_summary"], results["global_step"])
                summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"])
                print("done")
            if should_eval(step, args.accum_eval_summary_freq):
                val_datasets = [val_dataset]
                val_models = [model]
                if long_model is not None:
                    val_datasets.append(long_val_dataset)
                    val_models.append(long_model)
                for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
                    sess.run(val_model.accum_eval_metrics_reset_op)
                    # traverse (roughly up to rounding based on the batch size) all the validation dataset
                    accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
                    val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
                    for update_step in range(accum_eval_summary_num_updates):
                        print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
                        val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
                    accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
                    print("recording accum eval summary")
                    summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
                    print("done")
            if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or
                    should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)):
                summary_writer.flush()
            if should(step, args.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                # global step is read before it's incremented
                steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
                train_epoch = results["global_step"] / steps_per_epoch
                print("progress  global step %d  epoch %0.1f" % (results["global_step"] + 1, train_epoch))
                if step > 0:
                    elapsed_time = time.time() - start_time
                    average_time = elapsed_time / step
                    images_per_sec = batch_size / average_time
                    remaining_time = (max_steps - (start_step + step + 1)) * average_time
                    print("          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
                          (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))

                if results['d_losses']:
                    print("d_loss", results["d_loss"])
                for name, loss in results['d_losses'].items():
                    print("  ", name, loss)
                if results['g_losses']:
                    print("g_loss", results["g_loss"])
                for name, loss in results['g_losses'].items():
                    print("  ", name, loss)
                if isinstance(model.learning_rate, tf.Tensor):
                    print("learning_rate", results["learning_rate"])

            if should(step, args.save_freq):
                print("saving model to", args.output_dir)
                saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step)
                print("done")
                'model is required when checkpoint is not specified')
        args.output_gif_dir = args.output_gif_dir or os.path.join(
            args.results_gif_dir, 'model.%s' % args.model)
        args.output_png_dir = args.output_png_dir or os.path.join(
            args.results_png_dir, 'model.%s' % args.model)

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir,
                           mode=args.mode,
                           num_epochs=args.num_epochs,
                           seed=args.seed,
                           hparams_dict=dataset_hparams_dict,
                           hparams=args.dataset_hparams)

    VideoPredictionModel = models.get_model_class(args.model)
    hparams_dict = dict(model_hparams_dict)
    hparams_dict.update({
        'context_frames': dataset.hparams.context_frames,
        'sequence_length': dataset.hparams.sequence_length,
        'eval_length': dataset.hparams.eval_length,
        'repeat': dataset.hparams.time_shift,
    })
Exemple #7
0
def main():
    """
    results_dir
    ├── output_dir                              # condition / method
    │   ├── prediction                          # task
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       ├── psnr.csv
    │   │       ├── mse.csv
    │   │       └── ssim.csv
    │   ├── prediction_eval_vgg_csim_max        # task: best sample in terms of VGG cosine similarity
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png  # indexed by sample index and time step
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png      # predicted images (only the ones in the loss)
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── vgg_csim.csv
    │   ├── servo
    │   │   ├── inputs
    │   │   │   ├── context_image_00000_00.png
    │   │   │   ├── ...
    │   │   │   ├── goal_image_00000_00.png     # only one goal image per sample
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_image_00000_00.png
    │   │   │   ├── ...
    │   │   │   ├── gen_image_goal_diff_00000_00.png
    │   │   │   └── ...
    │   │   └── metrics
    │   │       ├── action_mse.csv
    │   │       └── goal_image_mse.csv
    │   ├── motion
    │   │   ├── inputs
    │   │   │   ├── pix_distrib_00000_00.png
    │   │   │   └── ...
    │   │   ├── outputs
    │   │   │   ├── gen_pix_distrib_00000_00.png
    │   │   │   ├── ...
    │   │   │   ├── gen_pix_distrib_overlaid_00000_00.png
    │   │   │   └── ...
    │   │   └── metrics
    │   │       └── pix_dist.csv
    │   └── ...
    └── ...
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument("--results_dir",
                        type=str,
                        default='results',
                        help="ignored if output_dir is specified")
    parser.add_argument(
        "--output_dir",
        help=
        "output directory where results are saved. default is results_dir/model_fname, "
        "where model_fname is the directory name of checkpoint")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )

    parser.add_argument("--mode",
                        type=str,
                        choices=['val', 'test'],
                        default='val',
                        help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size",
                        type=int,
                        default=8,
                        help="number of samples in batch")
    parser.add_argument(
        "--num_samples",
        type=int,
        help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument(
        "--tasks",
        type=str,
        nargs='+',
        help=
        'tasks to evaluate (e.g. prediction, prediction_eval, servo, motion)')
    parser.add_argument(
        "--eval_substasks",
        type=str,
        nargs='+',
        default=['max', 'min'],
        help=
        'subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval'
    )
    parser.add_argument("--only_metrics", action='store_true')
    parser.add_argument("--num_stochastic_samples", type=int, default=100)

    parser.add_argument(
        "--gt_inputs_dir",
        type=str,
        help="directory containing input ground truth images for ismple dataset"
    )
    parser.add_argument(
        "--gt_outputs_dir",
        type=str,
        help=
        "directory containing output ground truth images for ismple dataset")

    parser.add_argument("--eval_parallel_iterations", type=int, default=10)
    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
                model_hparams_dict.pop('num_gpus',
                                       None)  # backwards-compatibility
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")
        args.output_dir = args.output_dir or os.path.join(
            args.results_dir,
            os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError(
                'dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError(
                'model is required when checkpoint is not specified')
        args.output_dir = args.output_dir or os.path.join(
            args.results_dir, 'model.%s' % args.model)

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir,
                           mode=args.mode,
                           num_epochs=args.num_epochs,
                           seed=args.seed,
                           hparams_dict=dataset_hparams_dict,
                           hparams=args.dataset_hparams)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    model = VideoPredictionModel(
        mode='test',
        hparams_dict=override_hparams_dict(dataset),
        hparams=args.model_hparams,
        eval_num_samples=args.num_stochastic_samples,
        eval_parallel_iterations=args.eval_parallel_iterations)
    context_frames = model.hparams.context_frames
    sequence_length = model.hparams.sequence_length

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError('batch_size should evenly divide the dataset')

    inputs, target = dataset.make_batch(args.batch_size)
    if not isinstance(model, models.GroundTruthVideoPredictionModel):
        # remove ground truth data past context_frames to prevent accidentally using it
        for k, v in inputs.items():
            if k != 'actions':
                inputs[k] = v[:, :context_frames]

    input_phs = {
        k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k)
        for k, v in inputs.items()
    }
    target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph')

    with tf.variable_scope(''):
        model.build_graph(input_phs, target_ph)

    tasks = args.tasks
    if tasks is None:
        tasks = ['prediction_eval']
        if 'pix_distribs' in inputs:
            tasks.append('motion')

    if 'servo' in tasks:
        servo_model = VideoPredictionModel(mode='test',
                                           hparams_dict=model_hparams_dict,
                                           hparams=args.model_hparams)
        cem_batch_size = 200
        plan_horizon = sequence_length - 1
        image_shape = inputs['images'].shape.as_list()[2:]
        state_shape = inputs['states'].shape.as_list()[2:]
        action_shape = inputs['actions'].shape.as_list()[2:]
        servo_input_phs = {
            'images':
            tf.placeholder(tf.float32,
                           shape=[cem_batch_size, context_frames] +
                           image_shape),
            'states':
            tf.placeholder(tf.float32,
                           shape=[cem_batch_size, 1] + state_shape),
            'actions':
            tf.placeholder(tf.float32,
                           shape=[cem_batch_size, plan_horizon] +
                           action_shape),
        }
        if isinstance(servo_model, models.GroundTruthVideoPredictionModel):
            images_shape = inputs['images'].shape.as_list()[1:]
            servo_input_phs['images'] = tf.placeholder(tf.float32,
                                                       shape=[cem_batch_size] +
                                                       images_shape)
        with tf.variable_scope('', reuse=True):
            servo_model.build_graph(servo_input_phs)

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    model.restore(sess, args.checkpoint)

    if 'servo' in tasks:
        servo_policy = ServoPolicy(servo_model, sess)

    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results, target_result = sess.run([inputs, target])
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" %
              (sample_ind, sample_ind + args.batch_size))

        if 'prediction_eval' in tasks:
            feed_dict = {
                input_ph: input_results[name]
                for name, input_ph in input_phs.items()
            }
            feed_dict.update({target_ph: target_result})
            # compute "best" metrics using the computation graph (if available) or explicitly with python logic
            if model.eval_outputs and model.eval_metrics:
                fetches = {'images': model.inputs['images']}
                fetches.update(model.eval_outputs.items())
                fetches.update(model.eval_metrics.items())
                results = sess.run(fetches, feed_dict=feed_dict)
            else:
                metric_names = [
                    'psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'vgg_csim'
                ]
                metric_fns = [
                    metrics.peak_signal_to_noise_ratio_np,
                    metrics.structural_similarity_np,
                    metrics.structural_similarity_scikit_np,
                    metrics.structural_similarity_finn_np,
                    metrics.vgg_cosine_similarity_np
                ]

                all_gen_images = []
                all_metrics = [
                    np.empty((args.num_stochastic_samples, args.batch_size,
                              sequence_length - context_frames))
                    for _ in metric_names
                ]
                for s in range(args.num_stochastic_samples):
                    gen_images = sess.run(model.outputs['gen_images'],
                                          feed_dict=feed_dict)
                    all_gen_images.append(gen_images)
                    for metric_name, metric_fn, all_metric in zip(
                            metric_names, metric_fns, all_metrics):
                        metric = metric_fn(gen_images,
                                           target_result,
                                           keep_axis=(0, 1))
                        all_metric[s] = metric

                results = {}
                for metric_name, all_metric in zip(metric_names, all_metrics):
                    for subtask in args.eval_substasks:
                        results['eval_gen_images_%s/%s' %
                                (metric_name, subtask)] = np.empty_like(
                                    all_gen_images[0])
                        results['eval_%s/%s' %
                                (metric_name, subtask)] = np.empty_like(
                                    all_metric[0])

                for i in range(args.batch_size):
                    for metric_name, all_metric in zip(metric_names,
                                                       all_metrics):
                        ordered = np.argsort(
                            np.mean(all_metric, axis=-1)
                            [:, i])  # mean over time and sort over samples
                        for subtask in args.eval_substasks:
                            if subtask == 'max':
                                sidx = ordered[-1]
                            elif subtask == 'min':
                                sidx = ordered[0]
                            else:
                                raise NotImplementedError
                            results['eval_gen_images_%s/%s' %
                                    (metric_name,
                                     subtask)][i] = all_gen_images[sidx][i]
                            results['eval_%s/%s' %
                                    (metric_name,
                                     subtask)][i] = all_metric[sidx][i]
            save_prediction_eval_results(
                os.path.join(output_dir,
                             'prediction_eval'), results, model.hparams,
                sample_ind, args.only_metrics, args.eval_substasks)

        if 'prediction' in tasks or 'motion' in tasks:  # do these together
            feed_dict = {
                input_ph: input_results[name]
                for name, input_ph in input_phs.items()
            }
            fetches = {
                'images': model.inputs['images'],
                'gen_images': model.outputs['gen_images']
            }
            if 'motion' in tasks:
                fetches.update({
                    'pix_distribs':
                    model.inputs['pix_distribs'],
                    'gen_pix_distribs':
                    model.outputs['gen_pix_distribs']
                })

            if args.num_stochastic_samples:
                all_results = [
                    sess.run(fetches, feed_dict=feed_dict)
                    for _ in range(args.num_stochastic_samples)
                ]
                all_results = nest.map_structure(lambda *x: np.stack(x),
                                                 *all_results)
                all_context_images, all_images = np.split(
                    all_results['images'], [context_frames], axis=2)
                all_gen_images = all_results['gen_images'][:, :,
                                                           context_frames -
                                                           sequence_length:]
                all_mse = metrics.mean_squared_error_np(all_images,
                                                        all_gen_images,
                                                        keep_axis=(0, 1))
                all_mse_argsort = np.argsort(all_mse, axis=0)

                for subtask, argsort_ind in zip(
                    ['_best', '_median', '_worst'],
                    [0, args.num_stochastic_samples // 2, -1]):
                    all_mse_inds = all_mse_argsort[argsort_ind]
                    gather = lambda x: np.array([
                        x[ind, sample_ind]
                        for sample_ind, ind in enumerate(all_mse_inds)
                    ])
                    results = nest.map_structure(gather, all_results)
                    if 'prediction' in tasks:
                        save_prediction_results(
                            os.path.join(output_dir,
                                         'prediction' + subtask), results,
                            model.hparams, sample_ind, args.only_metrics)
                    if 'motion' in tasks:
                        draw_center = isinstance(
                            model, models.NonTrainableVideoPredictionModel)
                        save_motion_results(
                            os.path.join(output_dir, 'motion' + subtask),
                            results, model.hparams, draw_center, sample_ind,
                            args.only_metrics)
            else:
                results = sess.run(fetches, feed_dict=feed_dict)
                if 'prediction' in tasks:
                    save_prediction_results(
                        os.path.join(output_dir, 'prediction'), results,
                        model.hparams, sample_ind, args.only_metrics)
                if 'motion' in tasks:
                    draw_center = isinstance(
                        model, models.NonTrainableVideoPredictionModel)
                    save_motion_results(os.path.join(output_dir, 'motion'),
                                        results, model.hparams, draw_center,
                                        sample_ind, args.only_metrics)

        if 'servo' in tasks:
            images = input_results['images']
            states = input_results['states']
            gen_actions = []
            gen_images = []
            for images_, states_ in zip(images, states):
                obs = {
                    'context_images': images_[:context_frames],
                    'context_state': states_[0],
                    'goal_image': images_[-1]
                }
                if isinstance(servo_model,
                              models.GroundTruthVideoPredictionModel):
                    obs['context_images'] = images_
                gen_actions_, gen_images_ = servo_policy.act(
                    obs, servo_model.outputs['gen_images'])
                gen_actions.append(gen_actions_)
                gen_images.append(gen_images_)
            gen_actions = np.stack(gen_actions)
            gen_images = np.stack(gen_images)
            results = {
                'images': input_results['images'],
                'actions': input_results['actions'],
                'goal_image': input_results['images'][:, -1],
                'gen_actions': gen_actions,
                'gen_images': gen_images
            }
            save_servo_results(os.path.join(output_dir, 'servo'), results,
                               servo_model.hparams, sample_ind,
                               args.only_metrics)

        sample_ind += args.batch_size

    metric_fnames = []
    if 'prediction_eval' in tasks:
        metric_names = ['psnr', 'ssim', 'ssim_finn', 'vgg_csim']
        subtasks = ['max']
        for metric_name in metric_names:
            for subtask in subtasks:
                metric_fnames.append(
                    os.path.join(
                        output_dir,
                        'prediction_eval_%s_%s' % (metric_name, subtask),
                        'metrics', metric_name))
    if 'prediction' in tasks:
        subtask = '_best' if args.num_stochastic_samples else ''
        metric_fnames.extend([
            os.path.join(output_dir, 'prediction' + subtask, 'metrics',
                         'psnr'),
            os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'mse'),
            os.path.join(output_dir, 'prediction' + subtask, 'metrics',
                         'ssim'),
        ])
    if 'motion' in tasks:
        subtask = '_best' if args.num_stochastic_samples else ''
        metric_fnames.append(
            os.path.join(output_dir, 'motion' + subtask, 'metrics',
                         'pix_dist'))
    if 'servo' in tasks:
        metric_fnames.append(
            os.path.join(output_dir, 'servo', 'metrics', 'goal_image_mse'))
        metric_fnames.append(
            os.path.join(output_dir, 'servo', 'metrics', 'action_mse'))

    for metric_fname in metric_fnames:
        task_name, _, metric_name = metric_fname.split('/')[-3:]
        metric = load_metrics(metric_fname)
        print('=' * 31)
        print(task_name, metric_name)
        print('-' * 31)
        metric_header_format = '{:>10} {:>20}'
        metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})'
        print(
            metric_header_format.format('time step',
                                        os.path.split(metric_fname)[1]))
        for t, (metric_mean, metric_std) in enumerate(
                zip(metric.mean(axis=0), metric.std(axis=0))):
            print(metric_row_format.format(t, metric_mean, metric_std))
        print(
            metric_row_format.format('mean (std)', metric.mean(),
                                     metric.std()))
        print('=' * 31)
Exemple #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                     "train, val, test, etc, or a directory containing "
                                                                     "the tfrecords")
    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified")
    parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, "
                                             "where model_fname is the directory name of checkpoint")
    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
    parser.add_argument("--only_metrics", action='store_true')

    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='test', help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size", type=int, default=1, help="number of samples in batch")
    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval')
    parser.add_argument("--num_stochastic_samples", type=int, default=100)

    parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset")
    parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset")

    parser.add_argument("--eval_parallel_iterations", type=int, default=1)
    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")
        args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError('dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError('model is required when checkpoint is not specified')
        args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model)

    if not args.only_metrics:
        args.num_stochastic_samples = args.num_stochastic_samples // 10

    print('----------------------------------- Options ------------------------------------')
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print('------------------------------------- End --------------------------------------')

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        hparams_dict['zs_seed_no'] = args.seed
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams,
                                 eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations)
    context_frames = model.hparams.context_frames
    sequence_length = model.hparams.sequence_length

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError('batch_size should evenly divide the dataset')

    inputs, target = dataset.make_batch(args.batch_size)
    if not isinstance(model, models.GroundTruthVideoPredictionModel):
        # remove ground truth data past context_frames to prevent accidentally using it
        for k, v in inputs.items():
            if k != 'actions':
                inputs[k] = v[:, :context_frames]


    input_phs = {k: tf.placeholder(v.dtype, [v.get_shape().as_list()[0] * 2] + v.get_shape().as_list()[1:], '%s_ph' % k) for k, v in inputs.items()}
    target_ph = tf.placeholder(target.dtype, [target.get_shape().as_list()[0] * 2] + target.get_shape().as_list()[1:], 'targets_ph')

    with tf.variable_scope(''):
        model.build_graph(input_phs, target_ph)

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    our_result_list = {}
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results, target_result = sess.run([inputs, target])
        except tf.errors.OutOfRangeError:
            break

        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))

        ##########################################
        # compute statistics
        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
        feed_dict.update({target_ph: target_result})

        ###################################
        # compute diversity measures
        fetches = {'images': model.inputs['images'],
                   'gen_images': model.outputs['gen_images']}

        all_results = [sess.run(fetches, feed_dict=feed_dict) for _ in range(args.num_stochastic_samples)]
        all_results = nest.map_structure(lambda *x: np.stack(x), *all_results)
        # the result has (num_samples, batch_size, time, height, width, 3)

        all_results['context'] = np.repeat(np.expand_dims(input_results['images'][:args.batch_size], axis=0), args.num_stochastic_samples, axis=0)
        all_results['images'] = np.repeat(np.expand_dims(target_result[:args.batch_size], axis=0), args.num_stochastic_samples, axis=0)
        all_results['gen_images'] = all_results['gen_images'][:, :args.batch_size, context_frames - sequence_length:]

        all_images = all_results['images']
        all_gen_images = all_results['gen_images']

        if args.only_metrics:
            # do it for VGG
            csim_result = [metrics.vgg_cosine_similarity_np(np.array(a), np.array(b), keep_axis=(0, 1, 2)) for a, b in zip(group(all_images, 25), group(all_gen_images, 25))]
            csim_max = np.mean(np.array(csim_result), axis=-1).max()
            our_result_list.setdefault('gt_csim', []).append(csim_max)

            # do it for MSE
            all_mse = metrics.mean_squared_error_np(all_images, all_gen_images, keep_axis=(0, 1, 2))
            mse_min = np.mean(all_mse, axis=-1).min()
            our_result_list.setdefault('gt_mse', []).append(mse_min)

            # compute ours
            our_result = compute_our_diversity_np(all_gen_images)
            our_result_list.setdefault('btw_mse', []).append(our_result)

            #all_mse_argsort = np.argsort(all_mse, axis=0)
            #for subtask, argsort_ind in zip(['_best', '_median', '_worst'],
            #                                [0, args.num_stochastic_samples // 2, -1]):
            #    all_mse_inds = all_mse_argsort[argsort_ind]
            #    gather = lambda x: np.array([x[ind, sample_ind] for sample_ind, ind in enumerate(all_mse_inds)])
            #    results = nest.map_structure(gather, all_results)
            #    save_prediction_results(os.path.join(output_dir, 'prediction' + subtask),
            #                            results, model.hparams, sample_ind, (args.only_metrics or (subtask == '_median')) )
        else:
            # write logic for saving results
            print('saving results')
            for sample_no, single_video in enumerate(all_results['gen_images']):
                save_image_sequence(os.path.join(output_dir, '%05d_%02d' % (sample_ind, sample_no)), single_video[0])

        sample_ind += args.batch_size


    summary_file_array = []
    summary_file_path = os.path.join(args.output_dir, 'summary.txt')

    for key, value_list in our_result_list.items():
        our_metric = np.asarray(value_list)
        sentence = '[%s]: %12.9f (%12.9f)' % (key, our_metric.mean(), our_metric.std())
        print(sentence)
        summary_file_array.append(sentence)

    with open(summary_file_path, 'w') as f:
        f.write('\n'.join(summary_file_array))
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument(
        "--val_input_dir",
        type=str,
        help="directories containing the tfrecords. default: input_dir")
    parser.add_argument("--logs_dir",
                        default='logs',
                        help="ignored if output_dir is specified")
    parser.add_argument(
        "--output_dir",
        help=
        "output directory where json files, summary, model, gifs, etc are saved. "
        "default is logs_dir/model_fname, where model_fname consists of "
        "information from model and model_hparams")
    parser.add_argument("--output_dir_postfix", default="")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )
    parser.add_argument("--resume",
                        action='store_true',
                        help='resume from lastest checkpoint in output_dir.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--dataset_hparams_dict",
                        type=str,
                        help="a json file of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")
    parser.add_argument("--model_hparams_dict",
                        type=str,
                        help="a json file of model hyperparameters")

    parser.add_argument(
        "--summary_freq",
        type=int,
        default=10000,
        help=
        "save frequency of summaries (except for image and eval summaries) for train/validation set"
    )
    # parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set")
    # parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set")
    # parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only")
    parser.add_argument("--progress_freq",
                        type=int,
                        default=1000,
                        help="display progress every progress_freq steps")
    # parser.add_argument("--save_freq", type=int, default=50000, help="save frequence of model, 0 to disable")

    parser.add_argument(
        "--aggregate_nccl",
        type=int,
        default=0,
        help=
        "whether to use nccl or cpu for gradient aggregation in multi-gpu training"
    )
    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0.9,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int)

    parser.add_argument("--resnet_size", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=10)
    parser.add_argument("--max_steps", type=int, default=1000000)

    args = parser.parse_args()

    # Set random seed
    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    if args.output_dir is None:
        args.output_dir = './logs-shape_embedding/'

    if args.resume:
        if args.checkpoint:
            raise ValueError('resume and checkpoint cannot both be specified')
        args.checkpoint = args.output_dir

    # Set dataset & model hparam
    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.dataset_hparams_dict:
        with open(args.dataset_hparams_dict) as f:
            dataset_hparams_dict.update(json.loads(f.read()))
    if args.model_hparams_dict:
        with open(args.model_hparams_dict) as f:
            model_hparams_dict.update(json.loads(f.read()))

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    # Create dataset
    VideoDataset = datasets.get_dataset_class(args.dataset)
    train_dataset = VideoDataset(args.input_dir,
                                 mode='train',
                                 hparams_dict=dataset_hparams_dict,
                                 hparams=args.dataset_hparams)
    val_dataset = VideoDataset(args.val_input_dir or args.input_dir,
                               mode='val',
                               hparams_dict=dataset_hparams_dict,
                               hparams=args.dataset_hparams)

    batch_size = args.batch_size
    train_tf_dataset = train_dataset.make_dataset(batch_size)
    train_iterator = train_tf_dataset.make_one_shot_iterator()
    train_handle = train_iterator.string_handle()
    val_tf_dataset = val_dataset.make_dataset(batch_size)
    val_iterator = val_tf_dataset.make_one_shot_iterator()
    val_handle = val_iterator.string_handle()
    iterator = tf.data.Iterator.from_string_handle(
        train_handle, train_tf_dataset.output_types,
        train_tf_dataset.output_shapes)
    # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles
    inputs = iterator.get_next()
    shape_id_to_embeddings_map = np.load(
        './metadata/shape_id_to_embeddings_map.npy').item()
    shape_id_to_embeddings_map = tf.convert_to_tensor(
        np.array(list(shape_id_to_embeddings_map.values())))
    inputs['shape_id_to_embeddings_map'] = shape_id_to_embeddings_map

    # Create model
    variable_scope = tf.get_variable_scope()
    variable_scope.set_use_resource(True)
    num_blocks = (args.resnet_size - 2) // 6
    model = ShapeEmbeddingModel(args.resnet_size, False, 5, 16, 3, 1, None,
                                None, [num_blocks] * 3, [1, 2, 2])
    model.build_graph(inputs)

    with tf.name_scope("parameter_count"):
        # exclude trainable variables that are replicas (used in multi-gpu setting)
        trainable_variables = set(
            tf.trainable_variables())  # & set(model.saveable_variables)
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in trainable_variables])

    saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    global_step = tf.train.get_or_create_global_step()
    max_steps = args.max_steps

    with tf.Session(config=config) as sess:

        def should(step, freq):
            if freq is None:
                return (step + 1) == (max_steps - start_step)
            else:
                return freq and ((step + 1) % freq == 0 or
                                 (step + 1) in (0, max_steps - start_step))

        print("parameter_count =", sess.run(parameter_count))

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        model.restore(sess, args.checkpoint)
        val_handle_eval = sess.run(val_handle)
        sess.graph.finalize()

        start_step = sess.run(global_step)
        # start at one step earlier to log everything without doing any training
        # step is relative to the start_step
        shape_ids = []
        logits = []
        if val_dataset.num_examples_per_epoch() % args.batch_size != 0:
            raise ValueError(
                'num_examples_per_epoch should be divided by batch_size, {} % {} != 0'
                .format(val_dataset.num_examples_per_epoch(), args.batch_size))
        num_iters = val_dataset.num_examples_per_epoch() // args.batch_size
        val_fetches = {
            'logits': model.logits,
            'shape_ids': inputs['shape_ids']
        }
        for i in range(num_iters):
            val_results = sess.run(val_fetches,
                                   feed_dict={train_handle: val_handle_eval})
            logits.append(val_results['logits'])
            shape_ids.append(val_results['shape_ids'][:, 0])
            print(val_results['shape_ids'][:, 0].shape)
        metadata = {}
        metadata['shape_ids'] = np.array(shape_ids).reshape(-1)
        metadata['logits'] = np.array(logits).reshape((-1, 5))
        np.save('./metadata/shape_ids_and_embeddings.npy', metadata)
Exemple #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="a directory containing the root dat directory ")
    parser.add_argument("--output_dir",
                        type=str,
                        default='outputs',
                        help="directory to save raw outputs from predictions")
    parser.add_argument("--results_dir",
                        type=str,
                        default='results',
                        help="ignored if output_gif_dir is specified")
    parser.add_argument(
        "--results_gif_dir",
        type=str,
        help="default is results_dir. ignored if output_gif_dir is specified")
    parser.add_argument(
        "--results_png_dir",
        type=str,
        help="default is results_dir. ignored if output_png_dir is specified")
    parser.add_argument(
        "--output_gif_dir",
        help="output directory where samples are saved as gifs. default is "
        "results_gif_dir/model_fname")
    parser.add_argument(
        "--output_png_dir",
        help="output directory where samples are saved as pngs. default is "
        "results_png_dir/model_fname")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )

    parser.add_argument("--mode",
                        type=str,
                        choices=['val', 'test'],
                        default='val',
                        help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")
    parser.add_argument("--context_length",
                        type=int,
                        help="number of context frames")
    parser.add_argument("--prediction_length",
                        type=int,
                        help="number of frames to predict")

    parser.add_argument("--batch_size",
                        type=int,
                        default=5,
                        help="number of samples in batch")
    parser.add_argument(
        "--num_samples",
        type=int,
        help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--num_stochastic_samples", type=int, default=5)
    parser.add_argument("--gif_length",
                        type=int,
                        help="default is sequence_length")
    parser.add_argument("--fps", type=int, default=4)

    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    args.results_dir = args.output_dir
    args.results_gif_dir = args.output_dir
    args.results_png_dir = args.output_dir
    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
                model_hparams_dict.pop('num_gpus',
                                       None)  # backwards-compatibility
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")
    else:
        if not args.dataset:
            raise ValueError(
                'dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError(
                'model is required when checkpoint is not specified')
        args.output_gif_dir = args.output_gif_dir or os.path.join(
            args.results_gif_dir, 'model.%s' % args.model)
        args.output_png_dir = args.output_png_dir or os.path.join(
            args.results_png_dir, 'model.%s' % args.model)

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir,
                           mode=args.mode,
                           num_epochs=args.num_epochs,
                           seed=args.seed,
                           hparams_dict=dataset_hparams_dict,
                           hparams=args.dataset_hparams)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        if args.context_length is not None:
            hparams_dict['context_frames'] = args.context_length
        if args.prediction_length is not None:
            hparams_dict[
                'sequence_length'] = args.context_length + args.prediction_length
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    model = VideoPredictionModel(mode='test',
                                 hparams_dict=override_hparams_dict(dataset),
                                 hparams=args.model_hparams)

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError('batch_size should evenly divide the dataset')

    inputs, _ = dataset.make_batch(args.batch_size)
    if not isinstance(model, models.GroundTruthVideoPredictionModel):
        # remove ground truth data past context_frames to prevent accidentally using it
        for k, v in inputs.items():
            if k != 'actions':
                inputs[k] = v[:, :model.hparams.context_frames]

    input_phs = {
        k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k)
        for k, v in inputs.items()
    }

    with tf.variable_scope(''):
        model.build_graph(input_phs)

    if not os.path.exists(args.output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
        f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
        f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    model.restore(sess, args.checkpoint)

    #get all tile/filenames from input_dir
    tiles = os.listdir(args.input_dir)
    tiles.sort()

    tile_names = []
    file_names = []
    for tile in tiles:
        in_tile_path = os.path.join(args.input_dir, tile)
        files = os.listdir(in_tile_path)
        files.sort()
        for file in files:
            tile_names.append(tile)
            file_names.append(file[:-10])

    #Generate and save
    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results = sess.run(inputs)
        except tf.errors.OutOfRangeError:
            break
        print("Generating samples from %d to %d" %
              (sample_ind, sample_ind + args.batch_size))

        feed_dict = {
            input_ph: input_results[name]
            for name, input_ph in input_phs.items()
        }

        #If num stochastic samples is 1, make a mirror of input dataset with predictions
        #If num stochastic samples bigger, make a deeper subdirectory
        for stochastic_sample_ind in range(args.num_stochastic_samples):
            gen_images = sess.run(model.outputs['gen_images'],
                                  feed_dict=feed_dict)
            for i, gen_images_ in enumerate(gen_images):
                #get only channels 0,1,2,3 send time to the last
                gen_images_ = np.moveaxis(gen_images_[:, :, :, :4], 0,
                                          -1).astype(np.float16)
                #switch red and blue
                r, g, b, n = np.split(gen_images_, 4, axis=2)
                gen_images_ = np.concatenate([b, g, r, n], axis=2)
                #Save
                save_full_path = os.path.join(args.output_dir,
                                              tile_names[sample_ind])
                sample_name = str(stochastic_sample_ind + 1).zfill(
                    4) + '_' + file_names[sample_ind] + '.npz'
                if not os.path.exists(save_full_path):
                    os.makedirs(save_full_path)
                np.savez(os.path.join(save_full_path, sample_name),
                         gen_images_)
                sample_ind += 1
            if stochastic_sample_ind < (args.num_stochastic_samples - 1):
                sample_ind -= args.batch_size
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                     "train, val, test, etc, or a directory containing "
                                                                     "the tfrecords")
    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
                                                 "results_gif_dir/model_fname")
    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
                                                 "results_png_dir/model_fname")
    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")

    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--num_stochastic_samples", type=int, default=5)
    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
    parser.add_argument("--fps", type=int, default=4)

    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    args.results_gif_dir = args.results_gif_dir or args.results_dir
    args.results_png_dir = args.results_png_dir or args.results_dir
    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")
        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError('dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError('model is required when checkpoint is not specified')
        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)

    print('----------------------------------- Options ------------------------------------')
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print('------------------------------------- End --------------------------------------')

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams)

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError('batch_size should evenly divide the dataset')

    inputs, target = dataset.make_batch(args.batch_size)
    if not isinstance(model, models.GroundTruthVideoPredictionModel):
        # remove ground truth data past context_frames to prevent accidentally using it
        for k, v in inputs.items():
            if k != 'actions':
                inputs[k] = v[:, :model.hparams.context_frames]

    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
    target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph')

    with tf.variable_scope(''):
        model.build_graph(input_phs, target_ph)

    for output_dir in (args.output_gif_dir, args.output_png_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results, target_result = sess.run([inputs, target])
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))

        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
        for stochastic_sample_ind in range(args.num_stochastic_samples):
            gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
            for i, gen_images_ in enumerate(gen_images):
                gen_images_ = (gen_images_ * 255.0).astype(np.uint8)

                gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
                save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
                         gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps)

                for t, gen_image in enumerate(gen_images_):
                    gen_image_fname = 'gen_image_%05d_%02d_%02d.png' % (sample_ind + i, stochastic_sample_ind, t)
                    gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
                    cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)

        sample_ind += args.batch_size
    def __init__(self, num_of_generated_frames, batch_size):
        seed = 7
        self.img_size = 64
        self.batch_size = batch_size
        tf.set_random_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        results_gif_dir = "../results_test/carla_50k"
        results_png_dir = "../results_test/carla_50k"
        dataset_hparams_dict = {}
        model_hparams_dict = {}
        checkpoint = "/media/eslam/426b7820-cb81-4c46-9430-be5429970ddb/home/eslam/Future_Imitiation/video_prediction-master/logs/carla_intel/ours_savp"
        # loading weights
        checkpoint_dir = os.path.normpath(checkpoint)
        if not os.path.isdir(checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)

        dataset = "carla_intel"
        model = "savp"
        mode = "test"
        num_epochs = 1
        gpu_mem_frac = 0
        self.num_stochastic_samples = 1
        self.fps = 4
        dataset_hparams = "sequence_length = " + str(4 + num_of_generated_frames)
        # TODO: should be changed to feed the 4 images directly to it.
        input_dir = "/media/eslam/426b7820-cb81-4c46-9430-be5429970ddb/home/eslam/Future_Imitiation/Intel_dataset/tf_record/test"
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")
        self.output_gif_dir = os.path.join(results_gif_dir, os.path.split(checkpoint_dir)[1])
        self.output_png_dir = os.path.join(results_png_dir, os.path.split(checkpoint_dir)[1])

        VideoDataset = datasets.get_dataset_class(dataset)
        dataset = VideoDataset(
            input_dir,
            mode=mode,
            num_epochs=num_epochs,
            seed=seed,
            hparams_dict=dataset_hparams_dict,
            hparams=dataset_hparams)

        VideoPredictionModel = models.get_model_class(model)
        hparams_dict = dict(model_hparams_dict)
        hparams_dict.update({
            'context_frames': dataset.hparams.context_frames,
            'sequence_length': dataset.hparams.sequence_length,
            'repeat': dataset.hparams.time_shift,
        })
        self.model = VideoPredictionModel(
            mode=mode,
            hparams_dict=hparams_dict,
            hparams=None)

        self.sequence_length = self.model.hparams.sequence_length
        self.gif_length = self.sequence_length
        context_frames = self.model.hparams.context_frames
        self.future_length = self.sequence_length - context_frames
        num_examples_per_epoch = dataset.num_examples_per_epoch()
        #if num_examples_per_epoch % self.batch_size != 0:
        #    raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)

        inputs = dataset.make_batch(self.batch_size)
        self.input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
        with tf.variable_scope(''):
            self.model.build_graph(self.input_phs)

        for output_dir in (self.output_gif_dir, self.output_png_dir):
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_frac)
        config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
        self.sess = tf.Session(config=config)
        self.sess.graph.as_default()
        self.model.restore(self.sess, checkpoint)
        self.sample_ind = 0