def generate_future_frames(self, frame_0, frame_1, frame_2, frame_3, debug=False): input_results = np.zeros((self.batch_size, self.sequence_length, self.img_size, self.img_size, 3)) for row_num in range(self.batch_size): input_results[row_num][0] = cv2.resize(frame_0[row_num], (self.img_size, self.img_size))/255 input_results[row_num][1] = cv2.resize(frame_1[row_num], (self.img_size, self.img_size))/255 input_results[row_num][2] = cv2.resize(frame_2[row_num], (self.img_size, self.img_size))/255 input_results[row_num][3] = cv2.resize(frame_3[row_num], (self.img_size, self.img_size))/255 for name, input_ph in self.input_phs.items(): feed_dict = {input_ph: input_results} for stochastic_sample_ind in range(self.num_stochastic_samples): gen_images = self.sess.run(self.model.outputs['gen_images'], feed_dict=feed_dict) # only keep the future frames gen_images = gen_images[:, -self.future_length:] if debug: for i, gen_images_ in enumerate(gen_images): context_images_ = (input_results[i] * 255.0).astype(np.uint8) gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % (self.sample_ind + i, stochastic_sample_ind) context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) if self.gif_length: context_and_gen_images = context_and_gen_images[:self.gif_length] save_gif(os.path.join(self.output_gif_dir, gen_images_fname), context_and_gen_images, fps=self.fps) gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) for t, gen_image in enumerate(gen_images_): gen_image_fname = gen_image_fname_pattern % (self.sample_ind + i, stochastic_sample_ind, t) if gen_image.shape[-1] == 1: gen_image = np.tile(gen_image, (1, 1, 3)) else: gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(self.output_png_dir, gen_image_fname), gen_image) self.sample_ind += self.batch_size return gen_images
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument( "--val_input_dirs", type=str, nargs='+', help="directories containing the tfrecords. default: [input_dir]") parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") parser.add_argument( "--output_dir", help= "output directory where json files, summary, model, gifs, etc are saved. " "default is logs_dir/model_fname, where model_fname consists of " "information from model and model_hparams") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") parser.add_argument( "--summary_freq", type=int, default=1000, help= "save summaries (except for image and eval summaries) every summary_freq steps" ) parser.add_argument( "--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps") parser.add_argument( "--eval_summary_freq", type=int, default=0, help="save eval summaries every eval_summary_freq steps") parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") parser.add_argument("--metrics_freq", type=int, default=0, help="run and display metrics every metrics_freq step") parser.add_argument( "--gif_freq", type=int, default=0, help="save gifs of predicted frames every gif_freq steps") parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable") parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.output_dir is None: list_depth = 0 model_fname = '' for t in ('model=%s,%s' % (args.model, args.model_hparams)): if t == '[': list_depth += 1 if t == ']': list_depth -= 1 if list_depth and t == ',': t = '..' if t in '=,': t = '.' if t in '[]': t = '' model_fname += t args.output_dir = os.path.join(args.logs_dir, model_fname) if args.resume: if args.checkpoint: raise ValueError('resume and checkpoint cannot both be specified') args.checkpoint = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.dataset_hparams_dict: with open(args.dataset_hparams_dict) as f: dataset_hparams_dict.update(json.loads(f.read())) if args.model_hparams_dict: with open(args.model_hparams_dict) as f: model_hparams_dict.update(json.loads(f.read())) if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict.update(json.loads(f.read())) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_input_dirs = args.val_input_dirs or [args.input_dir] val_datasets = [ VideoDataset(val_input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) for val_input_dir in val_input_dirs ] if len(val_input_dirs) > 1: if isinstance(val_datasets[-1], datasets.KTHVideoDataset): val_datasets[-1].set_sequence_length(40) else: val_datasets[-1].set_sequence_length(30) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift return hparams_dict VideoPredictionModel = models.get_model_class(args.model) train_model = VideoPredictionModel( mode='train', hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams) val_models = [ VideoPredictionModel(mode='val', hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams) for val_dataset in val_datasets ] batch_size = train_model.hparams.batch_size with tf.variable_scope('') as training_scope: train_model.build_graph(*train_dataset.make_batch(batch_size)) for val_model, val_dataset in zip(val_models, val_datasets): with tf.variable_scope(training_scope, reuse=True): val_model.build_graph(*val_dataset.make_batch(batch_size)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write( json.dumps(train_model.hparams.values(), sort_keys=True, indent=4)) if args.gif_freq: val_model = val_models[0] val_tensors = OrderedDict() context_images = val_model.inputs['images'][:, :val_model.hparams. context_frames] val_tensors['gen_images_vis'] = tf.concat( [context_images, val_model.gen_images], axis=1) if val_model.gen_images_enc is not None: val_tensors['gen_images_enc_vis'] = tf.concat( [context_images, val_model.gen_images_enc], axis=1) val_tensors.update({ name: tensor for name, tensor in val_model.inputs.items() if tensor.shape.ndims >= 4 }) val_tensors['targets'] = val_model.targets val_tensors.update({ name: tensor for name, tensor in val_model.outputs.items() if tensor.shape.ndims >= 4 }) val_tensor_clips = OrderedDict([ (name, tf_utils.tensor_to_clip(output)) for name, output in val_tensors.items() ]) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) saver = tf.train.Saver(max_to_keep=3) summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)) eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES)) eval_image_summaries = image_summaries & eval_summaries image_summaries -= eval_image_summaries eval_summaries -= eval_image_summaries if args.summary_freq: summary_op = tf.summary.merge(summaries) if args.image_summary_freq: image_summary_op = tf.summary.merge(list(image_summaries)) if args.eval_summary_freq: eval_summary_op = tf.summary.merge(list(eval_summaries)) eval_image_summary_op = tf.summary.merge(list(eval_image_summaries)) if args.summary_freq or args.image_summary_freq or args.eval_summary_freq: summary_writer = tf.summary.FileWriter(args.output_dir) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() max_steps = train_model.hparams.max_steps with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) train_model.restore(sess, args.checkpoint) start_step = sess.run(global_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, max_steps - start_step): if step == 0: start = time.time() def should(freq): return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) fetches = {"global_step": global_step} if step >= 0: fetches["train_op"] = train_model.train_op if should(args.progress_freq): fetches['d_losses'] = train_model.d_losses fetches['g_losses'] = train_model.g_losses if isinstance(train_model.learning_rate, tf.Tensor): fetches["learning_rate"] = train_model.learning_rate if should(args.metrics_freq): fetches['metrics'] = train_model.metrics if should(args.summary_freq): fetches["summary"] = summary_op if should(args.image_summary_freq): fetches["image_summary"] = image_summary_op if should(args.eval_summary_freq): fetches["eval_summary"] = eval_summary_op fetches["eval_image_summary"] = eval_image_summary_op run_start_time = time.time() results = sess.run(fetches) run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5: print('session.run took %0.1fs' % run_elapsed_time) if should(args.summary_freq): print("recording summary") summary_writer.add_summary(results["summary"], results["global_step"]) print("done") if should(args.image_summary_freq): print("recording image summary") summary_writer.add_summary( tf_utils.convert_tensor_to_gif_summary( results["image_summary"]), results["global_step"]) print("done") if should(args.eval_summary_freq): print("recording eval summary") summary_writer.add_summary(results["eval_summary"], results["global_step"]) summary_writer.add_summary( tf_utils.convert_tensor_to_gif_summary( results["eval_image_summary"]), results["global_step"]) print("done") if should(args.summary_freq) or should( args.image_summary_freq) or should(args.eval_summary_freq): summary_writer.flush() if should(args.progress_freq): # global_step will have the correct step count if we resume from a checkpoint steps_per_epoch = math.ceil( train_dataset.num_examples_per_epoch() / batch_size) train_epoch = math.ceil(results["global_step"] / steps_per_epoch) train_step = (results["global_step"] - 1) % steps_per_epoch + 1 print("progress global step %d epoch %d step %d" % (results["global_step"], train_epoch, train_step)) if step >= 0: elapsed_time = time.time() - start average_time = elapsed_time / (step + 1) images_per_sec = batch_size / average_time remaining_time = (max_steps - (start_step + step)) * average_time print( " image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) for name, loss in itertools.chain(results['d_losses'].items(), results['g_losses'].items()): print(name, loss) if isinstance(train_model.learning_rate, tf.Tensor): print("learning_rate", results["learning_rate"]) if should(args.metrics_freq): for name, metric in results['metrics'].items(): print(name, metric) if should(args.save_freq): print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) print("done") if should(args.gif_freq): image_dir = os.path.join(args.output_dir, 'images') if not os.path.exists(image_dir): os.makedirs(image_dir) gif_clips = sess.run(val_tensor_clips) gif_step = results["global_step"] for name, clip in gif_clips.items(): filename = "%08d-%s.gif" % (gif_step, name) print("saving gif to", os.path.join(image_dir, filename)) ffmpeg_gif.save_gif(os.path.join(image_dir, filename), clip, fps=4) print("done")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") parser.add_argument( "--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") parser.add_argument( "--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") parser.add_argument( "--output_gif_dir", help="output directory where samples are saved as gifs. default is " "results_gif_dir/model_fname") parser.add_argument( "--output_png_dir", help="output directory where samples are saved as pngs. default is " "results_png_dir/model_fname") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument( "--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--num_stochastic_samples", type=int, default=5) parser.add_argument("--gif_length", type=int, help="default is sequence_length") parser.add_argument("--fps", type=int, default=4) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) args.results_gif_dir = args.results_gif_dir or args.results_dir args.results_png_dir = args.results_png_dir or args.results_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") args.output_gif_dir = args.output_gif_dir or os.path.join( args.results_gif_dir, os.path.split(checkpoint_dir)[1]) args.output_png_dir = args.output_png_dir or os.path.join( args.results_png_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_gif_dir = args.output_gif_dir or os.path.join( args.results_gif_dir, 'model.%s' % args.model) args.output_png_dir = args.output_png_dir or os.path.join( args.results_png_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': dataset.hparams.context_frames, 'sequence_length': dataset.hparams.sequence_length, 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel(mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams) sequence_length = model.hparams.sequence_length context_frames = model.hparams.context_frames future_length = sequence_length - context_frames if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() print('HELLOO', num_examples_per_epoch, args.batch_size) if num_examples_per_epoch % args.batch_size != 0: raise ValueError( 'batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) inputs = dataset.make_batch(args.batch_size) input_phs = { k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items() } with tf.variable_scope(''): model.build_graph(input_phs) for output_dir in (args.output_gif_dir, args.output_png_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write( json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) sess.graph.as_default() model.restore(sess, args.checkpoint) sample_ind = 0 # dir_every_n = 128 key = args.input_dir.split('/')[-1] while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results = sess.run(inputs) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) # if sample_ind % dir_every_n == 0: # output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + dir_every_n)) feed_dict = { input_ph: input_results[name] for name, input_ph in input_phs.items() } # for stochastic_sample_ind in range(args.num_stochastic_samples): for stochastic_sample_ind in range(1): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) # only keep the future frames gen_images = gen_images[:, -future_length:] for i, gen_images_ in enumerate(gen_images): context_images_ = (input_results['images'][i] * 255.0).astype( np.uint8) gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % ( sample_ind + i, stochastic_sample_ind) context_and_gen_images = list( context_images_[:context_frames]) + list(gen_images_) if args.gif_length: context_and_gen_images = context_and_gen_images[:args. gif_length] # output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_images_fname) # print('SAVE GIF', output_dir, gen_images_fname) save_gif(os.path.join(args.output_gif_dir, key, gen_images_fname), context_and_gen_images, fps=args.fps) gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max( 2, len(str(len(gen_images_) - 1))) for t, gen_image in enumerate(gen_images_): gen_image_fname = gen_image_fname_pattern % ( sample_ind + i, stochastic_sample_ind, t) if gen_image.shape[-1] == 1: gen_image = np.tile(gen_image, (1, 1, 3)) else: gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) # output_dir = os.path.join(args.output_png_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_image_fname) print('SAVE PNG', key) cv2.imwrite( os.path.join(args.output_png_dir, key, gen_image_fname), gen_image) sample_ind += args.batch_size
gen_images_ = (gen_images_ * 255.0).astype( np.uint8) gen_images_fname = 'best_gen_images.gif' context_and_gen_images = list( context_images_[:context_frames]) + list( gen_images_) if args.gif_length: context_and_gen_images = context_and_gen_images[: args . gif_length] save_gif(os.path.join( '/home/mcube/yenchen/savp-omnipush/data/', gen_images_fname), context_and_gen_images, fps=args.fps) best_gen_last_image = (gen_last_imgs[costs.index(min(costs))] * 255.0).astype(np.uint8) best_gen_last_image_name = '/home/mcube/yenchen/savp-omnipush/data/gen_img.png' cv2.imwrite(best_gen_last_image_name, cv2.cvtColor(best_gen_last_image, cv2.COLOR_RGB2BGR)) cem_mean_reshaped = cem.mean.reshape((cem.n_steps, cem.a_dim)) mean_radians = np.arctan2(cem_mean_reshaped[:, 0], cem_mean_reshaped[:, 1]) print('---') print("round: %d, min cost: %.3f" % (r, min(costs))) print("Best action: {}".format(best_action))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " "results_gif_dir/model_fname") parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " "results_png_dir/model_fname") parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--num_stochastic_samples", type=int, default=5) parser.add_argument("--gif_length", type=int, help="default is sequence_length") parser.add_argument("--fps", type=int, default=4) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) args.results_gif_dir = args.results_gif_dir or args.results_dir args.results_png_dir = args.results_png_dir or args.results_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print("dataset_hparams.json was not loaded because it does not exist") try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError('dataset is required when checkpoint is not specified') if not args.model: raise ValueError('model is required when checkpoint is not specified') args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift return hparams_dict VideoPredictionModel = models.get_model_class(args.model) model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams) if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: raise ValueError('batch_size should evenly divide the dataset') inputs, target = dataset.make_batch(args.batch_size) if not isinstance(model, models.GroundTruthVideoPredictionModel): # remove ground truth data past context_frames to prevent accidentally using it for k, v in inputs.items(): if k != 'actions': inputs[k] = v[:, :model.hparams.context_frames] input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph') with tf.variable_scope(''): model.build_graph(input_phs, target_ph) for output_dir in (args.output_gif_dir, args.output_png_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) model.restore(sess, args.checkpoint) sample_ind = 0 while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results, target_result = sess.run([inputs, target]) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} for stochastic_sample_ind in range(args.num_stochastic_samples): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) for i, gen_images_ in enumerate(gen_images): gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) save_gif(os.path.join(args.output_gif_dir, gen_images_fname), gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps) for t, gen_image in enumerate(gen_images_): gen_image_fname = 'gen_image_%05d_%02d_%02d.png' % (sample_ind + i, stochastic_sample_ind, t) gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) sample_ind += args.batch_size