def build_model(checkpoint, model_str, model_hparams, input_placeholders, context_length, sequence_length): model_hparams_dict = {} checkpoint_dir = os.path.normpath(checkpoint) if not os.path.isdir(checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % checkpoint) options = json.loads(f.read()) model_str = model_str or options['model'] try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") VideoPredictionModel = models.get_model_class(model_str) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': context_length, 'sequence_length': sequence_length, }) model = VideoPredictionModel(mode='test', hparams_dict=hparams_dict, hparams=model_hparams) with tf.variable_scope(''): model.build_graph(input_placeholders) return model
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") parser.add_argument( "--output_dir", help= "output directory where results are saved. default is results_dir/model_fname, " "where model_fname is the directory name of checkpoint") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=4, help="number of samples in batch") parser.add_argument( "--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)') parser.add_argument("--only_metrics", action='store_true') parser.add_argument("--num_stochastic_samples", type=int, default=100) parser.add_argument("--eval_parallel_iterations", type=int, default=10) parser.add_argument("--gpu_mem_frac", type=float, default=0.8, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int) parser.add_argument("--meta_batch_size", type=int, default=5, help="how many inner-loops to run") parser.add_argument("--inner_iters", type=int, default=5, help="number of inner-loop iterations") parser.add_argument("--meta_step_size", type=float, default=1.0, help="initial step size of meta optimization") parser.add_argument("--final_meta_step_size", type=float, default=0.0, help="final sep size of meta optimization") args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") args.output_dir = args.output_dir or os.path.join( args.results_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_dir = args.output_dir or os.path.join( args.results_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) # Dataset val_sets = read_dataset(args.input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_sets = list(val_sets) # Backward compatibility, used to set hypermeter for others train_dataset = val_sets[0] variable_scope = tf.get_variable_scope() variable_scope.set_use_resource(True) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': train_dataset.hparams.context_frames, 'sequence_length': train_dataset.hparams.sequence_length, 'repeat': train_dataset.hparams.time_shift, }) model = VideoPredictionModel( hparams_dict=hparams_dict, hparams=args.model_hparams, eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations) batch_size = args.batch_size assert batch_size == 4 output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(val_sets[0].hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) """ Each val set contains 10 examples, use first 4 for val train, next 4 for val test, drop last 2""" # Val train set val_train_tf_datasets = [ dataset.make_dataset(batch_size, skip_size=0, take_size=4) for dataset in val_sets ] val_train_iterators = [ tf_dataset.make_one_shot_iterator() for tf_dataset in val_train_tf_datasets ] val_train_handles = [ iterator.string_handle() for iterator in val_train_iterators ] # Val test set val_test_tf_datasets = [ dataset.make_dataset(batch_size, skip_size=4, take_size=4) for dataset in val_sets ] val_test_iterators = [ tf_dataset.make_one_shot_iterator() for tf_dataset in val_test_tf_datasets ] val_test_handles = [ iterator.string_handle() for iterator in val_test_iterators ] # Backward compatibility, use first train set to build graph val_train_handle = val_train_handles[0] iterator = tf.data.Iterator.from_string_handle( val_train_handle, val_train_tf_datasets[0].output_types, val_train_tf_datasets[0].output_shapes) inputs = iterator.get_next() # inputs comes from the "first training dataset" by default, unless train_handle is remapped to other handles model.build_graph(inputs) with tf.name_scope("parameter_count"): # exclude trainable variables that are replicas (used in multi-gpu setting) trainable_variables = set(tf.trainable_variables()) & set( model.saveable_variables) parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) model.restore(sess, args.checkpoint) print("parameter_count =", sess.run(parameter_count)) print("number of test sets =", len(val_sets)) # Evaluate handle for each dataset val_train_handle_evals = [sess.run(handle) for handle in val_train_handles] val_test_handle_evals = [sess.run(handle) for handle in val_test_handles] # Set input for the first step current_handle_eval = random.choice(val_train_handle_evals) # Set up variables recorder model._state = VariableState(sess, tf.trainable_variables()) sess.graph.finalize() sample_ind = 0 for i, (val_train_handle_eval, val_test_handle_eval) in enumerate( zip(val_train_handle_evals, val_test_handle_evals)): print('evaluating %d / %d test set' % (i + 1, len(val_train_handle_evals))) old_vars = model._state.export_variables() # Inner update train_fetches = {"train_op": model.train_op} for i in range(args.inner_iters): _ = sess.run(train_fetches, feed_dict={val_train_handle: val_train_handle_eval}) # compute "best" metrics using the computation graph fetches = {'images': model.inputs['images']} fetches.update(model.eval_outputs.items()) fetches.update(model.eval_metrics.items()) results = sess.run(fetches, feed_dict={val_train_handle: val_test_handle_eval}) save_prediction_eval_results( os.path.join(output_dir, 'prediction_eval'), results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) # Return to original model parameters model._state.import_variables(old_vars) sample_ind += args.batch_size metric_fnames = [] metric_names = ['psnr', 'ssim', 'lpips'] subtasks = ['max'] for metric_name in metric_names: for subtask in subtasks: metric_fnames.append( os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) for metric_fname in metric_fnames: task_name, _, metric_name = metric_fname.split('/')[-3:] metric = load_metrics(metric_fname) print('=' * 31) print(task_name, metric_name) print('-' * 31) metric_header_format = '{:>10} {:>20}' metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})' print( metric_header_format.format('time step', os.path.split(metric_fname)[1])) for t, (metric_mean, metric_std) in enumerate( zip(metric.mean(axis=0), metric.std(axis=0))): print(metric_row_format.format(t, metric_mean, metric_std)) print( metric_row_format.format('mean (std)', metric.mean(), metric.std())) print('=' * 31)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " "default is logs_dir/model_fname, where model_fname consists of " "information from model and model_hparams") parser.add_argument("--output_dir_postfix", default="") parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") parser.add_argument("--debug_num_datasets", type=int, default=-1, help="number of dataset to use") parser.add_argument("--summary_freq", type=int, default=10, help="save frequency of summaries (except for image and eval summaries) for train/validation set") parser.add_argument("--image_summary_freq", type=int, default=50, help="save frequency of image summaries for train/validation set") parser.add_argument("--eval_summary_freq", type=int, default=100, help="save frequency of eval summaries for train/validation set") parser.add_argument("--accum_eval_summary_freq", type=int, default=400, help="save frequency of accumulated eval summaries for validation set only") parser.add_argument("--progress_freq", type=int, default=10, help="display progress every progress_freq steps") parser.add_argument("--save_freq", type=int, default=50, help="save frequence of model, 0 to disable") parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") parser.add_argument("--gpu_mem_frac", type=float, default=0.8, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int) parser.add_argument("--meta_batch_size", type=int, default=8, help="how many inner-loops to run") parser.add_argument("--exp_size", type=int, default=5, help="how many videos to compute embedding") parser.add_argument("--inner_iters", type=int, default=1, help="number of inner-loop iterations") parser.add_argument("--meta_step_size", type=float, default=1.0, help="initial step size of meta optimization") parser.add_argument("--final_meta_step_size", type=float, default=1.0, help="final sep size of meta optimization") args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.output_dir is None: list_depth = 0 model_fname = '' for t in ('model=%s,%s' % (args.model, args.model_hparams)): if t == '[': list_depth += 1 if t == ']': list_depth -= 1 if list_depth and t == ',': t = '..' if t in '=,': t = '.' if t in '[]': t = '' model_fname += t args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix if args.resume: if args.checkpoint: raise ValueError('resume and checkpoint cannot both be specified') args.checkpoint = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.dataset_hparams_dict: with open(args.dataset_hparams_dict) as f: dataset_hparams_dict.update(json.loads(f.read())) if args.model_hparams_dict: with open(args.model_hparams_dict) as f: model_hparams_dict.update(json.loads(f.read())) if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print("dataset_hparams.json was not loaded because it does not exist") try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') # Dataset train_sets = read_dataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_sets = read_dataset(args.input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) train_sets = list(train_sets) val_sets = list(val_sets) # Backward compatibility, used to set hypermeter for others train_dataset = train_sets[0] variable_scope = tf.get_variable_scope() variable_scope.set_use_resource(True) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': train_dataset.hparams.context_frames, 'sequence_length': train_dataset.hparams.sequence_length, 'repeat': train_dataset.hparams.time_shift, }) model = VideoPredictionModel( hparams_dict=hparams_dict, hparams=args.model_hparams, aggregate_nccl=args.aggregate_nccl) model.exp_size = args.exp_size batch_size = model.hparams.batch_size # Train set train_tf_datasets = [dataset.make_dataset(batch_size) for dataset in train_sets] train_iterators = [tf_dataset.make_one_shot_iterator() for tf_dataset in train_tf_datasets] train_handles = [iterator.string_handle() for iterator in train_iterators] # Val train set val_tf_datasets = [dataset.make_dataset(batch_size) for dataset in val_sets] val_iterators = [tf_dataset.make_one_shot_iterator() for tf_dataset in val_tf_datasets] val_handles = [iterator.string_handle() for iterator in val_iterators] # Backward compatibility, use first train set to build graph train_handle = train_handles[0] iterator = tf.data.Iterator.from_string_handle(train_handle, train_tf_datasets[0].output_types, train_tf_datasets[0].output_shapes) inputs = iterator.get_next() # inputs comes from the "first training dataset" by default, unless train_handle is remapped to other handles model.build_graph(inputs) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) with tf.name_scope("parameter_count"): # exclude trainable variables that are replicas (used in multi-gpu setting) trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero if (args.summary_freq != 0 or args.image_summary_freq != 0 or args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0): summary_writer = tf.summary.FileWriter(args.output_dir) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() max_steps = model.hparams.max_steps // (args.inner_iters * args.meta_batch_size) with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) model.restore(sess, args.checkpoint) sess.run(model.post_init_ops) # Evaluate handle for each dataset # Note: this step is super slow, so we only use a few datasets for debugging. if args.debug_num_datasets == -1: train_handle_evals = [sess.run(handle) for handle in train_handles] val_handle_evals = [sess.run(handle) for handle in val_handles] else: train_handle_evals = [sess.run(handle) for handle in train_handles[:args.debug_num_datasets]] val_handle_evals = [sess.run(handle) for handle in val_handles[:args.debug_num_datasets]] print("parameter_count =", sess.run(parameter_count)) print("number of train sets =", len(train_handle_evals)) print("number of test sets =", len(val_handle_evals)) # Set input for the first step current_handle_eval = random.choice(train_handle_evals) # Set up variables recorder model._state = VariableState(sess, tf.trainable_variables()) sess.graph.finalize() start_step = sess.run(global_step) def should(step, freq): if freq is None: return (step + 1) == (max_steps - start_step) else: return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) def should_eval(step, freq): # never run eval summaries at the beginning since it's expensive, unless it's the last iteration return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step)) # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, max_steps - start_step): if step == 1: # skip step -1 and 0 for timing purposes (for warmstarting) start_time = time.time() fetches = {"global_step": global_step} if step >= 0: # Set up train fetches fetches["train_op"] = model.train_op # Linearly decreased the meta step size frac_done = step / max_steps cur_meta_step_size = args.meta_step_size * (1 - frac_done) + args.final_meta_step_size * frac_done # Start meta training old_vars = model._state.export_variables() new_vars = [] for meta_idx in range(args.meta_batch_size): print("step %d, meta batch %d / %d" % (step, meta_idx+1, args.meta_batch_size)) # Sample task (videos from one specific object) current_handle_eval = random.choice(train_handle_evals) for i in range(args.inner_iters): # Run inner update run_start_time = time.time() results = sess.run(fetches, feed_dict={train_handle: current_handle_eval}) run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: print('running train_op took too long (%0.1fs)' % run_elapsed_time) # Record parameters after doing inner update for each task new_vars.append(model._state.export_variables()) model._state.import_variables(old_vars) new_vars = average_vars(new_vars) # Perform meta update model._state.import_variables(interpolate_vars(old_vars, new_vars, cur_meta_step_size)) fetches = {"global_step": global_step} if should(step, args.progress_freq): fetches['d_loss'] = model.d_loss fetches['g_loss'] = model.g_loss fetches['d_losses'] = model.d_losses fetches['g_losses'] = model.g_losses if isinstance(model.learning_rate, tf.Tensor): fetches["learning_rate"] = model.learning_rate if should(step, args.summary_freq): fetches["summary"] = model.summary_op if should(step, args.image_summary_freq): fetches["image_summary"] = model.image_summary_op if should_eval(step, args.eval_summary_freq): fetches["eval_summary"] = model.eval_summary_op results = sess.run(fetches, feed_dict={train_handle: current_handle_eval}) print(step) # Val if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or should_eval(step, args.eval_summary_freq)): # Set up val fetches for summary current_handle_eval = val_handle_evals[0] val_fetches = {"global_step": global_step} if should(step, args.summary_freq): val_fetches["summary"] = model.summary_op if should(step, args.image_summary_freq): val_fetches["image_summary"] = model.image_summary_op if should_eval(step, args.eval_summary_freq): val_fetches["eval_summary"] = model.eval_summary_op # Eval val_results = sess.run(val_fetches, feed_dict={train_handle: current_handle_eval}) for name, summary in val_results.items(): if name == 'global_step': continue val_results[name] = add_tag_suffix(summary, '_1') if should(step, args.summary_freq): print("recording summary") summary_writer.add_summary(results["summary"], results["global_step"]) summary_writer.add_summary(val_results["summary"], val_results["global_step"]) print("done") if should(step, args.image_summary_freq): print("recording image summary") summary_writer.add_summary(results["image_summary"], results["global_step"]) summary_writer.add_summary(val_results["image_summary"], val_results["global_step"]) print("done") if should_eval(step, args.eval_summary_freq): print("recording eval summary") summary_writer.add_summary(results["eval_summary"], results["global_step"]) summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"]) print("done") if should_eval(step, args.accum_eval_summary_freq): sess.run(model.accum_eval_metrics_reset_op) val_fetches = {"global_step": global_step, "accum_eval_summary": model.accum_eval_summary_op} for i, val_handle_eval in enumerate(val_handle_evals): # traverse (roughly up to rounding based on the batch size) all the validation dataset print('evaluating %d / %d test set' % (i, len(val_handle_evals))) val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_inner_update') print("recording accum eval summary") summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) print("done") if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)): summary_writer.flush() if should(step, args.progress_freq): # global_step will have the correct step count if we resume from a checkpoint # global step is read before it's incremented steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size train_epoch = results["global_step"] / steps_per_epoch print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) if step > 0: elapsed_time = time.time() - start_time average_time = elapsed_time / step images_per_sec = batch_size / average_time remaining_time = (max_steps - (start_step + step + 1)) * average_time print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) if results['d_losses']: print("d_loss", results["d_loss"]) for name, loss in results['d_losses'].items(): print(" ", name, loss) if results['g_losses']: print("g_loss", results["g_loss"]) for name, loss in results['g_losses'].items(): print(" ", name, loss) if isinstance(model.learning_rate, tf.Tensor): print("learning_rate", results["learning_rate"]) if should(step, args.save_freq): print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) print("done")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument( "--val_input_dirs", type=str, nargs='+', help="directories containing the tfrecords. default: [input_dir]") parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") parser.add_argument( "--output_dir", help= "output directory where json files, summary, model, gifs, etc are saved. " "default is logs_dir/model_fname, where model_fname consists of " "information from model and model_hparams") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") parser.add_argument( "--summary_freq", type=int, default=1000, help= "save summaries (except for image and eval summaries) every summary_freq steps" ) parser.add_argument( "--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps") parser.add_argument( "--eval_summary_freq", type=int, default=0, help="save eval summaries every eval_summary_freq steps") parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") parser.add_argument("--metrics_freq", type=int, default=0, help="run and display metrics every metrics_freq step") parser.add_argument( "--gif_freq", type=int, default=0, help="save gifs of predicted frames every gif_freq steps") parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable") parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.output_dir is None: list_depth = 0 model_fname = '' for t in ('model=%s,%s' % (args.model, args.model_hparams)): if t == '[': list_depth += 1 if t == ']': list_depth -= 1 if list_depth and t == ',': t = '..' if t in '=,': t = '.' if t in '[]': t = '' model_fname += t args.output_dir = os.path.join(args.logs_dir, model_fname) if args.resume: if args.checkpoint: raise ValueError('resume and checkpoint cannot both be specified') args.checkpoint = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.dataset_hparams_dict: with open(args.dataset_hparams_dict) as f: dataset_hparams_dict.update(json.loads(f.read())) if args.model_hparams_dict: with open(args.model_hparams_dict) as f: model_hparams_dict.update(json.loads(f.read())) if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict.update(json.loads(f.read())) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_input_dirs = args.val_input_dirs or [args.input_dir] val_datasets = [ VideoDataset(val_input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) for val_input_dir in val_input_dirs ] if len(val_input_dirs) > 1: if isinstance(val_datasets[-1], datasets.KTHVideoDataset): val_datasets[-1].set_sequence_length(40) else: val_datasets[-1].set_sequence_length(30) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift return hparams_dict VideoPredictionModel = models.get_model_class(args.model) train_model = VideoPredictionModel( mode='train', hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams) val_models = [ VideoPredictionModel(mode='val', hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams) for val_dataset in val_datasets ] batch_size = train_model.hparams.batch_size with tf.variable_scope('') as training_scope: train_model.build_graph(*train_dataset.make_batch(batch_size)) for val_model, val_dataset in zip(val_models, val_datasets): with tf.variable_scope(training_scope, reuse=True): val_model.build_graph(*val_dataset.make_batch(batch_size)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write( json.dumps(train_model.hparams.values(), sort_keys=True, indent=4)) if args.gif_freq: val_model = val_models[0] val_tensors = OrderedDict() context_images = val_model.inputs['images'][:, :val_model.hparams. context_frames] val_tensors['gen_images_vis'] = tf.concat( [context_images, val_model.gen_images], axis=1) if val_model.gen_images_enc is not None: val_tensors['gen_images_enc_vis'] = tf.concat( [context_images, val_model.gen_images_enc], axis=1) val_tensors.update({ name: tensor for name, tensor in val_model.inputs.items() if tensor.shape.ndims >= 4 }) val_tensors['targets'] = val_model.targets val_tensors.update({ name: tensor for name, tensor in val_model.outputs.items() if tensor.shape.ndims >= 4 }) val_tensor_clips = OrderedDict([ (name, tf_utils.tensor_to_clip(output)) for name, output in val_tensors.items() ]) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) saver = tf.train.Saver(max_to_keep=3) summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES)) eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES)) eval_image_summaries = image_summaries & eval_summaries image_summaries -= eval_image_summaries eval_summaries -= eval_image_summaries if args.summary_freq: summary_op = tf.summary.merge(summaries) if args.image_summary_freq: image_summary_op = tf.summary.merge(list(image_summaries)) if args.eval_summary_freq: eval_summary_op = tf.summary.merge(list(eval_summaries)) eval_image_summary_op = tf.summary.merge(list(eval_image_summaries)) if args.summary_freq or args.image_summary_freq or args.eval_summary_freq: summary_writer = tf.summary.FileWriter(args.output_dir) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() max_steps = train_model.hparams.max_steps with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) train_model.restore(sess, args.checkpoint) start_step = sess.run(global_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, max_steps - start_step): if step == 0: start = time.time() def should(freq): return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) fetches = {"global_step": global_step} if step >= 0: fetches["train_op"] = train_model.train_op if should(args.progress_freq): fetches['d_losses'] = train_model.d_losses fetches['g_losses'] = train_model.g_losses if isinstance(train_model.learning_rate, tf.Tensor): fetches["learning_rate"] = train_model.learning_rate if should(args.metrics_freq): fetches['metrics'] = train_model.metrics if should(args.summary_freq): fetches["summary"] = summary_op if should(args.image_summary_freq): fetches["image_summary"] = image_summary_op if should(args.eval_summary_freq): fetches["eval_summary"] = eval_summary_op fetches["eval_image_summary"] = eval_image_summary_op run_start_time = time.time() results = sess.run(fetches) run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5: print('session.run took %0.1fs' % run_elapsed_time) if should(args.summary_freq): print("recording summary") summary_writer.add_summary(results["summary"], results["global_step"]) print("done") if should(args.image_summary_freq): print("recording image summary") summary_writer.add_summary( tf_utils.convert_tensor_to_gif_summary( results["image_summary"]), results["global_step"]) print("done") if should(args.eval_summary_freq): print("recording eval summary") summary_writer.add_summary(results["eval_summary"], results["global_step"]) summary_writer.add_summary( tf_utils.convert_tensor_to_gif_summary( results["eval_image_summary"]), results["global_step"]) print("done") if should(args.summary_freq) or should( args.image_summary_freq) or should(args.eval_summary_freq): summary_writer.flush() if should(args.progress_freq): # global_step will have the correct step count if we resume from a checkpoint steps_per_epoch = math.ceil( train_dataset.num_examples_per_epoch() / batch_size) train_epoch = math.ceil(results["global_step"] / steps_per_epoch) train_step = (results["global_step"] - 1) % steps_per_epoch + 1 print("progress global step %d epoch %d step %d" % (results["global_step"], train_epoch, train_step)) if step >= 0: elapsed_time = time.time() - start average_time = elapsed_time / (step + 1) images_per_sec = batch_size / average_time remaining_time = (max_steps - (start_step + step)) * average_time print( " image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) for name, loss in itertools.chain(results['d_losses'].items(), results['g_losses'].items()): print(name, loss) if isinstance(train_model.learning_rate, tf.Tensor): print("learning_rate", results["learning_rate"]) if should(args.metrics_freq): for name, metric in results['metrics'].items(): print(name, metric) if should(args.save_freq): print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) print("done") if should(args.gif_freq): image_dir = os.path.join(args.output_dir, 'images') if not os.path.exists(image_dir): os.makedirs(image_dir) gif_clips = sess.run(val_tensor_clips) gif_step = results["global_step"] for name, clip in gif_clips.items(): filename = "%08d-%s.gif" % (gif_step, name) print("saving gif to", os.path.join(image_dir, filename)) ffmpeg_gif.save_gif(os.path.join(image_dir, filename), clip, fps=4) print("done")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") parser.add_argument( "--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") parser.add_argument( "--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") parser.add_argument( "--output_gif_dir", help="output directory where samples are saved as gifs. default is " "results_gif_dir/model_fname") parser.add_argument( "--output_png_dir", help="output directory where samples are saved as pngs. default is " "results_png_dir/model_fname") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument( "--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--num_stochastic_samples", type=int, default=5) parser.add_argument("--gif_length", type=int, help="default is sequence_length") parser.add_argument("--fps", type=int, default=4) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) args.results_gif_dir = args.results_gif_dir or args.results_dir args.results_png_dir = args.results_png_dir or args.results_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") args.output_gif_dir = args.output_gif_dir or os.path.join( args.results_gif_dir, os.path.split(checkpoint_dir)[1]) args.output_png_dir = args.output_png_dir or os.path.join( args.results_png_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_gif_dir = args.output_gif_dir or os.path.join( args.results_gif_dir, 'model.%s' % args.model) args.output_png_dir = args.output_png_dir or os.path.join( args.results_png_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': dataset.hparams.context_frames, 'sequence_length': dataset.hparams.sequence_length, 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel(mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams) sequence_length = model.hparams.sequence_length context_frames = model.hparams.context_frames future_length = sequence_length - context_frames if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() print('HELLOO', num_examples_per_epoch, args.batch_size) if num_examples_per_epoch % args.batch_size != 0: raise ValueError( 'batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) inputs = dataset.make_batch(args.batch_size) input_phs = { k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items() } with tf.variable_scope(''): model.build_graph(input_phs) for output_dir in (args.output_gif_dir, args.output_png_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write( json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) sess.graph.as_default() model.restore(sess, args.checkpoint) sample_ind = 0 # dir_every_n = 128 key = args.input_dir.split('/')[-1] while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results = sess.run(inputs) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) # if sample_ind % dir_every_n == 0: # output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + dir_every_n)) feed_dict = { input_ph: input_results[name] for name, input_ph in input_phs.items() } # for stochastic_sample_ind in range(args.num_stochastic_samples): for stochastic_sample_ind in range(1): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) # only keep the future frames gen_images = gen_images[:, -future_length:] for i, gen_images_ in enumerate(gen_images): context_images_ = (input_results['images'][i] * 255.0).astype( np.uint8) gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % ( sample_ind + i, stochastic_sample_ind) context_and_gen_images = list( context_images_[:context_frames]) + list(gen_images_) if args.gif_length: context_and_gen_images = context_and_gen_images[:args. gif_length] # output_dir = os.path.join(args.output_gif_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_images_fname) # print('SAVE GIF', output_dir, gen_images_fname) save_gif(os.path.join(args.output_gif_dir, key, gen_images_fname), context_and_gen_images, fps=args.fps) gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max( 2, len(str(len(gen_images_) - 1))) for t, gen_image in enumerate(gen_images_): gen_image_fname = gen_image_fname_pattern % ( sample_ind + i, stochastic_sample_ind, t) if gen_image.shape[-1] == 1: gen_image = np.tile(gen_image, (1, 1, 3)) else: gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) # output_dir = os.path.join(args.output_png_dir, '{}_{}'.format(sample_ind, sample_ind + args.batch_size), gen_image_fname) print('SAVE PNG', key) cv2.imwrite( os.path.join(args.output_png_dir, key, gen_image_fname), gen_image) sample_ind += args.batch_size
def main(): """ results_dir ├── output_dir # condition / method │ ├── prediction_eval_lpips_max # task: best sample in terms of LPIPS similarity │ │ ├── inputs │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step │ │ │ └── ... │ │ ├── outputs │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) │ │ │ └── ... │ │ └── metrics │ │ └── lpips.csv │ ├── prediction_eval_ssim_max # task: best sample in terms of SSIM │ │ ├── inputs │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step │ │ │ └── ... │ │ ├── outputs │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) │ │ │ └── ... │ │ └── metrics │ │ └── ssim.csv │ └── ... └── ... """ parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") parser.add_argument( "--output_dir", help= "output directory where results are saved. default is results_dir/model_fname, " "where model_fname is the directory name of checkpoint") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument( "--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)') parser.add_argument("--only_metrics", action='store_true') parser.add_argument("--num_stochastic_samples", type=int, default=100) parser.add_argument( "--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset" ) parser.add_argument( "--gt_outputs_dir", type=str, help= "directory containing output ground truth images for ismple dataset") parser.add_argument("--eval_parallel_iterations", type=int, default=10) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") args.output_dir = args.output_dir or os.path.join( args.results_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_dir = args.output_dir or os.path.join( args.results_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': dataset.hparams.context_frames, 'sequence_length': dataset.hparams.sequence_length, 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel( hparams_dict=hparams_dict, hparams=args.model_hparams, eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations) if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: raise ValueError( 'batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) inputs = dataset.make_batch(args.batch_size) input_phs = { k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items() } with tf.variable_scope(''): model.build_graph(input_phs) output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) sess.graph.as_default() model.restore(sess, args.checkpoint) sample_ind = 0 while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results = sess.run(inputs) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) feed_dict = { input_ph: input_results[name] for name, input_ph in input_phs.items() } # compute "best" metrics using the computation graph fetches = {'images': model.inputs['images']} fetches.update(model.eval_outputs.items()) fetches.update(model.eval_metrics.items()) results = sess.run(fetches, feed_dict=feed_dict) save_prediction_eval_results( os.path.join(output_dir, 'prediction_eval'), results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) sample_ind += args.batch_size metric_fnames = [] metric_names = ['psnr', 'ssim', 'lpips'] subtasks = ['max'] for metric_name in metric_names: for subtask in subtasks: metric_fnames.append( os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) for metric_fname in metric_fnames: task_name, _, metric_name = metric_fname.split('/')[-3:] metric = load_metrics(metric_fname) print('=' * 31) print(task_name, metric_name) print('-' * 31) metric_header_format = '{:>10} {:>20}' metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})' print( metric_header_format.format('time step', os.path.split(metric_fname)[1])) for t, (metric_mean, metric_std) in enumerate( zip(metric.mean(axis=0), metric.std(axis=0))): print(metric_row_format.format(t, metric_mean, metric_std)) print( metric_row_format.format('mean (std)', metric.mean(), metric.std())) print('=' * 31)
def main(): """ results_dir ├── output_dir # condition / method │ ├── prediction # task │ │ ├── inputs │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step │ │ │ └── ... │ │ ├── outputs │ │ │ ├── gen_image_00000_00.png # predicted images (only the ones in the loss) │ │ │ └── ... │ │ └── metrics │ │ ├── psnr.csv │ │ ├── mse.csv │ │ └── ssim.csv │ ├── prediction_eval_vgg_csim_max # task: best sample in terms of VGG cosine similarity │ │ ├── inputs │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step │ │ │ └── ... │ │ ├── outputs │ │ │ ├── gen_image_00000_00.png # predicted images (only the ones in the loss) │ │ │ └── ... │ │ └── metrics │ │ └── vgg_csim.csv │ ├── servo │ │ ├── inputs │ │ │ ├── context_image_00000_00.png │ │ │ ├── ... │ │ │ ├── goal_image_00000_00.png # only one goal image per sample │ │ │ └── ... │ │ ├── outputs │ │ │ ├── gen_image_00000_00.png │ │ │ ├── ... │ │ │ ├── gen_image_goal_diff_00000_00.png │ │ │ └── ... │ │ └── metrics │ │ ├── action_mse.csv │ │ └── goal_image_mse.csv │ ├── motion │ │ ├── inputs │ │ │ ├── pix_distrib_00000_00.png │ │ │ └── ... │ │ ├── outputs │ │ │ ├── gen_pix_distrib_00000_00.png │ │ │ ├── ... │ │ │ ├── gen_pix_distrib_overlaid_00000_00.png │ │ │ └── ... │ │ └── metrics │ │ └── pix_dist.csv │ └── ... └── ... """ parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") parser.add_argument( "--output_dir", help= "output directory where results are saved. default is results_dir/model_fname, " "where model_fname is the directory name of checkpoint") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument( "--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument( "--tasks", type=str, nargs='+', help= 'tasks to evaluate (e.g. prediction, prediction_eval, servo, motion)') parser.add_argument( "--eval_substasks", type=str, nargs='+', default=['max', 'min'], help= 'subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval' ) parser.add_argument("--only_metrics", action='store_true') parser.add_argument("--num_stochastic_samples", type=int, default=100) parser.add_argument( "--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset" ) parser.add_argument( "--gt_outputs_dir", type=str, help= "directory containing output ground truth images for ismple dataset") parser.add_argument("--eval_parallel_iterations", type=int, default=10) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") args.output_dir = args.output_dir or os.path.join( args.results_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_dir = args.output_dir or os.path.join( args.results_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift return hparams_dict VideoPredictionModel = models.get_model_class(args.model) model = VideoPredictionModel( mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams, eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations) context_frames = model.hparams.context_frames sequence_length = model.hparams.sequence_length if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: raise ValueError('batch_size should evenly divide the dataset') inputs, target = dataset.make_batch(args.batch_size) if not isinstance(model, models.GroundTruthVideoPredictionModel): # remove ground truth data past context_frames to prevent accidentally using it for k, v in inputs.items(): if k != 'actions': inputs[k] = v[:, :context_frames] input_phs = { k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items() } target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph') with tf.variable_scope(''): model.build_graph(input_phs, target_ph) tasks = args.tasks if tasks is None: tasks = ['prediction_eval'] if 'pix_distribs' in inputs: tasks.append('motion') if 'servo' in tasks: servo_model = VideoPredictionModel(mode='test', hparams_dict=model_hparams_dict, hparams=args.model_hparams) cem_batch_size = 200 plan_horizon = sequence_length - 1 image_shape = inputs['images'].shape.as_list()[2:] state_shape = inputs['states'].shape.as_list()[2:] action_shape = inputs['actions'].shape.as_list()[2:] servo_input_phs = { 'images': tf.placeholder(tf.float32, shape=[cem_batch_size, context_frames] + image_shape), 'states': tf.placeholder(tf.float32, shape=[cem_batch_size, 1] + state_shape), 'actions': tf.placeholder(tf.float32, shape=[cem_batch_size, plan_horizon] + action_shape), } if isinstance(servo_model, models.GroundTruthVideoPredictionModel): images_shape = inputs['images'].shape.as_list()[1:] servo_input_phs['images'] = tf.placeholder(tf.float32, shape=[cem_batch_size] + images_shape) with tf.variable_scope('', reuse=True): servo_model.build_graph(servo_input_phs) output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) model.restore(sess, args.checkpoint) if 'servo' in tasks: servo_policy = ServoPolicy(servo_model, sess) sample_ind = 0 while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results, target_result = sess.run([inputs, target]) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) if 'prediction_eval' in tasks: feed_dict = { input_ph: input_results[name] for name, input_ph in input_phs.items() } feed_dict.update({target_ph: target_result}) # compute "best" metrics using the computation graph (if available) or explicitly with python logic if model.eval_outputs and model.eval_metrics: fetches = {'images': model.inputs['images']} fetches.update(model.eval_outputs.items()) fetches.update(model.eval_metrics.items()) results = sess.run(fetches, feed_dict=feed_dict) else: metric_names = [ 'psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'vgg_csim' ] metric_fns = [ metrics.peak_signal_to_noise_ratio_np, metrics.structural_similarity_np, metrics.structural_similarity_scikit_np, metrics.structural_similarity_finn_np, metrics.vgg_cosine_similarity_np ] all_gen_images = [] all_metrics = [ np.empty((args.num_stochastic_samples, args.batch_size, sequence_length - context_frames)) for _ in metric_names ] for s in range(args.num_stochastic_samples): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) all_gen_images.append(gen_images) for metric_name, metric_fn, all_metric in zip( metric_names, metric_fns, all_metrics): metric = metric_fn(gen_images, target_result, keep_axis=(0, 1)) all_metric[s] = metric results = {} for metric_name, all_metric in zip(metric_names, all_metrics): for subtask in args.eval_substasks: results['eval_gen_images_%s/%s' % (metric_name, subtask)] = np.empty_like( all_gen_images[0]) results['eval_%s/%s' % (metric_name, subtask)] = np.empty_like( all_metric[0]) for i in range(args.batch_size): for metric_name, all_metric in zip(metric_names, all_metrics): ordered = np.argsort( np.mean(all_metric, axis=-1) [:, i]) # mean over time and sort over samples for subtask in args.eval_substasks: if subtask == 'max': sidx = ordered[-1] elif subtask == 'min': sidx = ordered[0] else: raise NotImplementedError results['eval_gen_images_%s/%s' % (metric_name, subtask)][i] = all_gen_images[sidx][i] results['eval_%s/%s' % (metric_name, subtask)][i] = all_metric[sidx][i] save_prediction_eval_results( os.path.join(output_dir, 'prediction_eval'), results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) if 'prediction' in tasks or 'motion' in tasks: # do these together feed_dict = { input_ph: input_results[name] for name, input_ph in input_phs.items() } fetches = { 'images': model.inputs['images'], 'gen_images': model.outputs['gen_images'] } if 'motion' in tasks: fetches.update({ 'pix_distribs': model.inputs['pix_distribs'], 'gen_pix_distribs': model.outputs['gen_pix_distribs'] }) if args.num_stochastic_samples: all_results = [ sess.run(fetches, feed_dict=feed_dict) for _ in range(args.num_stochastic_samples) ] all_results = nest.map_structure(lambda *x: np.stack(x), *all_results) all_context_images, all_images = np.split( all_results['images'], [context_frames], axis=2) all_gen_images = all_results['gen_images'][:, :, context_frames - sequence_length:] all_mse = metrics.mean_squared_error_np(all_images, all_gen_images, keep_axis=(0, 1)) all_mse_argsort = np.argsort(all_mse, axis=0) for subtask, argsort_ind in zip( ['_best', '_median', '_worst'], [0, args.num_stochastic_samples // 2, -1]): all_mse_inds = all_mse_argsort[argsort_ind] gather = lambda x: np.array([ x[ind, sample_ind] for sample_ind, ind in enumerate(all_mse_inds) ]) results = nest.map_structure(gather, all_results) if 'prediction' in tasks: save_prediction_results( os.path.join(output_dir, 'prediction' + subtask), results, model.hparams, sample_ind, args.only_metrics) if 'motion' in tasks: draw_center = isinstance( model, models.NonTrainableVideoPredictionModel) save_motion_results( os.path.join(output_dir, 'motion' + subtask), results, model.hparams, draw_center, sample_ind, args.only_metrics) else: results = sess.run(fetches, feed_dict=feed_dict) if 'prediction' in tasks: save_prediction_results( os.path.join(output_dir, 'prediction'), results, model.hparams, sample_ind, args.only_metrics) if 'motion' in tasks: draw_center = isinstance( model, models.NonTrainableVideoPredictionModel) save_motion_results(os.path.join(output_dir, 'motion'), results, model.hparams, draw_center, sample_ind, args.only_metrics) if 'servo' in tasks: images = input_results['images'] states = input_results['states'] gen_actions = [] gen_images = [] for images_, states_ in zip(images, states): obs = { 'context_images': images_[:context_frames], 'context_state': states_[0], 'goal_image': images_[-1] } if isinstance(servo_model, models.GroundTruthVideoPredictionModel): obs['context_images'] = images_ gen_actions_, gen_images_ = servo_policy.act( obs, servo_model.outputs['gen_images']) gen_actions.append(gen_actions_) gen_images.append(gen_images_) gen_actions = np.stack(gen_actions) gen_images = np.stack(gen_images) results = { 'images': input_results['images'], 'actions': input_results['actions'], 'goal_image': input_results['images'][:, -1], 'gen_actions': gen_actions, 'gen_images': gen_images } save_servo_results(os.path.join(output_dir, 'servo'), results, servo_model.hparams, sample_ind, args.only_metrics) sample_ind += args.batch_size metric_fnames = [] if 'prediction_eval' in tasks: metric_names = ['psnr', 'ssim', 'ssim_finn', 'vgg_csim'] subtasks = ['max'] for metric_name in metric_names: for subtask in subtasks: metric_fnames.append( os.path.join( output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) if 'prediction' in tasks: subtask = '_best' if args.num_stochastic_samples else '' metric_fnames.extend([ os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'psnr'), os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'mse'), os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'ssim'), ]) if 'motion' in tasks: subtask = '_best' if args.num_stochastic_samples else '' metric_fnames.append( os.path.join(output_dir, 'motion' + subtask, 'metrics', 'pix_dist')) if 'servo' in tasks: metric_fnames.append( os.path.join(output_dir, 'servo', 'metrics', 'goal_image_mse')) metric_fnames.append( os.path.join(output_dir, 'servo', 'metrics', 'action_mse')) for metric_fname in metric_fnames: task_name, _, metric_name = metric_fname.split('/')[-3:] metric = load_metrics(metric_fname) print('=' * 31) print(task_name, metric_name) print('-' * 31) metric_header_format = '{:>10} {:>20}' metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})' print( metric_header_format.format('time step', os.path.split(metric_fname)[1])) for t, (metric_mean, metric_std) in enumerate( zip(metric.mean(axis=0), metric.std(axis=0))): print(metric_row_format.format(t, metric_mean, metric_std)) print( metric_row_format.format('mean (std)', metric.mean(), metric.std())) print('=' * 31)
) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': dataset.hparams.context_frames, 'sequence_length': dataset.hparams.sequence_length, 'eval_length': dataset.hparams.eval_length, 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel(mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams) sequence_length = model.hparams.sequence_length context_frames = model.hparams.context_frames future_length = sequence_length - context_frames
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, " "where model_fname is the directory name of checkpoint") parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") parser.add_argument("--only_metrics", action='store_true') parser.add_argument("--mode", type=str, choices=['val', 'test'], default='test', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=1, help="number of samples in batch") parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval') parser.add_argument("--num_stochastic_samples", type=int, default=100) parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset") parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset") parser.add_argument("--eval_parallel_iterations", type=int, default=1) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print("dataset_hparams.json was not loaded because it does not exist") try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError('dataset is required when checkpoint is not specified') if not args.model: raise ValueError('model is required when checkpoint is not specified') args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model) if not args.only_metrics: args.num_stochastic_samples = args.num_stochastic_samples // 10 print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift hparams_dict['zs_seed_no'] = args.seed return hparams_dict VideoPredictionModel = models.get_model_class(args.model) model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams, eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations) context_frames = model.hparams.context_frames sequence_length = model.hparams.sequence_length if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: raise ValueError('batch_size should evenly divide the dataset') inputs, target = dataset.make_batch(args.batch_size) if not isinstance(model, models.GroundTruthVideoPredictionModel): # remove ground truth data past context_frames to prevent accidentally using it for k, v in inputs.items(): if k != 'actions': inputs[k] = v[:, :context_frames] input_phs = {k: tf.placeholder(v.dtype, [v.get_shape().as_list()[0] * 2] + v.get_shape().as_list()[1:], '%s_ph' % k) for k, v in inputs.items()} target_ph = tf.placeholder(target.dtype, [target.get_shape().as_list()[0] * 2] + target.get_shape().as_list()[1:], 'targets_ph') with tf.variable_scope(''): model.build_graph(input_phs, target_ph) output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) model.restore(sess, args.checkpoint) sample_ind = 0 our_result_list = {} while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results, target_result = sess.run([inputs, target]) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) ########################################## # compute statistics feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} feed_dict.update({target_ph: target_result}) ################################### # compute diversity measures fetches = {'images': model.inputs['images'], 'gen_images': model.outputs['gen_images']} all_results = [sess.run(fetches, feed_dict=feed_dict) for _ in range(args.num_stochastic_samples)] all_results = nest.map_structure(lambda *x: np.stack(x), *all_results) # the result has (num_samples, batch_size, time, height, width, 3) all_results['context'] = np.repeat(np.expand_dims(input_results['images'][:args.batch_size], axis=0), args.num_stochastic_samples, axis=0) all_results['images'] = np.repeat(np.expand_dims(target_result[:args.batch_size], axis=0), args.num_stochastic_samples, axis=0) all_results['gen_images'] = all_results['gen_images'][:, :args.batch_size, context_frames - sequence_length:] all_images = all_results['images'] all_gen_images = all_results['gen_images'] if args.only_metrics: # do it for VGG csim_result = [metrics.vgg_cosine_similarity_np(np.array(a), np.array(b), keep_axis=(0, 1, 2)) for a, b in zip(group(all_images, 25), group(all_gen_images, 25))] csim_max = np.mean(np.array(csim_result), axis=-1).max() our_result_list.setdefault('gt_csim', []).append(csim_max) # do it for MSE all_mse = metrics.mean_squared_error_np(all_images, all_gen_images, keep_axis=(0, 1, 2)) mse_min = np.mean(all_mse, axis=-1).min() our_result_list.setdefault('gt_mse', []).append(mse_min) # compute ours our_result = compute_our_diversity_np(all_gen_images) our_result_list.setdefault('btw_mse', []).append(our_result) #all_mse_argsort = np.argsort(all_mse, axis=0) #for subtask, argsort_ind in zip(['_best', '_median', '_worst'], # [0, args.num_stochastic_samples // 2, -1]): # all_mse_inds = all_mse_argsort[argsort_ind] # gather = lambda x: np.array([x[ind, sample_ind] for sample_ind, ind in enumerate(all_mse_inds)]) # results = nest.map_structure(gather, all_results) # save_prediction_results(os.path.join(output_dir, 'prediction' + subtask), # results, model.hparams, sample_ind, (args.only_metrics or (subtask == '_median')) ) else: # write logic for saving results print('saving results') for sample_no, single_video in enumerate(all_results['gen_images']): save_image_sequence(os.path.join(output_dir, '%05d_%02d' % (sample_ind, sample_no)), single_video[0]) sample_ind += args.batch_size summary_file_array = [] summary_file_path = os.path.join(args.output_dir, 'summary.txt') for key, value_list in our_result_list.items(): our_metric = np.asarray(value_list) sentence = '[%s]: %12.9f (%12.9f)' % (key, our_metric.mean(), our_metric.std()) print(sentence) summary_file_array.append(sentence) with open(summary_file_path, 'w') as f: f.write('\n'.join(summary_file_array))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument( "--output_dir", help= "output directory where json files, summary, model, gifs, etc are saved. " "default is logs_dir/model_fname, where model_fname consists of " "information from model and model_hparams") parser.add_argument("--output_dir_postfix", default="") parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") parser.add_argument( "--summary_freq", type=int, default=1000, help= "save frequency of summaries (except for image and eval summaries) for train/validation set" ) parser.add_argument( "--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set") parser.add_argument( "--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set") parser.add_argument( "--accum_eval_summary_freq", type=int, default=100000, help= "save frequency of accumulated eval summaries for validation set only") parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable") parser.add_argument( "--aggregate_nccl", type=int, default=0, help= "whether to use nccl or cpu for gradient aggregation in multi-gpu training" ) parser.add_argument("--gpu_mem_frac", type=float, default=0.7, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int) parser.add_argument( "--inner_train_steps", type=int, default=0, help= "whether to use nccl or cpu for gradient aggregation in multi-gpu training" ) parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)') parser.add_argument("--only_metrics", action='store_true') parser.add_argument("--num_stochastic_samples", type=int, default=100) parser.add_argument("--eval_parallel_iterations", type=int, default=10) parser.add_argument("--inner_train_batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument("--eval_batch_size", type=int, default=2, help="number of samples in batch") args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.resume: if args.checkpoint: raise ValueError('resume and checkpoint cannot both be specified') args.checkpoint = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") args.output_dir = args.output_dir or os.path.join( args.results_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_dir = args.output_dir or os.path.join( args.results_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_dataset = VideoDataset(args.input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) variable_scope = tf.get_variable_scope() variable_scope.set_use_resource(True) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': train_dataset.hparams.context_frames, 'sequence_length': train_dataset.hparams.sequence_length, 'repeat': train_dataset.hparams.time_shift, }) model = VideoPredictionModel( hparams_dict=hparams_dict, hparams=args.model_hparams, aggregate_nccl=args.aggregate_nccl, eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations) num_examples_per_epoch = val_dataset.num_examples_per_epoch() if num_examples_per_epoch % args.eval_batch_size != 0: raise ValueError( 'batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) # Batch size of inner train & eval inner_train_batch_size = args.inner_train_batch_size eval_batch_size = args.eval_batch_size # Dataset for inner training train_tf_dataset = train_dataset.make_dataset(inner_train_batch_size) train_iterator = train_tf_dataset.make_one_shot_iterator() train_handle = train_iterator.string_handle() # Dataset for evaluation val_tf_dataset = val_dataset.make_dataset(eval_batch_size) val_iterator = val_tf_dataset.make_one_shot_iterator() val_handle = val_iterator.string_handle() # Iterator for inner training iterator = tf.data.Iterator.from_string_handle( train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) inputs = iterator.get_next() # Inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles model.build_graph(inputs) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) with tf.name_scope("parameter_count"): # exclude trainable variables that are replicas (used in multi-gpu setting) trainable_variables = set(tf.trainable_variables()) & set( model.saveable_variables) parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero if (args.summary_freq != 0 or args.image_summary_freq != 0 or args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0): summary_writer = tf.summary.FileWriter(args.output_dir) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() inner_train_steps = args.inner_train_steps with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) model.restore(sess, args.checkpoint) sess.run(model.post_init_ops) val_handle_eval = sess.run(val_handle) sess.graph.finalize() start_step = sess.run(global_step) def should(step, freq): if freq is None: return (step + 1) == (inner_train_steps - start_step) else: return freq and ( (step + 1) % freq == 0 or (step + 1) in (0, inner_train_steps - start_step)) def should_eval(step, freq): # never run eval summaries at the beginning since it's expensive, unless it's the last iteration return should(step, freq) and (step >= 0 or (step + 1) == (inner_train_steps - start_step)) # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, inner_train_steps): if step == 1: # skip step -1 and 0 for timing purposes (for warmstarting) start_time = time.time() fetches = {"global_step": global_step} if step >= 0: fetches["train_op"] = model.train_op fetches['d_loss'] = model.d_loss fetches['g_loss'] = model.g_loss fetches['d_losses'] = model.d_losses fetches['g_losses'] = model.g_losses if isinstance(model.learning_rate, tf.Tensor): fetches["learning_rate"] = model.learning_rate fetches["summary"] = model.summary_op fetches["image_summary"] = model.image_summary_op fetches["eval_summary"] = model.eval_summary_op run_start_time = time.time() print(step) results = sess.run(fetches) run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5 and step > 0 and set( fetches.keys()) == {"global_step", "train_op"}: print('running train_op took too long (%0.1fs)' % run_elapsed_time) # Finish meta training, start evaluation val_fetches = {"global_step": global_step} val_fetches["summary"] = model.summary_op val_fetches["image_summary"] = model.image_summary_op val_fetches["eval_summary"] = model.eval_summary_op val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) for name, summary in val_results.items(): if name == 'global_step': continue val_results[name] = add_tag_suffix(summary, '_1') print("recording summary") if inner_train_steps > 0: summary_writer.add_summary(results["summary"], results["global_step"]) summary_writer.add_summary(val_results["summary"], val_results["global_step"]) print("done") print("recording image summary") if inner_train_steps > 0: summary_writer.add_summary(results["image_summary"], results["global_step"]) summary_writer.add_summary(val_results["image_summary"], val_results["global_step"]) print("done") print("recording eval summary") if inner_train_steps > 0: summary_writer.add_summary(results["eval_summary"], results["global_step"]) summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"]) print("done") val_datasets = [val_dataset] val_models = [model] for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): sess.run(val_model.accum_eval_metrics_reset_op) # traverse (roughly up to rounding based on the batch size) all the validation dataset accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch( ) // eval_batch_size val_fetches = { "global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op, "images": model.inputs['images'] } val_fetches.update(model.eval_outputs.items()) val_fetches.update(model.eval_metrics.items()) for update_step in range(accum_eval_summary_num_updates): print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) val_results = sess.run( val_fetches, feed_dict={train_handle: val_handle_eval}) save_prediction_eval_results( os.path.join(args.output_dir, 'prediction_eval'), val_results, model.hparams, update_step * eval_batch_size, args.only_metrics, args.eval_substasks) accum_eval_summary = add_tag_suffix( val_results["accum_eval_summary"], '_%d' % (i + 1)) print("recording accum eval summary") summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) print("done") summary_writer.flush() # global_step will have the correct step count if we resume from a checkpoint # global step is read before it's incremented steps_per_epoch = train_dataset.num_examples_per_epoch( ) / inner_train_batch_size train_epoch = results["global_step"] / steps_per_epoch print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) if step > 0: elapsed_time = time.time() - start_time average_time = elapsed_time / step images_per_sec = inner_train_batch_size / average_time remaining_time = (inner_train_steps - (start_step + step + 1)) * average_time print( " image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) if inner_train_steps > 0: if results['d_losses']: print("d_loss", results["d_loss"]) for name, loss in results['d_losses'].items(): print(" ", name, loss) if results['g_losses']: print("g_loss", results["g_loss"]) for name, loss in results['g_losses'].items(): print(" ", name, loss) if isinstance(model.learning_rate, tf.Tensor): print("learning_rate", results["learning_rate"]) print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) print("done")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument( "--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") parser.add_argument( "--output_dir", help= "output directory where json files, summary, model, gifs, etc are saved. " "default is logs_dir/model_fname, where model_fname consists of " "information from model and model_hparams") parser.add_argument("--output_dir_postfix", default="") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") parser.add_argument( "--summary_freq", type=int, default=1000, help= "save frequency of summaries (except for image and eval summaries) for train/validation set" ) parser.add_argument( "--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set") parser.add_argument( "--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set") parser.add_argument( "--accum_eval_summary_freq", type=int, default=100000, help= "save frequency of accumulated eval summaries for validation set only") parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable") parser.add_argument( "--aggregate_nccl", type=int, default=0, help= "whether to use nccl or cpu for gradient aggregation in multi-gpu training" ) parser.add_argument("--gpu_mem_frac", type=float, default=1.0, help="fraction of gpu memory to use") parser.add_argument("--num_gpus", type=int, default=1, help="number of gpus to use") parser.add_argument("--seed", type=int) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.output_dir is None: list_depth = 0 repo = git.Repo(search_parent_directories=True) sha = repo.head.object.hexsha[:10] stamp = "{:%B_%d_%H-%M-%S}".format(datetime.now()) model_fname = '' for t in ('model=%s,%s' % (args.model, args.model_hparams)): if t == '[': list_depth += 1 if t == ']': list_depth -= 1 if list_depth and t == ',': t = '..' if t in '=,': t = '.' if t in '[]': t = '' model_fname += t model_fname += '_{}_{}'.format(stamp, sha) args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix if args.resume: if args.checkpoint: raise ValueError('resume and checkpoint cannot both be specified') args.checkpoint = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.dataset_hparams_dict: with open(args.dataset_hparams_dict) as f: dataset_hparams_dict.update(json.loads(f.read())) if args.model_hparams_dict: with open(args.model_hparams_dict) as f: model_hparams_dict.update(json.loads(f.read())) if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict.update(json.loads(f.read())) except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) summary_args = { 'fps': 2, } VideoDataset = get_dataset_class(args.dataset) train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) val_dataset = VideoDataset(args.val_input_dir or args.input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length: # the longer dataset is only used for the accum_eval_metrics long_val_dataset = VideoDataset(args.val_input_dir or args.input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) long_val_dataset.set_sequence_length( val_dataset.hparams.long_sequence_length) else: long_val_dataset = None variable_scope = tf.get_variable_scope() variable_scope.set_use_resource(True) VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': train_dataset.hparams.context_frames, 'sequence_length': train_dataset.hparams.sequence_length, 'repeat': train_dataset.hparams.time_shift, 'dt': train_dataset.hparams.dt, }) model = VideoPredictionModel(hparams_dict=hparams_dict, hparams=args.model_hparams, num_gpus=args.num_gpus, aggregate_nccl=args.aggregate_nccl) batch_size = model.hparams.batch_size train_tf_dataset = train_dataset.make_dataset(batch_size) train_iterator = train_tf_dataset.make_one_shot_iterator() train_handle = train_iterator.string_handle() val_tf_dataset = val_dataset.make_dataset(batch_size) val_iterator = val_tf_dataset.make_one_shot_iterator() val_handle = val_iterator.string_handle() iterator = tf.data.Iterator.from_string_handle( train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) inputs = iterator.get_next() # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles model.build_graph(inputs, summary_args) if long_val_dataset is not None: # separately build a model for the longer sequence. # this is needed because the model doesn't support dynamic shapes. long_hparams_dict = dict(hparams_dict) long_hparams_dict[ 'sequence_length'] = long_val_dataset.hparams.sequence_length # use smaller batch size for longer model to prevenet running out of memory long_hparams_dict['batch_size'] = model.hparams.batch_size // 2 long_model = VideoPredictionModel( mode="test", # to not build the losses and discriminators hparams_dict=long_hparams_dict, hparams=args.model_hparams, aggregate_nccl=args.aggregate_nccl) tf.get_variable_scope().reuse_variables() long_model.build_graph(long_val_dataset.make_batch(batch_size), summary_args) else: long_model = None if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write( json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) with tf.name_scope("parameter_count"): # exclude trainable variables that are replicas (used in multi-gpu setting) trainable_variables = set(tf.trainable_variables()) & set( model.saveable_variables) parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero if (args.summary_freq != 0 or args.image_summary_freq != 0 or args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0): summary_writer = tf.summary.FileWriter(args.output_dir) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) global_step = tf.train.get_or_create_global_step() max_steps = model.hparams.max_steps with tf.Session(config=config) as sess: print("parameter_count =", sess.run(parameter_count)) summary_writer.add_graph(sess.graph) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) model.restore(sess, args.checkpoint) sess.run(model.post_init_ops) val_handle_eval = sess.run(val_handle) sess.graph.finalize() start_step = sess.run(global_step) def should(step, freq): if freq is None: return (step + 1) == (max_steps - start_step) else: return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) def should_eval(step, freq): # FIXME return False # never run eval summaries at the beginning since it's expensive, unless it's the last iteration return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step)) # start at one step earlier to log everything without doing any training # step is relative to the start_step for step in range(-1, max_steps - start_step): if step == 1: # skip step -1 and 0 for timing purposes (for warmstarting) start_time = time.time() fetches = {"global_step": global_step} if step >= 0: fetches["train_op"] = model.train_op if should(step, args.progress_freq): fetches['d_loss'] = model.d_loss fetches['g_loss'] = model.g_loss fetches['d_losses'] = model.d_losses fetches['g_losses'] = model.g_losses if isinstance(model.learning_rate, tf.Tensor): fetches["learning_rate"] = model.learning_rate if should(step, args.summary_freq): fetches["summary"] = model.summary_op if should(step, args.image_summary_freq): fetches["image_summary"] = model.image_summary_op if should_eval(step, args.eval_summary_freq): fetches["eval_summary"] = model.eval_summary_op results = sess.run(fetches) if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or should_eval(step, args.eval_summary_freq)): val_fetches = {"global_step": global_step} if should(step, args.summary_freq): val_fetches["summary"] = model.summary_op if should(step, args.image_summary_freq): val_fetches["image_summary"] = model.image_summary_op if should_eval(step, args.eval_summary_freq): val_fetches["eval_summary"] = model.eval_summary_op val_results = sess.run( val_fetches, feed_dict={train_handle: val_handle_eval}) for name, summary in val_results.items(): if name == 'global_step': continue val_results[name] = add_tag_suffix(summary, '_1') if should(step, args.summary_freq): print("recording summary") summary_writer.add_summary(results["summary"], results["global_step"]) summary_writer.add_summary(val_results["summary"], val_results["global_step"]) print("done") if should(step, args.image_summary_freq): print("recording image summary") summary_writer.add_summary(results["image_summary"], results["global_step"]) summary_writer.add_summary(val_results["image_summary"], val_results["global_step"]) print("done") if should_eval(step, args.eval_summary_freq): print("recording eval summary") summary_writer.add_summary(results["eval_summary"], results["global_step"]) summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"]) print("done") if should_eval(step, args.accum_eval_summary_freq): val_datasets = [val_dataset] val_models = [model] if long_model is not None: val_datasets.append(long_val_dataset) val_models.append(long_model) for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): sess.run(val_model.accum_eval_metrics_reset_op) # traverse (roughly up to rounding based on the batch size) all the validation dataset accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch( ) // val_model.hparams.batch_size val_fetches = { "global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op } for update_step in range(accum_eval_summary_num_updates): print( 'evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) val_results = sess.run( val_fetches, feed_dict={train_handle: val_handle_eval}) accum_eval_summary = add_tag_suffix( val_results["accum_eval_summary"], '_%d' % (i + 1)) print("recording accum eval summary") summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) print("done") if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)): summary_writer.flush() if should(step, args.progress_freq): # global_step will have the correct step count if we resume from a checkpoint # global step is read before it's incremented steps_per_epoch = train_dataset.num_examples_per_epoch( ) / batch_size train_epoch = results["global_step"] / steps_per_epoch print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) if step > 0: elapsed_time = time.time() - start_time average_time = elapsed_time / step images_per_sec = batch_size / average_time remaining_time = (max_steps - (start_step + step + 1)) * average_time print( " image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) if should(step, args.save_freq): print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) print("done")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="a directory containing the root dat directory ") parser.add_argument("--output_dir", type=str, default='outputs', help="directory to save raw outputs from predictions") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") parser.add_argument( "--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") parser.add_argument( "--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") parser.add_argument( "--output_gif_dir", help="output directory where samples are saved as gifs. default is " "results_gif_dir/model_fname") parser.add_argument( "--output_png_dir", help="output directory where samples are saved as pngs. default is " "results_png_dir/model_fname") parser.add_argument( "--checkpoint", help= "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)" ) parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument( "--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument( "--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--context_length", type=int, help="number of context frames") parser.add_argument("--prediction_length", type=int, help="number of frames to predict") parser.add_argument("--batch_size", type=int, default=5, help="number of samples in batch") parser.add_argument( "--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--num_stochastic_samples", type=int, default=5) parser.add_argument("--gif_length", type=int, help="default is sequence_length") parser.add_argument("--fps", type=int, default=4) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) args.results_dir = args.output_dir args.results_gif_dir = args.output_dir args.results_png_dir = args.output_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print( "dataset_hparams.json was not loaded because it does not exist" ) try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print( "model_hparams.json was not loaded because it does not exist") else: if not args.dataset: raise ValueError( 'dataset is required when checkpoint is not specified') if not args.model: raise ValueError( 'model is required when checkpoint is not specified') args.output_gif_dir = args.output_gif_dir or os.path.join( args.results_gif_dir, 'model.%s' % args.model) args.output_png_dir = args.output_png_dir or os.path.join( args.results_png_dir, 'model.%s' % args.model) print( '----------------------------------- Options ------------------------------------' ) for k, v in args._get_kwargs(): print(k, "=", v) print( '------------------------------------- End --------------------------------------' ) VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift if args.context_length is not None: hparams_dict['context_frames'] = args.context_length if args.prediction_length is not None: hparams_dict[ 'sequence_length'] = args.context_length + args.prediction_length return hparams_dict VideoPredictionModel = models.get_model_class(args.model) model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams) if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: raise ValueError('batch_size should evenly divide the dataset') inputs, _ = dataset.make_batch(args.batch_size) if not isinstance(model, models.GroundTruthVideoPredictionModel): # remove ground truth data past context_frames to prevent accidentally using it for k, v in inputs.items(): if k != 'actions': inputs[k] = v[:, :model.hparams.context_frames] input_phs = { k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items() } with tf.variable_scope(''): model.build_graph(input_phs) if not os.path.exists(args.output_dir): os.makedirs(output_dir) with open(os.path.join(args.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) model.restore(sess, args.checkpoint) #get all tile/filenames from input_dir tiles = os.listdir(args.input_dir) tiles.sort() tile_names = [] file_names = [] for tile in tiles: in_tile_path = os.path.join(args.input_dir, tile) files = os.listdir(in_tile_path) files.sort() for file in files: tile_names.append(tile) file_names.append(file[:-10]) #Generate and save sample_ind = 0 while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results = sess.run(inputs) except tf.errors.OutOfRangeError: break print("Generating samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) feed_dict = { input_ph: input_results[name] for name, input_ph in input_phs.items() } #If num stochastic samples is 1, make a mirror of input dataset with predictions #If num stochastic samples bigger, make a deeper subdirectory for stochastic_sample_ind in range(args.num_stochastic_samples): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) for i, gen_images_ in enumerate(gen_images): #get only channels 0,1,2,3 send time to the last gen_images_ = np.moveaxis(gen_images_[:, :, :, :4], 0, -1).astype(np.float16) #switch red and blue r, g, b, n = np.split(gen_images_, 4, axis=2) gen_images_ = np.concatenate([b, g, r, n], axis=2) #Save save_full_path = os.path.join(args.output_dir, tile_names[sample_ind]) sample_name = str(stochastic_sample_ind + 1).zfill( 4) + '_' + file_names[sample_ind] + '.npz' if not os.path.exists(save_full_path): os.makedirs(save_full_path) np.savez(os.path.join(save_full_path, sample_name), gen_images_) sample_ind += 1 if stochastic_sample_ind < (args.num_stochastic_samples - 1): sample_ind -= args.batch_size
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " "train, val, test, etc, or a directory containing " "the tfrecords") parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " "results_gif_dir/model_fname") parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " "results_png_dir/model_fname") parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') parser.add_argument("--dataset", type=str, help="dataset class name") parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") parser.add_argument("--model", type=str, help="model class name") parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--num_stochastic_samples", type=int, default=5) parser.add_argument("--gif_length", type=int, help="default is sequence_length") parser.add_argument("--fps", type=int, default=4) parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) args = parser.parse_args() if args.seed is not None: tf.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) args.results_gif_dir = args.results_gif_dir or args.results_dir args.results_png_dir = args.results_png_dir or args.results_dir dataset_hparams_dict = {} model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) args.dataset = args.dataset or options['dataset'] args.model = args.model or options['model'] try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print("dataset_hparams.json was not loaded because it does not exist") try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) else: if not args.dataset: raise ValueError('dataset is required when checkpoint is not specified') if not args.model: raise ValueError('model is required when checkpoint is not specified') args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): print(k, "=", v) print('------------------------------------- End --------------------------------------') VideoDataset = datasets.get_dataset_class(args.dataset) dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) def override_hparams_dict(dataset): hparams_dict = dict(model_hparams_dict) hparams_dict['context_frames'] = dataset.hparams.context_frames hparams_dict['sequence_length'] = dataset.hparams.sequence_length hparams_dict['repeat'] = dataset.hparams.time_shift return hparams_dict VideoPredictionModel = models.get_model_class(args.model) model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams) if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') num_examples_per_epoch = args.num_samples else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: raise ValueError('batch_size should evenly divide the dataset') inputs, target = dataset.make_batch(args.batch_size) if not isinstance(model, models.GroundTruthVideoPredictionModel): # remove ground truth data past context_frames to prevent accidentally using it for k, v in inputs.items(): if k != 'actions': inputs[k] = v[:, :model.hparams.context_frames] input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph') with tf.variable_scope(''): model.build_graph(input_phs, target_ph) for output_dir in (args.output_gif_dir, args.output_png_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(args), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) model.restore(sess, args.checkpoint) sample_ind = 0 while True: if args.num_samples and sample_ind >= args.num_samples: break try: input_results, target_result = sess.run([inputs, target]) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} for stochastic_sample_ind in range(args.num_stochastic_samples): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) for i, gen_images_ in enumerate(gen_images): gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) save_gif(os.path.join(args.output_gif_dir, gen_images_fname), gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps) for t, gen_image in enumerate(gen_images_): gen_image_fname = 'gen_image_%05d_%02d_%02d.png' % (sample_ind + i, stochastic_sample_ind, t) gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) sample_ind += args.batch_size
def __init__(self, num_of_generated_frames, batch_size): seed = 7 self.img_size = 64 self.batch_size = batch_size tf.set_random_seed(seed) np.random.seed(seed) random.seed(seed) results_gif_dir = "../results_test/carla_50k" results_png_dir = "../results_test/carla_50k" dataset_hparams_dict = {} model_hparams_dict = {} checkpoint = "/media/eslam/426b7820-cb81-4c46-9430-be5429970ddb/home/eslam/Future_Imitiation/video_prediction-master/logs/carla_intel/ours_savp" # loading weights checkpoint_dir = os.path.normpath(checkpoint) if not os.path.isdir(checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) if not os.path.exists(checkpoint_dir): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) dataset = "carla_intel" model = "savp" mode = "test" num_epochs = 1 gpu_mem_frac = 0 self.num_stochastic_samples = 1 self.fps = 4 dataset_hparams = "sequence_length = " + str(4 + num_of_generated_frames) # TODO: should be changed to feed the 4 images directly to it. input_dir = "/media/eslam/426b7820-cb81-4c46-9430-be5429970ddb/home/eslam/Future_Imitiation/Intel_dataset/tf_record/test" try: with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: dataset_hparams_dict = json.loads(f.read()) except FileNotFoundError: print("dataset_hparams.json was not loaded because it does not exist") try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") self.output_gif_dir = os.path.join(results_gif_dir, os.path.split(checkpoint_dir)[1]) self.output_png_dir = os.path.join(results_png_dir, os.path.split(checkpoint_dir)[1]) VideoDataset = datasets.get_dataset_class(dataset) dataset = VideoDataset( input_dir, mode=mode, num_epochs=num_epochs, seed=seed, hparams_dict=dataset_hparams_dict, hparams=dataset_hparams) VideoPredictionModel = models.get_model_class(model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ 'context_frames': dataset.hparams.context_frames, 'sequence_length': dataset.hparams.sequence_length, 'repeat': dataset.hparams.time_shift, }) self.model = VideoPredictionModel( mode=mode, hparams_dict=hparams_dict, hparams=None) self.sequence_length = self.model.hparams.sequence_length self.gif_length = self.sequence_length context_frames = self.model.hparams.context_frames self.future_length = self.sequence_length - context_frames num_examples_per_epoch = dataset.num_examples_per_epoch() #if num_examples_per_epoch % self.batch_size != 0: # raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) inputs = dataset.make_batch(self.batch_size) self.input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} with tf.variable_scope(''): self.model.build_graph(self.input_phs) for output_dir in (self.output_gif_dir, self.output_png_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) self.sess = tf.Session(config=config) self.sess.graph.as_default() self.model.restore(self.sess, checkpoint) self.sample_ind = 0