def __init__( self, batch_size, observation_space, action_space, policy_hparams, policy_dir, sampling_temp ): super(PolicyAgent, self).__init__( batch_size, observation_space, action_space ) self._sampling_temp = sampling_temp with tf.Graph().as_default(): self._observations_t = tf.placeholder( shape=((batch_size,) + self.observation_space.shape), dtype=self.observation_space.dtype ) (logits, self._values_t) = rl.get_policy( self._observations_t, policy_hparams, self.action_space ) actions = common_layers.sample_with_temperature(logits, sampling_temp) self._probs_t = tf.nn.softmax(logits / sampling_temp) self._actions_t = tf.cast(actions, tf.int32) model_saver = tf.train.Saver( tf.global_variables(policy_hparams.policy_network + "/.*") # pylint: disable=unexpected-keyword-arg ) self._sess = tf.Session() self._sess.run(tf.global_variables_initializer()) trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess)
def evaluate(self, env_fn, hparams, stochastic): if stochastic: policy_to_actions_lambda = lambda policy: policy.sample() else: policy_to_actions_lambda = lambda policy: policy.mode() with tf.Graph().as_default(): with tf.name_scope("rl_eval"): eval_env = env_fn(in_graph=True) (collect_memory, _, collect_init) = _define_collect( eval_env, hparams, "ppo_eval", eval_phase=True, frame_stack_size=self.frame_stack_size, force_beginning_resets=False, policy_to_actions_lambda=policy_to_actions_lambda) model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) collect_init(sess) trainer_lib.restore_checkpoint(self.agent_model_dir, model_saver, sess) sess.run(collect_memory)
def _run_train(ppo_hparams, event_dir, model_dir, restarter, train_summary_op, eval_summary_op, initializers, report_fn=None): """Train.""" summary_writer = tf.summary.FileWriter(event_dir, graph=tf.get_default_graph(), flush_secs=60) model_saver = tf.train.Saver( tf.global_variables(ppo_hparams.policy_network + "/.*") + tf.global_variables("training/" + ppo_hparams.policy_network + "/.*") + # tf.global_variables("clean_scope.*") + # Needed for sharing params. tf.global_variables("global_step") + tf.global_variables("losses_avg.*") + tf.global_variables("train_stats.*")) global_step = tf.train.get_or_create_global_step() with tf.control_dependencies([tf.assign_add(global_step, 1)]): train_summary_op = tf.identity(train_summary_op) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for initializer in initializers: initializer(sess) trainer_lib.restore_checkpoint(model_dir, model_saver, sess) num_target_iterations = restarter.target_local_step num_completed_iterations = num_target_iterations - restarter.steps_to_go with restarter.training_loop(): for epoch_index in range(num_completed_iterations, num_target_iterations): summary = sess.run(train_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) if (ppo_hparams.eval_every_epochs and epoch_index % ppo_hparams.eval_every_epochs == 0): eval_summary = sess.run(eval_summary_op) if summary_writer: summary_writer.add_summary(eval_summary, epoch_index) if report_fn: summary_proto = tf.Summary() summary_proto.ParseFromString(eval_summary) for elem in summary_proto.value: if "mean_score" in elem.tag: report_fn(elem.simple_value, epoch_index) break if (model_saver and ppo_hparams.save_models_every_epochs and (epoch_index % ppo_hparams.save_models_every_epochs == 0 or (epoch_index + 1) == num_target_iterations)): ckpt_path = os.path.join( model_dir, "model.ckpt-{}".format( tf.train.global_step(sess, global_step))) model_saver.save(sess, ckpt_path)
def initialize(self, sess): model_loader = tf.train.Saver( var_list=tf.global_variables(scope="next_frame*") # pylint:disable=unexpected-keyword-arg ) trainer_lib.restore_checkpoint( self._model_dir, saver=model_loader, sess=sess, must_restore=True )
def evaluate(self, env_fn, hparams, sampling_temp): with tf.Graph().as_default(): with tf.name_scope("rl_eval"): eval_env = env_fn(in_graph=True) (collect_memory, _, collect_init) = _define_collect( eval_env, hparams, "ppo_eval", eval_phase=True, frame_stack_size=self.frame_stack_size, force_beginning_resets=False, sampling_temp=sampling_temp, distributional_size=self._distributional_size, ) model_saver = tf.train.Saver( tf.global_variables(hparams.policy_network + "/.*") # tf.global_variables("clean_scope.*") # Needed for sharing params. ) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) collect_init(sess) trainer_lib.restore_checkpoint(self.agent_model_dir, model_saver, sess) sess.run(collect_memory)
def encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path, out_files): """Encode all frames in dataset with model and write them out to out_files.""" batch_size = 8 dataset = dataset.batch(batch_size) examples = dataset.make_one_shot_iterator().get_next() images = examples.pop("frame") images = tf.cast(images, tf.int32) encoded = model.encode(images) encoded_frame_height = int( math.ceil(problem.frame_height / 2**ae_hparams.num_hidden_layers)) encoded_frame_width = int( math.ceil(problem.frame_width / 2**ae_hparams.num_hidden_layers)) num_bits = 8 encoded = tf.reshape( encoded, [-1, encoded_frame_height, encoded_frame_width, 3, num_bits]) encoded = tf.cast(discretization.bit_to_int(encoded, num_bits), tf.uint8) pngs = tf.map_fn(tf.image.encode_png, encoded, dtype=tf.string, back_prop=False) with tf.Session() as sess: autoencoder_saver = tf.train.Saver( tf.global_variables("autoencoder.*")) trainer_lib.restore_checkpoint(autoencoder_path, autoencoder_saver, sess, must_restore=True) def generator(): """Generate examples.""" while True: try: pngs_np, examples_np = sess.run([pngs, examples]) rewards = examples_np["reward"].tolist() actions = examples_np["action"].tolist() frame_numbers = examples_np["frame_number"].tolist() for action, reward, frame_number, png in \ zip(actions, rewards, frame_numbers, pngs_np): yield { "action": action, "reward": reward, "frame_number": frame_number, "image/encoded": [png], "image/format": ["png"], "image/height": [encoded_frame_height], "image/width": [encoded_frame_width], } except tf.errors.OutOfRangeError: break generator_utils.generate_files( generator(), out_files, cycle_every_n=problem.total_number_of_frames // 10)
def train(hparams, event_dir=None, model_dir=None, restore_agent=True, epoch=0): """Train.""" with tf.name_scope("rl_train"): train_summary_op, _, initialization = define_train(hparams, event_dir) if event_dir: summary_writer = tf.summary.FileWriter( event_dir, graph=tf.get_default_graph(), flush_secs=60) if model_dir: model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) else: summary_writer = None model_saver = None # TODO(piotrmilos): This should be refactored, possibly with # handlers for each type of env if hparams.environment_spec.simulated_env: env_model_loader = tf.train.Saver( tf.global_variables("next_frame*")) else: env_model_loader = None with tf.Session() as sess: sess.run(tf.global_variables_initializer()) initialization(sess) if env_model_loader: trainer_lib.restore_checkpoint( hparams.world_model_dir, env_model_loader, sess, must_restore=True) start_step = 0 if model_saver and restore_agent: start_step = trainer_lib.restore_checkpoint( model_dir, model_saver, sess) # Fail-friendly, don't train if already trained for this epoch if start_step >= ((hparams.epochs_num * (epoch + 1))): tf.logging.info("Skipping PPO training for epoch %d as train steps " "(%d) already reached", epoch, start_step) return for epoch_index in range(hparams.epochs_num): summary = sess.run(train_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) if (hparams.eval_every_epochs and epoch_index % hparams.eval_every_epochs == 0): if summary_writer and summary: summary_writer.add_summary(summary, epoch_index) else: tf.logging.info("Eval summary not saved") if (model_saver and hparams.save_models_every_epochs and (epoch_index % hparams.save_models_every_epochs == 0 or (epoch_index + 1) == hparams.epochs_num)): ckpt_path = os.path.join( model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step)) model_saver.save(sess, ckpt_path)
def train(hparams, environment_spec, event_dir=None, model_dir=None, restore_agent=True): """Train.""" with tf.name_scope("rl_train"): train_summary_op, eval_summary_op = define_train( hparams, environment_spec, event_dir) if event_dir: summary_writer = tf.summary.FileWriter( event_dir, graph=tf.get_default_graph(), flush_secs=60) if model_dir: model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) else: summary_writer = None model_saver = None if hparams.simulated_environment: env_model_loader = tf.train.Saver( tf.global_variables("basic_conv_gen.*")) else: env_model_loader = None with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if env_model_loader: trainer_lib.restore_checkpoint(hparams.world_model_dir, env_model_loader, sess, must_restore=True) start_step = 0 if model_saver and restore_agent: start_step = trainer_lib.restore_checkpoint( model_dir, model_saver, sess) for epoch_index in range(hparams.epochs_num): summary = sess.run(train_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) if (hparams.eval_every_epochs and epoch_index % hparams.eval_every_epochs == 0): summary = sess.run(eval_summary_op) if summary_writer and summary: summary_writer.add_summary(summary, epoch_index) else: tf.logging.info("Eval summary not saved") if (model_saver and hparams.save_models_every_epochs and (epoch_index % hparams.save_models_every_epochs == 0 or (epoch_index + 1) == hparams.epochs_num)): ckpt_path = os.path.join( model_dir, "model.ckpt-{}".format(epoch_index + start_step)) model_saver.save(sess, ckpt_path)
def initialize(self, sess): model_loader = tf.train.Saver( var_list=tf.global_variables(scope="next_frame*") # pylint:disable=unexpected-keyword-arg ) # TODO(afrozm): use TF methods to be on the safe side here. if os.path.isdir(self._model_dir): trainer_lib.restore_checkpoint(self._model_dir, saver=model_loader, sess=sess, must_restore=True) else: model_loader.restore(sess=sess, save_path=self._model_dir)
def encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path, out_files): """Encode all frames in dataset with model and write them out to out_files.""" batch_size = 8 dataset = dataset.batch(batch_size) examples = dataset.make_one_shot_iterator().get_next() images = examples.pop("frame") images = tf.expand_dims(images, 1) encoded = model.encode(images) encoded_frame_height = int( math.ceil(problem.frame_height / 2**ae_hparams.num_hidden_layers)) encoded_frame_width = int( math.ceil(problem.frame_width / 2**ae_hparams.num_hidden_layers)) num_bits = 8 encoded = tf.reshape( encoded, [-1, encoded_frame_height, encoded_frame_width, 3, num_bits]) encoded = tf.cast(discretization.bit_to_int(encoded, num_bits), tf.uint8) pngs = tf.map_fn(tf.image.encode_png, encoded, dtype=tf.string, back_prop=False) with tf.Session() as sess: autoencoder_saver = tf.train.Saver(tf.global_variables("autoencoder.*")) trainer_lib.restore_checkpoint(autoencoder_path, autoencoder_saver, sess, must_restore=True) def generator(): """Generate examples.""" while True: try: pngs_np, examples_np = sess.run([pngs, examples]) rewards_np = [list(el) for el in examples_np["reward"]] actions_np = [list(el) for el in examples_np["action"]] pngs_np = [el for el in pngs_np] for action, reward, png in zip(actions_np, rewards_np, pngs_np): yield { "action": action, "reward": reward, "image/encoded": [png], "image/format": ["png"], "image/height": [encoded_frame_height], "image/width": [encoded_frame_width], } except tf.errors.OutOfRangeError: break generator_utils.generate_files( generator(), out_files, cycle_every_n=problem.total_number_of_frames // 10)
def evaluate(hparams, model_dir, name_scope="rl_eval"): """Evaluate.""" hparams = copy.copy(hparams) hparams.add_hparam("eval_phase", True) with tf.Graph().as_default(): with tf.name_scope(name_scope): (collect_memory, _, collect_init) = collect.define_collect(hparams, "ppo_eval") model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) collect_init(sess) trainer_lib.restore_checkpoint(model_dir, model_saver, sess) sess.run(collect_memory)
def _run_train(ppo_hparams, event_dir, model_dir, num_target_iterations, train_summary_op, eval_summary_op, initializers, report_fn=None): """Train.""" summary_writer = tf.summary.FileWriter(event_dir, graph=tf.get_default_graph(), flush_secs=60) model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*")) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for initializer in initializers: initializer(sess) num_completed_iterations = trainer_lib.restore_checkpoint( model_dir, model_saver, sess) # Fail-friendly, complete only unfinished epoch num_iterations_to_go = num_target_iterations - num_completed_iterations if num_iterations_to_go <= 0: tf.logging.info( "Skipping PPO training. Requested %d iterations while %d train " "iterations already reached", num_target_iterations, num_completed_iterations) return for epoch_index in range(num_iterations_to_go): summary = sess.run(train_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) if (ppo_hparams.eval_every_epochs and epoch_index % ppo_hparams.eval_every_epochs == 0): eval_summary = sess.run(eval_summary_op) if summary_writer: summary_writer.add_summary(eval_summary, epoch_index) if report_fn: summary_proto = tf.Summary() summary_proto.ParseFromString(eval_summary) for elem in summary_proto.value: if "mean_score" in elem.tag: report_fn(elem.simple_value, epoch_index) break epoch_index_and_start = epoch_index + num_completed_iterations if (model_saver and ppo_hparams.save_models_every_epochs and (epoch_index_and_start % ppo_hparams.save_models_every_epochs == 0 or (epoch_index + 1) == num_iterations_to_go)): ckpt_path = os.path.join( model_dir, "model.ckpt-{}".format(epoch_index + 1 + num_completed_iterations)) model_saver.save(sess, ckpt_path)
def __init__(self, environment_spec, batch_size, model_dir=None, sess=None): self.batch_size = batch_size with tf.Graph().as_default(): self._batch_env = SimulatedBatchEnv(environment_spec, self.batch_size) self.action_space = self._batch_env.action_space # TODO(kc): check for the stack wrapper and correct number of channels in # observation_space self.observation_space = self._batch_env.observ_space self._sess = sess if sess is not None else tf.Session() self._to_initialize = [self._batch_env] environment_wrappers = environment_spec.wrappers wrappers = copy.copy( environment_wrappers) if environment_wrappers else [] for w in wrappers: self._batch_env = w[0](self._batch_env, **w[1]) self._to_initialize.append(self._batch_env) self._sess.run(tf.global_variables_initializer()) for wrapped_env in self._to_initialize: wrapped_env.initialize(self._sess) self._actions_t = tf.placeholder(shape=(batch_size, ), dtype=tf.int32) self._rewards_t, self._dones_t = self._batch_env.simulate( self._actions_t) self._obs_t = self._batch_env.observ self._reset_op = self._batch_env.reset( tf.range(batch_size, dtype=tf.int32)) env_model_loader = tf.train.Saver( var_list=tf.global_variables(scope="next_frame*")) # pylint:disable=unexpected-keyword-arg trainer_lib.restore_checkpoint(model_dir, saver=env_model_loader, sess=self._sess, must_restore=True)
def __init__(self, hparams, action_space, observation_space, policy_dir): assert hparams.base_algo == "ppo" ppo_hparams = trainer_lib.create_hparams(hparams.base_algo_params) frame_stack_shape = ( 1, hparams.frame_stack_size) + observation_space.shape self._frame_stack = np.zeros(frame_stack_shape, dtype=np.uint8) with tf.Graph().as_default(): self.obs_t = tf.placeholder(shape=self.frame_stack_shape, dtype=np.uint8) self.logits_t, self.value_function_t = get_policy( self.obs_t, ppo_hparams, action_space) model_saver = tf.train.Saver( tf.global_variables(scope=ppo_hparams.policy_network + "/.*") # pylint: disable=unexpected-keyword-arg ) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) trainer_lib.restore_checkpoint(policy_dir, model_saver, self.sess)
def train(hparams, event_dir=None, model_dir=None, restore_agent=True, name_scope="rl_train", report_fn=None): """Train.""" with tf.Graph().as_default(): with tf.name_scope(name_scope): train_summary_op, eval_summary_op, initializers = define_train( hparams) if event_dir: summary_writer = tf.summary.FileWriter( event_dir, graph=tf.get_default_graph(), flush_secs=60) else: summary_writer = None if model_dir: model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) else: model_saver = None with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for initializer in initializers: initializer(sess) start_step = 0 if model_saver and restore_agent: start_step = trainer_lib.restore_checkpoint( model_dir, model_saver, sess) # Fail-friendly, complete only unfinished epoch steps_to_go = hparams.epochs_num - start_step if steps_to_go <= 0: tf.logging.info( "Skipping PPO training. Requested %d steps while " "%d train steps already reached", hparams.epochs_num, start_step) return for epoch_index in range(steps_to_go): summary = sess.run(train_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) if (hparams.eval_every_epochs and epoch_index % hparams.eval_every_epochs == 0): eval_summary = sess.run(eval_summary_op) if summary_writer: summary_writer.add_summary(eval_summary, epoch_index) if report_fn: summary_proto = tf.Summary() summary_proto.ParseFromString(eval_summary) for elem in summary_proto.value: if "mean_score" in elem.tag: report_fn(elem.simple_value, epoch_index) break epoch_index_and_start = epoch_index + start_step if (model_saver and hparams.save_models_every_epochs and (epoch_index_and_start % hparams.save_models_every_epochs == 0 or (epoch_index + 1) == steps_to_go)): ckpt_path = os.path.join( model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step)) model_saver.save(sess, ckpt_path)
def main(_): hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) output_dir = FLAGS.output_dir subdirectories = ["data", "tmp", "world_model", "ppo"] using_autoencoder = hparams.autoencoder_train_steps > 0 if using_autoencoder: subdirectories.append("autoencoder") directories = setup_directories(output_dir, subdirectories) if hparams.game in gym_problems_specs.ATARI_GAMES: game_with_mode = hparams.game + "_deterministic-v4" else: game_with_mode = hparams.game if using_autoencoder: simulated_problem_name = ( "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" % game_with_mode) else: simulated_problem_name = ( "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode) if simulated_problem_name not in registry.list_problems(): tf.logging.info( "Game Problem %s not found; dynamically registering", simulated_problem_name) gym_problems_specs.create_problems_for_game( hparams.game, game_mode="Deterministic-v4") epoch = hparams.epochs - 1 epoch_data_dir = os.path.join(directories["data"], str(epoch)) ppo_model_dir = directories["ppo"] world_model_dir = directories["world_model"] gym_problem = registry.problem(simulated_problem_name) model_hparams = trainer_lib.create_hparams(hparams.generative_model_params) environment_spec = copy.copy(gym_problem.environment_spec) environment_spec.simulation_random_starts = hparams.simulation_random_starts batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params) batch_env_hparams.add_hparam("model_hparams", model_hparams) batch_env_hparams.add_hparam("environment_spec", environment_spec) batch_env_hparams.num_agents = 1 with temporary_flags({ "problem": simulated_problem_name, "model": hparams.generative_model, "hparams_set": hparams.generative_model_params, "output_dir": world_model_dir, "data_dir": epoch_data_dir, }): sess = tf.Session() env = DebugBatchEnv(batch_env_hparams, sess) sess.run(tf.global_variables_initializer()) env.initialize() env_model_loader = tf.train.Saver(tf.global_variables("next_frame*")) trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess, must_restore=True) model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess) key_mapping = gym_problem.env.env.get_keys_to_action() # map special codes key_mapping[()] = 100 key_mapping[(ord("r"), )] = 101 key_mapping[(ord("p"), )] = 102 play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
def train(hparams, event_dir=None, model_dir=None, restore_agent=True, name_scope="rl_train"): """Train.""" with tf.Graph().as_default(): with tf.name_scope(name_scope): train_summary_op, _, initialization = define_train(hparams) if event_dir: summary_writer = tf.summary.FileWriter( event_dir, graph=tf.get_default_graph(), flush_secs=60) else: summary_writer = None if model_dir: model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) else: model_saver = None # TODO(piotrmilos): This should be refactored, possibly with # handlers for each type of env if hparams.environment_spec.simulated_env: env_model_loader = tf.train.Saver( tf.global_variables("next_frame*")) else: env_model_loader = None with tf.Session() as sess: sess.run(tf.global_variables_initializer()) initialization(sess) if env_model_loader: trainer_lib.restore_checkpoint(hparams.world_model_dir, env_model_loader, sess, must_restore=True) start_step = 0 if model_saver and restore_agent: start_step = trainer_lib.restore_checkpoint( model_dir, model_saver, sess) # Fail-friendly, complete only unfinished epoch steps_to_go = hparams.epochs_num - start_step if steps_to_go <= 0: tf.logging.info( "Skipping PPO training. Requested %d steps while " "%d train steps already reached", hparams.epochs_num, start_step) return for epoch_index in range(steps_to_go): summary = sess.run(train_summary_op) if summary_writer: summary_writer.add_summary(summary, epoch_index) if (hparams.eval_every_epochs and epoch_index % hparams.eval_every_epochs == 0): if summary_writer and summary: summary_writer.add_summary(summary, epoch_index) else: tf.logging.info("Eval summary not saved") epoch_index_and_start = epoch_index + start_step if (model_saver and hparams.save_models_every_epochs and (epoch_index_and_start % hparams.save_models_every_epochs == 0 or (epoch_index + 1) == steps_to_go)): ckpt_path = os.path.join( model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step)) model_saver.save(sess, ckpt_path)
def main(_): hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) output_dir = FLAGS.output_dir subdirectories = ["data", "tmp", "world_model", "ppo"] using_autoencoder = hparams.autoencoder_train_steps > 0 if using_autoencoder: subdirectories.append("autoencoder") directories = setup_directories(output_dir, subdirectories) if hparams.game in gym_env.ATARI_GAMES: game_with_mode = hparams.game + "_deterministic-v4" else: game_with_mode = hparams.game if using_autoencoder: simulated_problem_name = ( "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" % game_with_mode) else: simulated_problem_name = ("gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode) if simulated_problem_name not in registry.list_problems(): tf.logging.info("Game Problem %s not found; dynamically registering", simulated_problem_name) gym_env.register_game(hparams.game, game_mode="Deterministic-v4") epoch = hparams.epochs-1 epoch_data_dir = os.path.join(directories["data"], str(epoch)) ppo_model_dir = directories["ppo"] world_model_dir = directories["world_model"] gym_problem = registry.problem(simulated_problem_name) model_hparams = trainer_lib.create_hparams(hparams.generative_model_params) environment_spec = copy.copy(gym_problem.environment_spec) environment_spec.simulation_random_starts = hparams.simulation_random_starts batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params) batch_env_hparams.add_hparam("model_hparams", model_hparams) batch_env_hparams.add_hparam("environment_spec", environment_spec) batch_env_hparams.num_agents = 1 with temporary_flags({ "problem": simulated_problem_name, "model": hparams.generative_model, "hparams_set": hparams.generative_model_params, "output_dir": world_model_dir, "data_dir": epoch_data_dir, }): sess = tf.Session() env = DebugBatchEnv(batch_env_hparams, sess) sess.run(tf.global_variables_initializer()) env.initialize() env_model_loader = tf.train.Saver( tf.global_variables("next_frame*")) trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess, must_restore=True) model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess) key_mapping = gym_problem.env.env.get_keys_to_action() # map special codes key_mapping[()] = 100 key_mapping[(ord("r"),)] = 101 key_mapping[(ord("p"),)] = 102 play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)