def generate_future_frames(self, frame_0, frame_1, frame_2, frame_3, debug=False):
        input_results = np.zeros((self.batch_size, self.sequence_length, self.img_size, self.img_size, 3))
        for row_num in range(self.batch_size):
            input_results[row_num][0] = cv2.resize(frame_0[row_num], (self.img_size, self.img_size))/255
            input_results[row_num][1] = cv2.resize(frame_1[row_num], (self.img_size, self.img_size))/255
            input_results[row_num][2] = cv2.resize(frame_2[row_num], (self.img_size, self.img_size))/255
            input_results[row_num][3] = cv2.resize(frame_3[row_num], (self.img_size, self.img_size))/255
        for name, input_ph in self.input_phs.items():
            feed_dict = {input_ph: input_results}

        for stochastic_sample_ind in range(self.num_stochastic_samples):
            gen_images = self.sess.run(self.model.outputs['gen_images'], feed_dict=feed_dict)
            # only keep the future frames
            gen_images = gen_images[:, -self.future_length:]
            if debug:
                for i, gen_images_ in enumerate(gen_images):
                    context_images_ = (input_results[i] * 255.0).astype(np.uint8)
                    gen_images_ = (gen_images_ * 255.0).astype(np.uint8)

                    gen_images_fname = 'gen_image_%05d_%02d.gif' % (self.sample_ind + i, stochastic_sample_ind)
                    context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
                    if self.gif_length:
                        context_and_gen_images = context_and_gen_images[:self.gif_length]
                    save_gif(os.path.join(self.output_gif_dir, gen_images_fname),
                             context_and_gen_images, fps=self.fps)

                    gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
                    for t, gen_image in enumerate(gen_images_):
                        gen_image_fname = gen_image_fname_pattern % (self.sample_ind + i, stochastic_sample_ind, t)
                        if gen_image.shape[-1] == 1:
                          gen_image = np.tile(gen_image, (1, 1, 3))
                        else:
                          gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
                        cv2.imwrite(os.path.join(self.output_png_dir, gen_image_fname), gen_image)

            self.sample_ind += self.batch_size
        return gen_images
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument(
        "--val_input_dirs",
        type=str,
        nargs='+',
        help="directories containing the tfrecords. default: [input_dir]")
    parser.add_argument("--logs_dir",
                        default='logs',
                        help="ignored if output_dir is specified")
    parser.add_argument(
        "--output_dir",
        help=
        "output directory where json files, summary, model, gifs, etc are saved. "
        "default is logs_dir/model_fname, where model_fname consists of "
        "information from model and model_hparams")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )
    parser.add_argument("--resume",
                        action='store_true',
                        help='resume from lastest checkpoint in output_dir.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--dataset_hparams_dict",
                        type=str,
                        help="a json file of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")
    parser.add_argument("--model_hparams_dict",
                        type=str,
                        help="a json file of model hyperparameters")

    parser.add_argument(
        "--summary_freq",
        type=int,
        default=1000,
        help=
        "save summaries (except for image and eval summaries) every summary_freq steps"
    )
    parser.add_argument(
        "--image_summary_freq",
        type=int,
        default=5000,
        help="save image summaries every image_summary_freq steps")
    parser.add_argument(
        "--eval_summary_freq",
        type=int,
        default=0,
        help="save eval summaries every eval_summary_freq steps")
    parser.add_argument("--progress_freq",
                        type=int,
                        default=100,
                        help="display progress every progress_freq steps")
    parser.add_argument("--metrics_freq",
                        type=int,
                        default=0,
                        help="run and display metrics every metrics_freq step")
    parser.add_argument(
        "--gif_freq",
        type=int,
        default=0,
        help="save gifs of predicted frames every gif_freq steps")
    parser.add_argument("--save_freq",
                        type=int,
                        default=5000,
                        help="save model every save_freq steps, 0 to disable")

    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    if args.output_dir is None:
        list_depth = 0
        model_fname = ''
        for t in ('model=%s,%s' % (args.model, args.model_hparams)):
            if t == '[':
                list_depth += 1
            if t == ']':
                list_depth -= 1
            if list_depth and t == ',':
                t = '..'
            if t in '=,':
                t = '.'
            if t in '[]':
                t = ''
            model_fname += t
        args.output_dir = os.path.join(args.logs_dir, model_fname)

    if args.resume:
        if args.checkpoint:
            raise ValueError('resume and checkpoint cannot both be specified')
        args.checkpoint = args.output_dir

    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.dataset_hparams_dict:
        with open(args.dataset_hparams_dict) as f:
            dataset_hparams_dict.update(json.loads(f.read()))
    if args.model_hparams_dict:
        with open(args.model_hparams_dict) as f:
            model_hparams_dict.update(json.loads(f.read()))
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict.update(json.loads(f.read()))
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict.update(json.loads(f.read()))
                model_hparams_dict.pop('num_gpus',
                                       None)  # backwards-compatibility
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    train_dataset = VideoDataset(args.input_dir,
                                 mode='train',
                                 hparams_dict=dataset_hparams_dict,
                                 hparams=args.dataset_hparams)
    val_input_dirs = args.val_input_dirs or [args.input_dir]
    val_datasets = [
        VideoDataset(val_input_dir,
                     mode='val',
                     hparams_dict=dataset_hparams_dict,
                     hparams=args.dataset_hparams)
        for val_input_dir in val_input_dirs
    ]
    if len(val_input_dirs) > 1:
        if isinstance(val_datasets[-1], datasets.KTHVideoDataset):
            val_datasets[-1].set_sequence_length(40)
        else:
            val_datasets[-1].set_sequence_length(30)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    train_model = VideoPredictionModel(
        mode='train',
        hparams_dict=override_hparams_dict(train_dataset),
        hparams=args.model_hparams)
    val_models = [
        VideoPredictionModel(mode='val',
                             hparams_dict=override_hparams_dict(val_dataset),
                             hparams=args.model_hparams)
        for val_dataset in val_datasets
    ]

    batch_size = train_model.hparams.batch_size
    with tf.variable_scope('') as training_scope:
        train_model.build_graph(*train_dataset.make_batch(batch_size))
    for val_model, val_dataset in zip(val_models, val_datasets):
        with tf.variable_scope(training_scope, reuse=True):
            val_model.build_graph(*val_dataset.make_batch(batch_size))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
    with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
        f.write(
            json.dumps(train_dataset.hparams.values(),
                       sort_keys=True,
                       indent=4))
    with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
        f.write(
            json.dumps(train_model.hparams.values(), sort_keys=True, indent=4))

    if args.gif_freq:
        val_model = val_models[0]
        val_tensors = OrderedDict()
        context_images = val_model.inputs['images'][:, :val_model.hparams.
                                                    context_frames]
        val_tensors['gen_images_vis'] = tf.concat(
            [context_images, val_model.gen_images], axis=1)
        if val_model.gen_images_enc is not None:
            val_tensors['gen_images_enc_vis'] = tf.concat(
                [context_images, val_model.gen_images_enc], axis=1)
        val_tensors.update({
            name: tensor
            for name, tensor in val_model.inputs.items()
            if tensor.shape.ndims >= 4
        })
        val_tensors['targets'] = val_model.targets
        val_tensors.update({
            name: tensor
            for name, tensor in val_model.outputs.items()
            if tensor.shape.ndims >= 4
        })
        val_tensor_clips = OrderedDict([
            (name, tf_utils.tensor_to_clip(output))
            for name, output in val_tensors.items()
        ])

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=3)
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))
    eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES))
    eval_image_summaries = image_summaries & eval_summaries
    image_summaries -= eval_image_summaries
    eval_summaries -= eval_image_summaries
    if args.summary_freq:
        summary_op = tf.summary.merge(summaries)
    if args.image_summary_freq:
        image_summary_op = tf.summary.merge(list(image_summaries))
    if args.eval_summary_freq:
        eval_summary_op = tf.summary.merge(list(eval_summaries))
        eval_image_summary_op = tf.summary.merge(list(eval_image_summaries))

    if args.summary_freq or args.image_summary_freq or args.eval_summary_freq:
        summary_writer = tf.summary.FileWriter(args.output_dir)

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    global_step = tf.train.get_or_create_global_step()
    max_steps = train_model.hparams.max_steps
    with tf.Session(config=config) as sess:
        print("parameter_count =", sess.run(parameter_count))

        sess.run(tf.global_variables_initializer())
        train_model.restore(sess, args.checkpoint)

        start_step = sess.run(global_step)
        # start at one step earlier to log everything without doing any training
        # step is relative to the start_step
        for step in range(-1, max_steps - start_step):
            if step == 0:
                start = time.time()

            def should(freq):
                return freq and ((step + 1) % freq == 0 or
                                 (step + 1) in (0, max_steps - start_step))

            fetches = {"global_step": global_step}
            if step >= 0:
                fetches["train_op"] = train_model.train_op

            if should(args.progress_freq):
                fetches['d_losses'] = train_model.d_losses
                fetches['g_losses'] = train_model.g_losses
                if isinstance(train_model.learning_rate, tf.Tensor):
                    fetches["learning_rate"] = train_model.learning_rate
            if should(args.metrics_freq):
                fetches['metrics'] = train_model.metrics
            if should(args.summary_freq):
                fetches["summary"] = summary_op
            if should(args.image_summary_freq):
                fetches["image_summary"] = image_summary_op
            if should(args.eval_summary_freq):
                fetches["eval_summary"] = eval_summary_op
                fetches["eval_image_summary"] = eval_image_summary_op

            run_start_time = time.time()
            results = sess.run(fetches)
            run_elapsed_time = time.time() - run_start_time
            if run_elapsed_time > 1.5:
                print('session.run took %0.1fs' % run_elapsed_time)

            if should(args.summary_freq):
                print("recording summary")
                summary_writer.add_summary(results["summary"],
                                           results["global_step"])
                print("done")
            if should(args.image_summary_freq):
                print("recording image summary")
                summary_writer.add_summary(
                    tf_utils.convert_tensor_to_gif_summary(
                        results["image_summary"]), results["global_step"])
                print("done")
            if should(args.eval_summary_freq):
                print("recording eval summary")
                summary_writer.add_summary(results["eval_summary"],
                                           results["global_step"])
                summary_writer.add_summary(
                    tf_utils.convert_tensor_to_gif_summary(
                        results["eval_image_summary"]), results["global_step"])
                print("done")
            if should(args.summary_freq) or should(
                    args.image_summary_freq) or should(args.eval_summary_freq):
                summary_writer.flush()
            if should(args.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                steps_per_epoch = math.ceil(
                    train_dataset.num_examples_per_epoch() / batch_size)
                train_epoch = math.ceil(results["global_step"] /
                                        steps_per_epoch)
                train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                print("progress  global step %d  epoch %d  step %d" %
                      (results["global_step"], train_epoch, train_step))
                if step >= 0:
                    elapsed_time = time.time() - start
                    average_time = elapsed_time / (step + 1)
                    images_per_sec = batch_size / average_time
                    remaining_time = (max_steps -
                                      (start_step + step)) * average_time
                    print(
                        "          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)"
                        %
                        (images_per_sec, remaining_time / 60, remaining_time /
                         60 / 60, remaining_time / 60 / 60 / 24))

                for name, loss in itertools.chain(results['d_losses'].items(),
                                                  results['g_losses'].items()):
                    print(name, loss)
                if isinstance(train_model.learning_rate, tf.Tensor):
                    print("learning_rate", results["learning_rate"])
            if should(args.metrics_freq):
                for name, metric in results['metrics'].items():
                    print(name, metric)

            if should(args.save_freq):
                print("saving model to", args.output_dir)
                saver.save(sess,
                           os.path.join(args.output_dir, "model"),
                           global_step=global_step)
                print("done")

            if should(args.gif_freq):
                image_dir = os.path.join(args.output_dir, 'images')
                if not os.path.exists(image_dir):
                    os.makedirs(image_dir)

                gif_clips = sess.run(val_tensor_clips)
                gif_step = results["global_step"]
                for name, clip in gif_clips.items():
                    filename = "%08d-%s.gif" % (gif_step, name)
                    print("saving gif to", os.path.join(image_dir, filename))
                    ffmpeg_gif.save_gif(os.path.join(image_dir, filename),
                                        clip,
                                        fps=4)
                    print("done")
示例#3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",
                        type=str,
                        required=True,
                        help="either a directory containing subdirectories "
                        "train, val, test, etc, or a directory containing "
                        "the tfrecords")
    parser.add_argument("--results_dir",
                        type=str,
                        default='results',
                        help="ignored if output_gif_dir is specified")
    parser.add_argument(
        "--results_gif_dir",
        type=str,
        help="default is results_dir. ignored if output_gif_dir is specified")
    parser.add_argument(
        "--results_png_dir",
        type=str,
        help="default is results_dir. ignored if output_png_dir is specified")
    parser.add_argument(
        "--output_gif_dir",
        help="output directory where samples are saved as gifs. default is "
        "results_gif_dir/model_fname")
    parser.add_argument(
        "--output_png_dir",
        help="output directory where samples are saved as pngs. default is "
        "results_png_dir/model_fname")
    parser.add_argument(
        "--checkpoint",
        help=
        "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)"
    )

    parser.add_argument("--mode",
                        type=str,
                        choices=['val', 'test'],
                        default='val',
                        help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument(
        "--dataset_hparams",
        type=str,
        help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument(
        "--model_hparams",
        type=str,
        help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size",
                        type=int,
                        default=8,
                        help="number of samples in batch")
    parser.add_argument(
        "--num_samples",
        type=int,
        help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--num_stochastic_samples", type=int, default=5)
    parser.add_argument("--gif_length",
                        type=int,
                        help="default is sequence_length")
    parser.add_argument("--fps", type=int, default=4)

    parser.add_argument("--gpu_mem_frac",
                        type=float,
                        default=0,
                        help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    args.results_gif_dir = args.results_gif_dir or args.results_dir
    args.results_png_dir = args.results_png_dir or args.results_dir
    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir,
                                   "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "dataset_hparams.json was not loaded because it does not exist"
            )
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print(
                "model_hparams.json was not loaded because it does not exist")
        args.output_gif_dir = args.output_gif_dir or os.path.join(
            args.results_gif_dir,
            os.path.split(checkpoint_dir)[1])
        args.output_png_dir = args.output_png_dir or os.path.join(
            args.results_png_dir,
            os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError(
                'dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError(
                'model is required when checkpoint is not specified')
        args.output_gif_dir = args.output_gif_dir or os.path.join(
            args.results_gif_dir, 'model.%s' % args.model)
        args.output_png_dir = args.output_png_dir or os.path.join(
            args.results_png_dir, 'model.%s' % args.model)

    print(
        '----------------------------------- Options ------------------------------------'
    )
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print(
        '------------------------------------- End --------------------------------------'
    )

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir,
                           mode=args.mode,
                           num_epochs=args.num_epochs,
                           seed=args.seed,
                           hparams_dict=dataset_hparams_dict,
                           hparams=args.dataset_hparams)

    VideoPredictionModel = models.get_model_class(args.model)
    hparams_dict = dict(model_hparams_dict)
    hparams_dict.update({
        'context_frames': dataset.hparams.context_frames,
        'sequence_length': dataset.hparams.sequence_length,
        'repeat': dataset.hparams.time_shift,
    })
    model = VideoPredictionModel(mode=args.mode,
                                 hparams_dict=hparams_dict,
                                 hparams=args.model_hparams)

    sequence_length = model.hparams.sequence_length
    context_frames = model.hparams.context_frames
    future_length = sequence_length - context_frames

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    print('HELLOO', num_examples_per_epoch, args.batch_size)
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError(
            'batch_size should evenly divide the dataset size %d' %
            num_examples_per_epoch)

    inputs = dataset.make_batch(args.batch_size)
    input_phs = {
        k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k)
        for k, v in inputs.items()
    }
    with tf.variable_scope(''):
        model.build_graph(input_phs)

    for output_dir in (args.output_gif_dir, args.output_png_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
            f.write(
                json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
            f.write(
                json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.graph.as_default()

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    # dir_every_n = 128
    key = args.input_dir.split('/')[-1]
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results = sess.run(inputs)
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" %
              (sample_ind, sample_ind + args.batch_size))
        # if sample_ind % dir_every_n == 0:
        #     output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + dir_every_n))
        feed_dict = {
            input_ph: input_results[name]
            for name, input_ph in input_phs.items()
        }
        # for stochastic_sample_ind in range(args.num_stochastic_samples):
        for stochastic_sample_ind in range(1):
            gen_images = sess.run(model.outputs['gen_images'],
                                  feed_dict=feed_dict)
            # only keep the future frames
            gen_images = gen_images[:, -future_length:]
            for i, gen_images_ in enumerate(gen_images):
                context_images_ = (input_results['images'][i] * 255.0).astype(
                    np.uint8)
                gen_images_ = (gen_images_ * 255.0).astype(np.uint8)

                gen_images_fname = 'gen_image_%05d_%02d.gif' % (
                    sample_ind + i, stochastic_sample_ind)
                context_and_gen_images = list(
                    context_images_[:context_frames]) + list(gen_images_)
                if args.gif_length:
                    context_and_gen_images = context_and_gen_images[:args.
                                                                    gif_length]
                # output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_images_fname)
                # print('SAVE GIF', output_dir, gen_images_fname)
                save_gif(os.path.join(args.output_gif_dir, key,
                                      gen_images_fname),
                         context_and_gen_images,
                         fps=args.fps)

                gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(
                    2, len(str(len(gen_images_) - 1)))
                for t, gen_image in enumerate(gen_images_):
                    gen_image_fname = gen_image_fname_pattern % (
                        sample_ind + i, stochastic_sample_ind, t)
                    if gen_image.shape[-1] == 1:
                        gen_image = np.tile(gen_image, (1, 1, 3))
                    else:
                        gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
                    # output_dir = os.path.join(args.output_png_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_image_fname)
                    print('SAVE PNG', key)
                    cv2.imwrite(
                        os.path.join(args.output_png_dir, key,
                                     gen_image_fname), gen_image)

        sample_ind += args.batch_size
示例#4
0
                            gen_images_ = (gen_images_ * 255.0).astype(
                                np.uint8)

                            gen_images_fname = 'best_gen_images.gif'

                            context_and_gen_images = list(
                                context_images_[:context_frames]) + list(
                                    gen_images_)
                            if args.gif_length:
                                context_and_gen_images = context_and_gen_images[:
                                                                                args
                                                                                .
                                                                                gif_length]
                            save_gif(os.path.join(
                                '/home/mcube/yenchen/savp-omnipush/data/',
                                gen_images_fname),
                                     context_and_gen_images,
                                     fps=args.fps)

            best_gen_last_image = (gen_last_imgs[costs.index(min(costs))] *
                                   255.0).astype(np.uint8)
            best_gen_last_image_name = '/home/mcube/yenchen/savp-omnipush/data/gen_img.png'
            cv2.imwrite(best_gen_last_image_name,
                        cv2.cvtColor(best_gen_last_image, cv2.COLOR_RGB2BGR))

            cem_mean_reshaped = cem.mean.reshape((cem.n_steps, cem.a_dim))
            mean_radians = np.arctan2(cem_mean_reshaped[:, 0],
                                      cem_mean_reshaped[:, 1])
            print('---')
            print("round: %d, min cost: %.3f" % (r, min(costs)))
            print("Best action: {}".format(best_action))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                     "train, val, test, etc, or a directory containing "
                                                                     "the tfrecords")
    parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified")
    parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified")
    parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified")
    parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is "
                                                 "results_gif_dir/model_fname")
    parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is "
                                                 "results_png_dir/model_fname")
    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")

    parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.')

    parser.add_argument("--dataset", type=str, help="dataset class name")
    parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
    parser.add_argument("--model", type=str, help="model class name")
    parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")

    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
    parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)")
    parser.add_argument("--num_epochs", type=int, default=1)

    parser.add_argument("--num_stochastic_samples", type=int, default=5)
    parser.add_argument("--gif_length", type=int, help="default is sequence_length")
    parser.add_argument("--fps", type=int, default=4)

    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
    parser.add_argument("--seed", type=int, default=7)

    args = parser.parse_args()

    if args.seed is not None:
        tf.set_random_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

    args.results_gif_dir = args.results_gif_dir or args.results_dir
    args.results_png_dir = args.results_png_dir or args.results_dir
    dataset_hparams_dict = {}
    model_hparams_dict = {}
    if args.checkpoint:
        checkpoint_dir = os.path.normpath(args.checkpoint)
        if not os.path.exists(checkpoint_dir):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
        if not os.path.isdir(args.checkpoint):
            checkpoint_dir, _ = os.path.split(checkpoint_dir)
        with open(os.path.join(checkpoint_dir, "options.json")) as f:
            print("loading options from checkpoint %s" % args.checkpoint)
            options = json.loads(f.read())
            args.dataset = args.dataset or options['dataset']
            args.model = args.model or options['model']
        try:
            with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                dataset_hparams_dict = json.loads(f.read())
        except FileNotFoundError:
            print("dataset_hparams.json was not loaded because it does not exist")
        try:
            with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                model_hparams_dict = json.loads(f.read())
                model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
        except FileNotFoundError:
            print("model_hparams.json was not loaded because it does not exist")
        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1])
        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1])
    else:
        if not args.dataset:
            raise ValueError('dataset is required when checkpoint is not specified')
        if not args.model:
            raise ValueError('model is required when checkpoint is not specified')
        args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
        args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)

    print('----------------------------------- Options ------------------------------------')
    for k, v in args._get_kwargs():
        print(k, "=", v)
    print('------------------------------------- End --------------------------------------')

    VideoDataset = datasets.get_dataset_class(args.dataset)
    dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed,
                           hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)

    def override_hparams_dict(dataset):
        hparams_dict = dict(model_hparams_dict)
        hparams_dict['context_frames'] = dataset.hparams.context_frames
        hparams_dict['sequence_length'] = dataset.hparams.sequence_length
        hparams_dict['repeat'] = dataset.hparams.time_shift
        return hparams_dict

    VideoPredictionModel = models.get_model_class(args.model)
    model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams)

    if args.num_samples:
        if args.num_samples > dataset.num_examples_per_epoch():
            raise ValueError('num_samples cannot be larger than the dataset')
        num_examples_per_epoch = args.num_samples
    else:
        num_examples_per_epoch = dataset.num_examples_per_epoch()
    if num_examples_per_epoch % args.batch_size != 0:
        raise ValueError('batch_size should evenly divide the dataset')

    inputs, target = dataset.make_batch(args.batch_size)
    if not isinstance(model, models.GroundTruthVideoPredictionModel):
        # remove ground truth data past context_frames to prevent accidentally using it
        for k, v in inputs.items():
            if k != 'actions':
                inputs[k] = v[:, :model.hparams.context_frames]

    input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
    target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph')

    with tf.variable_scope(''):
        model.build_graph(input_phs, target_ph)

    for output_dir in (args.output_gif_dir, args.output_png_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
            f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
        with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)

    model.restore(sess, args.checkpoint)

    sample_ind = 0
    while True:
        if args.num_samples and sample_ind >= args.num_samples:
            break
        try:
            input_results, target_result = sess.run([inputs, target])
        except tf.errors.OutOfRangeError:
            break
        print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size))

        feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
        for stochastic_sample_ind in range(args.num_stochastic_samples):
            gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
            for i, gen_images_ in enumerate(gen_images):
                gen_images_ = (gen_images_ * 255.0).astype(np.uint8)

                gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
                save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
                         gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps)

                for t, gen_image in enumerate(gen_images_):
                    gen_image_fname = 'gen_image_%05d_%02d_%02d.png' % (sample_ind + i, stochastic_sample_ind, t)
                    gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
                    cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)

        sample_ind += args.batch_size