def _initialize_graph(self, sess): """Initialize the graph for sess.""" self._train_checkpointer.initialize_or_restore(sess) self._rb_checkpointer.initialize_or_restore(sess) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(self._ds_itr.initializer) sess.run(self._init_agent_op) self._train_step_call = sess.make_callable( [self._train_op, self._summary_op]) self._collect_timer = timer.Timer() self._train_timer = timer.Timer() self._action_timer = timer.Timer() self._step_timer = timer.Timer() self._observer_timer = timer.Timer() global_step_val = sess.run(self._global_step) self._timed_at_step = global_step_val # Call save to initialize the save_counter (need to do this before # finalizing the graph). self._train_checkpointer.save(global_step=global_step_val) self._policy_checkpointer.save(global_step=global_step_val) self._rb_checkpointer.save(global_step=global_step_val) tf.contrib.summary.initialize(session=sess, graph=tf.get_default_graph())
def _initialize_graph(self, sess): """Initialize the graph for sess.""" self._train_checkpointer.initialize_or_restore(sess) self._rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(self._init_agent_op) self._train_step_call = sess.make_callable(self._train_op) self._collect_timer = timer.Timer() self._train_timer = timer.Timer() self._action_timer = timer.Timer() self._step_timer = timer.Timer() self._observer_timer = timer.Timer() global_step_val = sess.run(self._global_step) self._timed_at_step = global_step_val # Call save to initialize the save_counter (need to do this before # finalizing the graph). self._train_checkpointer.save(global_step=global_step_val) self._policy_checkpointer.save(global_step=global_step_val) self._rb_checkpointer.save(global_step=global_step_val) sess.run(self._train_summary_writer.init()) if self._do_eval: sess.run(self._eval_summary_writer.init())
def testObjectiveDependentLosses(self): networks_and_loss_fns = self._create_objective_network_and_loss_fn_sequence( ) networks_and_loss_fns[1] = (networks_and_loss_fns[1][0], tf.compat.v1.losses.sigmoid_cross_entropy) networks_and_loss_fns[2] = (networks_and_loss_fns[2][0], tf.compat.v1.losses.absolute_difference) agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent( self._time_step_spec, self._action_spec, self._scalarizer, objective_network_and_loss_fn_sequence=networks_and_loss_fns, optimizer=None) observations = np.array([[0.1, 0.2], [1, 0.5]], dtype=np.float32) actions = np.array([0, 1], dtype=np.int32) objectives = np.array([[0.2, 1, 1.5], [4, 0, 5.5]], dtype=np.float32) initial_step, final_step = _get_initial_and_final_steps( observations, objectives) action_step = _get_action_step(actions) experience = _get_experience(initial_step, action_step, final_step) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent._loss(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertAllClose(self.evaluate(loss), 2.410641)
def testComputeLossWithArmFeatures(self): obs_spec = bandit_spec_utils.create_per_arm_observation_spec( global_dim=2, per_arm_dim=3, num_actions=3) time_step_spec = ts.time_step_spec(obs_spec) constraint_net = (global_and_arm_feature_network. create_feed_forward_common_tower_network( obs_spec, global_layers=(4, ), arm_layers=(4, ), common_layers=(4, ))) neural_constraint = constraints.NeuralConstraint( time_step_spec, self._action_spec, constraint_network=constraint_net) observations = { bandit_spec_utils.GLOBAL_FEATURE_KEY: tf.constant([[1, 2], [3, 4]], dtype=tf.float32), bandit_spec_utils.PER_ARM_FEATURE_KEY: tf.cast(tf.reshape(tf.range(18), shape=[2, 3, 3]), dtype=tf.float32) } actions = tf.constant([0, 1], dtype=tf.int32) rewards = tf.constant([0.5, 3.0], dtype=tf.float32) init_op = neural_constraint.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss = neural_constraint.compute_loss(observations, actions, rewards) self.assertGreater(self.evaluate(loss), 0.0)
def load_pol_ckpt(train_eval_dir, sess, eval_global_step, meld_agent, global_step, eval_second_pol): train_dir = os.path.join(train_eval_dir, 'train') pol_name = 'policy' if eval_second_pol: pol_name = 'policy2' actual_loaded_step = set_loading_step(train_dir, eval_global_step, pol_name) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, pol_name), policy=meld_agent.policy, global_step=global_step, max_to_keep=99999999999 ) # keep many policy checkpoints, in case of future eval policy_status = policy_checkpointer.initialize_or_restore(sess) # Initialize variables common.initialize_uninitialized_variables(sess) set_global_step(global_step, sess, actual_loaded_step) # make the checkpoint file pointing back to the latest checkpoint set_loading_step(train_dir, step=None) return actual_loaded_step, policy_status
def testInitializeAgent(self): agent = bern_ts_agent.BernoulliThompsonSamplingAgent( self._time_step_spec, self._action_spec) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def testInitializeConstraint(self): constraint_net = DummyNet(self._observation_spec, self._action_spec) neural_constraint = constraints.NeuralConstraint( self._time_step_spec, self._action_spec, constraint_network=constraint_net) init_op = neural_constraint.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def testInitializeAgent(self, agent_class): q_net = DummyNet(self._observation_spec, self._action_spec) agent = agent_class(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def initialize(self, batch_size, graph=None): if self._built: raise RuntimeError('PyTFPolicy can only be initialized once.') if not graph: graph = tf.compat.v1.get_default_graph() self._construct(batch_size, graph) var_list = tf.nest.flatten(self._tf_policy.variables()) common.initialize_uninitialized_variables(self.session, var_list) self._built = True
def testInitializeAgent(self): agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent( self._time_step_spec, self._action_spec, self._scalarizer, objective_networks=self._create_objective_networks(), optimizer=None) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def testInitializeAgent(self): reward_net = DummyNet(self._observation_spec, self._action_spec) agent = greedy_agent.GreedyRewardPredictionAgent( self._time_step_spec, self._action_spec, reward_network=reward_net, optimizer=None) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def collect(tf_env, tf_policy, output_dir, checkpoint=None, num_iterations=500000, episodes_per_file=500, summary_interval=1000): """A simple train and eval for SAC.""" if not os.path.isdir(output_dir): logger.info('Making output directory %s...', output_dir) os.makedirs(output_dir) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Make the replay buffer. replay_buffer = tfrecord_replay_buffer.TFRecordReplayBuffer( data_spec=tf_policy.trajectory_spec, experiment_id='exp', file_prefix=os.path.join(output_dir, 'data'), episodes_per_file=episodes_per_file) replay_observer = [replay_buffer.add_batch] collect_policy = tf_policy collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=1).run() with tf.compat.v1.Session() as sess: # Initialize training. try: common.initialize_uninitialized_variables(sess) except Exception: pass # Restore checkpoint. if checkpoint is not None: if os.path.isdir(checkpoint): train_dir = os.path.join(checkpoint, 'train') checkpoint_path = tf.train.latest_checkpoint(train_dir) else: checkpoint_path = checkpoint restorer = tf.train.Saver(name='restorer') restorer.restore(sess, checkpoint_path) collect_call = sess.make_callable(collect_op) for _ in range(num_iterations): collect_call()
def testInitializeAgent(self, agent_class, run_mode): if tf.executing_eagerly() and run_mode == context.graph_mode: self.skipTest('b/123778560') with run_mode(): q_net = DummyNet(self._observation_spec, self._action_spec) agent = agent_class(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def testComputeActionFeasibility(self): constraint_net = DummyNet(self._observation_spec, self._action_spec) neural_constraint = constraints.NeuralConstraint( self._time_step_spec, self._action_spec, constraint_network=constraint_net) init_op = neural_constraint.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) observation = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) feasibility_prob = neural_constraint(observation) self.assertAllClose(self.evaluate(feasibility_prob), np.ones([2, 3]))
def testComputeActionFeasibility(self): constraint_net = DummyNet(self._observation_spec, self._action_spec) quantile_constraint = constraints.QuantileConstraint( self._time_step_spec, self._action_spec, constraint_network=constraint_net) init_op = quantile_constraint.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) observation = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) feasibility_prob = quantile_constraint(observation) self.assertAllGreaterEqual(self.evaluate(feasibility_prob), 0.0) self.assertAllLessEqual(self.evaluate(feasibility_prob), 1.0)
def testComputeLoss(self): constraint_net = DummyNet(self._observation_spec, self._action_spec) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) actions = tf.constant([0, 1], dtype=tf.int32) rewards = tf.constant([0.5, 3.0], dtype=tf.float32) neural_constraint = constraints.NeuralConstraint( self._time_step_spec, self._action_spec, constraint_network=constraint_net) init_op = neural_constraint.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss = neural_constraint.compute_loss(observations, actions, rewards) self.assertAllClose(self.evaluate(loss), 42.25)
def testLoss(self): agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent( self._time_step_spec, self._action_spec, self._scalarizer, objective_networks=self._create_objective_networks(), optimizer=None) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) actions = tf.constant([0, 1], dtype=tf.int32) objectives = tf.constant([[8, 12, 11], [25, 18, 32]], dtype=tf.float32) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent.loss(observations, actions, objectives) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertAllClose(self.evaluate(loss), 0.0)
def testLoss(self): reward_net = DummyNet(self._observation_spec, self._action_spec) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) actions = tf.constant([0, 1], dtype=tf.int32) rewards = tf.constant([0.5, 3.0], dtype=tf.float32) agent = greedy_agent.GreedyRewardPredictionAgent( self._time_step_spec, self._action_spec, reward_network=reward_net, optimizer=None) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent.loss(observations, actions, rewards) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertAllClose(self.evaluate(loss), 42.25)
def testTrainAgent(self): observations = np.array([[1, 1]], dtype=np.float32) actions = np.array([0, 1], dtype=np.int32) rewards = np.array([0.0, 1.0], dtype=np.float32) initial_step, final_step = _get_initial_and_final_steps( observations, rewards) action_step = _get_action_step(actions) experience = _get_experience(initial_step, action_step, final_step) agent = bern_ts_agent.BernoulliThompsonSamplingAgent( self._time_step_spec, self._action_spec, batch_size=2) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent._train(experience, weights=None) self.evaluate(tf.compat.v1.initialize_all_variables()) # The loss is -sum(rewards). self.assertAllClose(self.evaluate(loss), -1.0)
def testLoss(self): reward_net = DummyNet(self._observation_spec, self._action_spec) observations = np.array([[1, 2], [3, 4]], dtype=np.float32) actions = np.array([0, 1], dtype=np.int32) rewards = np.array([0.5, 3.0], dtype=np.float32) initial_step, final_step = _get_initial_and_final_steps_nested_rewards( observations, rewards) action_step = _get_action_step(actions) experience = _get_experience(initial_step, action_step, final_step) agent = greedy_agent.GreedyRewardPredictionAgent( self._time_step_spec, self._action_spec, reward_network=reward_net, optimizer=None) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent._loss(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertAllClose(self.evaluate(loss), 42.25)
def testLoss(self): agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent( self._time_step_spec, self._action_spec, self._scalarizer, objective_network_and_loss_fn_sequence=self. _create_objective_network_and_loss_fn_sequence(), optimizer=None) observations = np.array([[1, 2], [3, 4]], dtype=np.float32) actions = np.array([0, 1], dtype=np.int32) objectives = np.array([[8, 12, 11], [25, 18, 32]], dtype=np.float32) initial_step, final_step = _get_initial_and_final_steps( observations, objectives) action_step = _get_action_step(actions) experience = _get_experience(initial_step, action_step, final_step) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent._loss(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertAllClose(self.evaluate(loss), 0.0)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, agent_class=dqn_agent.DqnAgent, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() train_op = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) global_step_call = sess.make_callable(global_step) train_step_call = sess.make_callable([train_op, summary_ops]) timed_at_step = global_step_call() collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ( (global_step_val - timed_at_step) / (collect_time + train_time)) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info('%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss = sess.run(train_op) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories, train_step_counter=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() global_step_val = global_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True)
def testTrainWithRnn(self): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40, ), ) critic_net = critic_rnn_network.CriticRnnNetwork( (self._obs_spec, self._action_spec), observation_fc_layer_params=(16, ), action_fc_layer_params=(16, ), joint_fc_layer_params=(16, ), lstm_size=(16, ), output_fc_layer_params=None, ) counter = common.create_variable('test_train_counter') optimizer_fn = tf.compat.v1.train.AdamOptimizer agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_net, actor_network=actor_net, actor_optimizer=optimizer_fn(1e-3), critic_optimizer=optimizer_fn(1e-3), alpha_optimizer=optimizer_fn(1e-3), train_step_counter=counter, ) batch_size = 5 observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=observations) experience = trajectory.Trajectory(time_steps.step_type, observations, actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if not tf.executing_eagerly(): # Get experience first to make sure optimizer variables are created and # can be initialized. experience = agent.train(experience) with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertEqual(self.evaluate(counter), 0) self.evaluate(experience) self.assertEqual(self.evaluate(counter), 1) else: self.assertEqual(self.evaluate(counter), 0) self.evaluate(agent.train(experience)) self.assertEqual(self.evaluate(counter), 1)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( ############################################## # types of params: # 0: specific to algorithm (gin file 0) # 1: specific to environment (gin file 1) # 2: specific to experiment (gin file 2 + command line) # Note: there are other important params # in eg ModelDistributionNetwork that the gin files specify # like sparse vs dense rewards, latent dimensions, etc. ############################################## # basic params for running/logging experiment root_dir, # 2 experiment_name, # 2 num_iterations=int(1e7), # 2 seed=1, # 2 gpu_allow_growth=False, # 2 gpu_memory_limit=None, # 2 verbose=True, # 2 policy_checkpoint_freq_in_iter=100, # policies needed for future eval # 2 train_checkpoint_freq_in_iter=0, #default don't save # 2 rb_checkpoint_freq_in_iter=0, #default don't save # 2 logging_freq_in_iter=10, # printing to terminal # 2 summary_freq_in_iter=10, # saving to tb # 2 num_images_per_summary=2, # 2 summaries_flush_secs=10, # 2 max_episode_len_override=None, # 2 num_trials_to_render=1, # 2 # environment, action mode, etc. env_name='HalfCheetah-v2', # 1 action_repeat=1, # 1 action_mode='joint_position', # joint_position or joint_delta_position # 1 double_camera=False, # camera input # 1 universe='gym', # default task_reward_dim=1, # default # dims for all networks actor_fc_layers=(256, 256), # 1 critic_obs_fc_layers=None, # 1 critic_action_fc_layers=None, # 1 critic_joint_fc_layers=(256, 256), # 1 num_repeat_when_concatenate=None, # 1 # networks critic_input='state', # 0 actor_input='state', # 0 # specifying tasks and eval episodes_per_trial=1, # 2 num_train_tasks=10, # 2 num_eval_tasks=10, # 2 num_eval_trials=10, # 2 eval_interval=10, # 2 eval_on_holdout_tasks=True, # 2 # data collection/buffer init_collect_trials_per_task=None, # 2 collect_trials_per_task=None, # 2 num_tasks_to_collect_per_iter=5, # 2 replay_buffer_capacity=int(1e5), # 2 # training init_model_train_ratio=0.8, # 2 model_train_ratio=1, # 2 model_train_freq=1, # 2 ac_train_ratio=1, # 2 ac_train_freq=1, # 2 num_tasks_per_train=5, # 2 train_trials_per_task=5, # 2 model_bs_in_steps=256, # 2 ac_bs_in_steps=128, # 2 # default AC learning rates, gamma, etc. target_update_tau=0.005, target_update_period=1, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, td_errors_loss_fn=functools.partial( tf.compat.v1.losses.mean_squared_error, weights=0.5), gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, log_image_strips=False, stop_model_training=1E10, eval_only=False, # evaluate checkpoints ONLY log_image_observations=False, load_offline_data=False, # whether to use offline data offline_data_dir=None, # replay buffer's dir offline_episode_len=None, # episode len of episodes stored in rb offline_ratio=0, # ratio of data that is from offline buffer ): g = tf.Graph() # register all gym envs max_steps_dict = { "HalfCheetahVel-v0": 50, "SawyerReach-v0": 40, "SawyerReachMT-v0": 40, "SawyerPeg-v0": 40, "SawyerPegMT-v0": 40, "SawyerPegMT4box-v0": 40, "SawyerShelfMT-v0": 40, "SawyerKitchenMT-v0": 40, "SawyerShelfMT-v2": 40, "SawyerButtons-v0": 40, } if max_episode_len_override: max_steps_dict[env_name] = max_episode_len_override register_all_gym_envs(max_steps_dict) # set max_episode_len based on our env max_episode_len = max_steps_dict[env_name] ###################################################### # Calculate additional params ###################################################### # convert to number of steps env_steps_per_trial = episodes_per_trial * max_episode_len real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1) env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial per_task_collect_steps = collect_trials_per_task * env_steps_per_trial # initial collect + train init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial init_model_train_steps = int(init_collect_env_steps * init_model_train_ratio) # collect + train collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio) ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio) # other global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter sample_episodes_per_task = train_trials_per_task * episodes_per_trial # number of episodes to sample from each replay model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial # assertions that make sure parameters make sense assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial" assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False" assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train" assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train" ###################################################### # Print a summary of params ###################################################### MELD_summary_string = f"""\n\n\n ============================================================== ============================================================== \n MELD algorithm summary: * each trial consists of {episodes_per_trial} episodes * episode length: {max_episode_len}, trial length: {env_steps_per_trial} * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks} * environment: {env_name} For each of {num_train_tasks} tasks: Do {init_collect_trials_per_task} trials of initial collect (total {init_collect_env_steps} env steps) Do {init_model_train_steps} steps of initial model training For i in range(inf): For each of {num_tasks_to_collect_per_iter} randomly selected tasks: Do {collect_trials_per_task} trials of collect (which is {collect_trials_per_task*env_steps_per_trial} env steps per task) (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration) if i % model_train_freq(={model_train_freq}): Do {model_train_steps_per_iter} steps of model training - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials. - pick randomly {model_bs_in_trials} trials, train model on whole trials. if i % ac_train_freq(={ac_train_freq}): Do {ac_train_steps_per_iter} steps of ac training - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials. - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, to train ac. * Other important params: Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps Average evaluation across {num_eval_trials} trials Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps Checkpoint: - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint \n ============================================================= ============================================================= """ print(MELD_summary_string) time.sleep(1) ###################################################### # Seed + name + GPU configs + directories for saving ###################################################### np.random.seed(int(seed)) experiment_name += "_seed" + str(seed) gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) train_eval_dir = get_train_eval_dir(root_dir, universe, env_name, experiment_name) train_dir = os.path.join(train_eval_dir, 'train') eval_dir = os.path.join(train_eval_dir, 'eval') eval_dir_2 = os.path.join(train_eval_dir, 'eval2') ###################################################### # Train and Eval Summary Writers ###################################################### train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_summary_flush_op = eval_summary_writer.flush() eval_logger = Logger(eval_dir_2) ###################################################### # Train and Eval metrics ###################################################### eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len # across all eval trials in each evaluation eval_metrics = [] for position in range( episodes_per_trial ): # have metrics for each episode position, to track whether it is learning eval_metrics_pos = [ py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' + str(position), buffer_size=eval_buffer_size), py_metrics.AverageEpisodeLengthMetric( name='f_AverageEpisodeLengthEval_' + str(position), buffer_size=eval_buffer_size), custom_metrics.AverageScoreMetric( name="d_AverageScoreMetricEval_" + str(position), buffer_size=eval_buffer_size), ] eval_metrics.extend(eval_metrics_pos) train_buffer_size = num_train_tasks * episodes_per_trial train_metrics = [ tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'), tf_metrics.EnvironmentSteps(name='EnvironmentSteps'), tf_py_metric.TFPyMetric( py_metrics.AverageReturnMetric(name="a_AverageReturnTrain", buffer_size=train_buffer_size)), tf_py_metric.TFPyMetric( py_metrics.AverageEpisodeLengthMetric( name="e_AverageEpisodeLengthTrain", buffer_size=train_buffer_size)), tf_py_metric.TFPyMetric( custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain", buffer_size=train_buffer_size)), ] global_step = tf.compat.v1.train.get_or_create_global_step( ) # will be use to record number of model grad steps + ac grad steps + env_step log_cond = get_log_condition_tensor( global_step, init_collect_trials_per_task, env_steps_per_trial, num_train_tasks, init_model_train_steps, collect_trials_per_task, num_tasks_to_collect_per_iter, model_train_steps_per_iter, ac_train_steps_per_iter, summary_freq_in_iter, eval_interval) with tf.compat.v2.summary.record_if(log_cond): ###################################################### # Create env ###################################################### py_env, eval_py_env, train_tasks, eval_tasks = load_environments( universe, action_mode, env_name=env_name, observations_whitelist=['state', 'pixels', "env_info"], action_repeat=action_repeat, num_train_tasks=num_train_tasks, num_eval_tasks=num_eval_tasks, eval_on_holdout_tasks=eval_on_holdout_tasks, return_multiple_tasks=True, ) override_reward_func = None if load_offline_data: py_env.set_task_dict(train_tasks) override_reward_func = py_env.override_reward_func tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True) # Get data specs from env time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() original_control_timestep = get_control_timestep(eval_py_env) # fps control_timestep = original_control_timestep * float(action_repeat) render_fps = int(np.round(1.0 / original_control_timestep)) ###################################################### # Latent variable model ###################################################### if verbose: print("-- start constructing model networks --") model_net = ModelDistributionNetwork( double_camera=double_camera, observation_spec=observation_spec, num_repeat_when_concatenate=num_repeat_when_concatenate, task_reward_dim=task_reward_dim, episodes_per_trial=episodes_per_trial, max_episode_len=max_episode_len ) # rest of arguments provided via gin if verbose: print("-- finish constructing AC networks --") ###################################################### # Compressor Network for Actor/Critic # The model's compressor is also used by the AC # compressor function: images --> features ###################################################### compressor_net = model_net.compressor ###################################################### # Specs for Actor and Critic ###################################################### if actor_input == 'state': actor_state_size = observation_spec['state'].shape[0] elif actor_input == 'latentSample': actor_state_size = model_net.state_size elif actor_input == "latentDistribution": actor_state_size = 2 * model_net.state_size # mean and (diagonal) variance of gaussian, of two latents else: raise NotImplementedError actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ), dtype=tf.float32) if critic_input == 'state': critic_state_size = observation_spec['state'].shape[0] elif critic_input == 'latentSample': critic_state_size = model_net.state_size elif critic_input == "latentDistribution": critic_state_size = 2 * model_net.state_size # mean and (diagonal) variance of gaussian, of two latents else: raise NotImplementedError critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ), dtype=tf.float32) ###################################################### # Actor and Critic Networks ###################################################### if verbose: print("-- start constructing Actor and Critic networks --") actor_net = actor_distribution_network.ActorDistributionNetwork( actor_input_spec, action_spec, fc_layer_params=actor_fc_layers, ) critic_net = critic_network.CriticNetwork( (critic_input_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) if verbose: print("-- finish constructing AC networks --") print("-- start constructing agent --") ###################################################### # Create the agent ###################################################### which_posterior_overwrite = None which_reward_overwrite = None meld_agent = MeldAgent( # specs time_step_spec=time_step_spec, action_spec=action_spec, # step counter train_step_counter= global_step, # will count number of model training steps # networks actor_network=actor_net, critic_network=critic_net, model_network=model_net, compressor_network=compressor_net, # optimizers actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), # target update target_update_tau=target_update_tau, target_update_period=target_update_period, # inputs critic_input=critic_input, actor_input=actor_input, # bs stuff model_batch_size=model_bs_in_steps, ac_batch_size=ac_bs_in_steps, # other num_tasks_per_train=num_tasks_per_train, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, task_reward_dim=task_reward_dim, episodes_per_trial=episodes_per_trial, # offline data override_reward_func=override_reward_func, offline_ratio=offline_ratio, ) if verbose: print("-- finish constructing agent --") ###################################################### # Replay buffers + observers to add data to them ###################################################### replay_buffers = [] replay_observers = [] for _ in range(num_train_tasks): replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer( meld_agent.collect_policy. trajectory_spec, # spec of each point stored in here (i.e. Trajectory) capacity=replay_buffer_capacity, completed_only= True, # in as_dataset, if num_steps is None, this means return full episodes # device='GPU:0', # gpu not supported for some reason begin_episode_fn=lambda traj: traj.is_first()[ 0], # first step of seq we add should be is_first end_episode_fn=lambda traj: traj.is_last()[ 0], # last step of seq we add should be is_last dataset_drop_remainder= True, #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items ) replay_buffer = StatefulEpisodicReplayBuffer( replay_buffer_episodic) # adding num_episodes here is bad replay_buffers.append(replay_buffer) replay_observers.append([replay_buffer.add_sequence]) if load_offline_data: # for each task, has a separate replay buffer for relabeled data replay_buffers_withRelabel = [] replay_observers_withRelabel = [] for _ in range(num_train_tasks): replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer( meld_agent.collect_policy. trajectory_spec, # spec of each point stored in here (i.e. Trajectory) capacity=replay_buffer_capacity, completed_only= True, # in as_dataset, if num_steps is None, this means return full episodes # device='GPU:0', # gpu not supported for some reason begin_episode_fn=lambda traj: traj.is_first()[ 0], # first step of seq we add should be is_first end_episode_fn=lambda traj: traj.is_last()[ 0], # last step of seq we add should be is_last dataset_drop_remainder=True, # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items ) replay_buffer_withRelabel = StatefulEpisodicReplayBuffer( replay_buffer_episodic_withRelabel ) # adding num_episodes here is bad replay_buffers_withRelabel.append(replay_buffer_withRelabel) replay_observers_withRelabel.append( [replay_buffer_withRelabel.add_sequence]) if verbose: print("-- finish constructing replay buffers --") print("-- start constructing policies and collect ops --") ###################################################### # Policies ##################################################### # init collect policy (random) init_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) # eval eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy) ################################################################################ # Collect ops : use policies to get data + have the observer put data into corresponding RB ################################################################################ #init collection (with random policy) init_collect_ops = [] for task_idx in range(num_train_tasks): # put init data into the rb + track with the train metric observers = replay_observers[task_idx] + train_metrics # initial collect op init_collect_op = DynamicTrialDriver( tf_env, init_collect_policy, num_trials_to_collect=init_collect_trials_per_task, observers=observers, episodes_per_trial= episodes_per_trial, # policy state will not be reset within these episodes max_episode_len=max_episode_len, ).run() # collect one trial init_collect_ops.append(init_collect_op) # data collection for training (with collect policy) collect_ops = [] for task_idx in range(num_train_tasks): collect_op = DynamicTrialDriver( tf_env, meld_agent.collect_policy, num_trials_to_collect=collect_trials_per_task, observers=replay_observers[task_idx] + train_metrics, # put data into 1st RB + track with 1st pol metrics episodes_per_trial= episodes_per_trial, # policy state will not be reset within these episodes max_episode_len=max_episode_len, ).run() # collect one trial collect_ops.append(collect_op) if verbose: print("-- finish constructing policies and collect ops --") print("-- start constructing replay buffer->training pipeline --") ###################################################### # replay buffer --> dataset --> iterate to get trajecs for training ###################################################### # get some data from all task replay buffers (even though won't actually train on all of them) dataset_iterators = [] all_tasks_trajectories_fromdense = [] for task_idx in range(num_train_tasks): dataset = replay_buffers[task_idx].as_dataset( sample_batch_size= sample_episodes_per_task, # number of episodes to sample num_steps=max_episode_len + 1 ).prefetch( 3 ) # +1 to include the last state: a trajectory with n transition has n+1 states # iterator to go through the data dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) dataset_iterators.append(dataset_iterator) # get sample_episodes_per_task sequences, each of length num_steps trajectories_task_i, _ = dataset_iterator.get_next() all_tasks_trajectories_fromdense.append(trajectories_task_i) if load_offline_data: # have separate dataset for relabel data dataset_iterators_withRelabel = [] all_tasks_trajectories_fromdense_withRelabel = [] for task_idx in range(num_train_tasks): dataset = replay_buffers_withRelabel[task_idx].as_dataset( sample_batch_size= sample_episodes_per_task, # number of episodes to sample num_steps=offline_episode_len + 1 ).prefetch( 3 ) # +1 to include the last state: a trajectory with n transition has n+1 states # iterator to go through the data dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) dataset_iterators_withRelabel.append(dataset_iterator) # get sample_episodes_per_task sequences, each of length num_steps trajectories_task_i, _ = dataset_iterator.get_next() all_tasks_trajectories_fromdense_withRelabel.append( trajectories_task_i) if verbose: print("-- finish constructing replay buffer->training pipeline --") print("-- start constructing model and AC training ops --") ###################################### # Decoding latent samples into rewards ###################################### latent_samples_1_ph = tf.compat.v1.placeholder( dtype=tf.float32, shape=(None, None, meld_agent._model_network.latent1_size)) latent_samples_2_ph = tf.compat.v1.placeholder( dtype=tf.float32, shape=(None, None, meld_agent._model_network.latent2_size)) decode_rews_op = meld_agent._model_network.decode_latents_into_reward( latent_samples_1_ph, latent_samples_2_ph) ###################################### # Model/Actor/Critic train + summary ops ###################################### # train AC on data from replay buffer if load_offline_data: ac_train_op = meld_agent.train_ac_meld( all_tasks_trajectories_fromdense, all_tasks_trajectories_fromdense_withRelabel) else: ac_train_op = meld_agent.train_ac_meld( all_tasks_trajectories_fromdense) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) if verbose: print("-- finish constructing AC training ops --") ############################ # Model train + summary ops ############################ # train model on data from replay buffer if load_offline_data: model_train_op, check_step_types = meld_agent.train_model_meld( all_tasks_trajectories_fromdense, all_tasks_trajectories_fromdense_withRelabel) else: model_train_op, check_step_types = meld_agent.train_model_meld( all_tasks_trajectories_fromdense) model_summary_ops, model_summary_ops_2 = [], [] for summary_op in tf.compat.v1.summary.all_v2_summary_ops(): if summary_op not in summary_ops: model_summary_ops.append(summary_op) if verbose: print("-- finish constructing model training ops --") print("-- start constructing checkpointers --") ######################## # Eval metrics ######################## with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) ######################## # Create savers ######################## train_config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=False) eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir, summarize_config=False) ######################## # Create checkpointers ######################## train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=meld_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=1) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=meld_agent.policy, global_step=global_step, max_to_keep=99999999999 ) # keep many policy checkpoints, in case of future eval rb_checkpointers = [] for buffer_idx in range(len(replay_buffers)): rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffers/', "task" + str(buffer_idx)), max_to_keep=1, replay_buffer=replay_buffers[buffer_idx]) rb_checkpointers.append(rb_checkpointer) if load_offline_data: # for LOADING data not for checkpointing. No new data going in anyways rb_checkpointers_withRelabel = [] for buffer_idx in range(len(replay_buffers_withRelabel)): ckpt_dir = os.path.join(offline_data_dir, "task" + str(buffer_idx)) rb_checkpointer = common.Checkpointer( ckpt_dir=ckpt_dir, max_to_keep=99999999999, replay_buffer=replay_buffers_withRelabel[buffer_idx]) rb_checkpointers_withRelabel.append(rb_checkpointer) # Notice: these replay buffers need to follow the same sequence of tasks as the current one if verbose: print("-- finish constructing checkpointers --") print("-- start main training loop --") with tf.compat.v1.Session() as sess: ######################## # Initialize ######################## if eval_only: sess.run(eval_summary_writer.init()) load_eval_log( train_eval_dir=train_eval_dir, meld_agent=meld_agent, global_step=global_step, sess=sess, eval_metrics=eval_metrics, eval_py_env=eval_py_env, eval_py_policy=eval_py_policy, num_eval_trials=num_eval_trials, max_episode_len=max_episode_len, episodes_per_trial=episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render=num_trials_to_render, train_tasks= train_tasks, # in case want to eval on a train task eval_tasks=eval_tasks, model_net=model_net, render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, ) return # Initialize checkpointing train_checkpointer.initialize_or_restore(sess) for rb_checkpointer in rb_checkpointers: rb_checkpointer.initialize_or_restore(sess) if load_offline_data: for rb_checkpointer in rb_checkpointers_withRelabel: rb_checkpointer.initialize_or_restore(sess) # Initialize dataset iterators for dataset_iterator in dataset_iterators: sess.run(dataset_iterator.initializer) if load_offline_data: for dataset_iterator in dataset_iterators_withRelabel: sess.run(dataset_iterator.initializer) # Initialize variables common.initialize_uninitialized_variables(sess) # Initialize summary writers sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize savers train_config_saver.after_create_session(sess) eval_config_saver.after_create_session(sess) # Get value of step counter global_step_val = sess.run(global_step) if verbose: print("====== finished initialization ======") ################################################################ # If this is start of new exp (i.e., 1st step) and not continuing old exp # eval rand policy + do initial data collection ################################################################ fresh_start = (global_step_val == 0) if fresh_start: ######################## # Evaluate initial policy ######################## if eval_interval: logging.info( '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks', num_eval_trials) perform_eval_and_summaries_meld( eval_metrics, eval_py_env, eval_py_policy, num_eval_trials, max_episode_len, episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render=num_eval_tasks, eval_tasks=eval_tasks, latent1_size=model_net.latent1_size, latent2_size=model_net.latent2_size, logger=eval_logger, global_step_val=global_step_val, render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, log_image_observations=log_image_observations, ) sess.run(eval_summary_flush_op) logging.info( 'Done with evaluation of initial (random) policy.\n\n') ######################## # Initial data collection ######################## logging.info( '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task', global_step_val, init_collect_trials_per_task, max_episode_len, episodes_per_trial) init_increment_global_step_op = global_step.assign_add( env_steps_per_trial * init_collect_trials_per_task) for task_idx in range(num_train_tasks): logging.info('on task %d / %d', task_idx + 1, num_train_tasks) py_env.set_task_for_env(train_tasks[task_idx]) sess.run([ init_collect_ops[task_idx], init_increment_global_step_op ]) # incremented gs in granularity of task rb_checkpointer.save(global_step=global_step_val) logging.info('Finished init collect.\n\n') else: logging.info( '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n', global_step_val) ######################### # Create calls ######################### # [1] calls for running the policies to collect training data collect_calls = [] increment_global_step_op = global_step.assign_add( env_steps_per_trial * collect_trials_per_task) for task_idx in range(num_train_tasks): collect_calls.append( sess.make_callable( [collect_ops[task_idx], increment_global_step_op])) # [2] call for doing a training step (A + C) ac_train_step_call = sess.make_callable([ac_train_op, summary_ops]) # [3] call for doing a training step (model) model_train_step_call = sess.make_callable( [model_train_op, check_step_types, model_summary_ops]) # [4] call for evaluating what global_step number we're on global_step_call = sess.make_callable(global_step) # reset keeping track of steps/time timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') with train_summary_writer.as_default( ), tf.compat.v2.summary.record_if(True): steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) ################################# # init model training ################################# if fresh_start: logging.info( '\n\nPerforming %d steps of init model training, each step on %d random tasks', init_model_train_steps, num_tasks_per_train) for i in range(init_model_train_steps): temp_start = time.time() if i % 100 == 0: print(".... init model training ", i, "/", init_model_train_steps) # init model training total_loss_value_model, check_step_types, _ = model_train_step_call( ) if PRINT_TIMING: print("single model train step: ", time.time() - temp_start) if verbose: print("\n\n\n-- start training loop --\n") ################################# # Training Loop ################################# start_time = time.time() for iteration in range(num_iterations): if iteration > 0: g.finalize() # print("\n\n\niter", iteration, sess.run(curr_iter)) print("global step", global_step_call()) logging.info("Iteration: %d, Global step: %d\n", iteration, global_step_val) #################### # collect data #################### logging.info( '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks', collect_trials_per_task, max_episode_len, episodes_per_trial, num_tasks_to_collect_per_iter) # randomly select tasks to collect this iteration list_of_collect_task_idxs = np.random.choice( len(train_tasks), num_tasks_to_collect_per_iter, replace=False) for count, task_idx in enumerate(list_of_collect_task_idxs): logging.info('on randomly selected task %d / %d', count + 1, num_tasks_to_collect_per_iter) # set task for the env py_env.set_task_for_env(train_tasks[task_idx]) # collect data with collect policy _, policy_state_val = collect_calls[task_idx]() logging.info('Finish data collection. Global step: %d\n', global_step_call()) #################### # train model #################### if (iteration == 0) or ((iteration % model_train_freq == 0) and (global_step_val < stop_model_training)): logging.info( '\n\nPerforming %d steps of model training, each on %d random tasks', model_train_steps_per_iter, num_tasks_per_train) for model_iter in range(model_train_steps_per_iter): temp_start_2 = time.time() # train model total_loss_value_model, _, _ = model_train_step_call() # print("is logging step", model_iter, sess.run(is_logging_step)) if PRINT_TIMING: print("2: single model train step: ", time.time() - temp_start_2) logging.info('Finish model training. Global step: %d\n', global_step_call()) else: print("SKIPPING MODEL TRAINING") #################### # train actor critic #################### if iteration % ac_train_freq == 0: logging.info( '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n', ac_train_steps_per_iter, num_tasks_per_train) for ac_iter in range(ac_train_steps_per_iter): temp_start_2_ac = time.time() # train ac total_loss_value_ac, _ = ac_train_step_call() if PRINT_TIMING: print("2: single AC train step: ", time.time() - temp_start_2_ac) logging.info('Finish AC training. Global step: %d\n', global_step_call()) # add up time time_acc += time.time() - start_time #################### # logging/summaries #################### ### Eval if eval_interval and (iteration % eval_interval == 0): logging.info( '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks', num_eval_trials) perform_eval_and_summaries_meld( eval_metrics, eval_py_env, eval_py_policy, num_eval_trials, max_episode_len, episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render= num_trials_to_render, # hardcoded: or gif will get too long eval_tasks=eval_tasks, latent1_size=model_net.latent1_size, latent2_size=model_net.latent2_size, logger=eval_logger, global_step_val=global_step_call(), render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, log_image_observations=log_image_observations, ) ### steps_per_second_summary global_step_val = global_step_call() if logging_freq_in_iter and (iteration % logging_freq_in_iter == 0): # log step number + speed (steps/sec) logging.info( 'step = %d, loss = %f', global_step_val, total_loss_value_ac.loss + total_loss_value_model.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f env_steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) # reset keeping track of steps/time timed_at_step = global_step_val time_acc = 0 ### train_checkpoint if train_checkpoint_freq_in_iter and ( iteration % train_checkpoint_freq_in_iter == 0): train_checkpointer.save(global_step=global_step_val) ### policy_checkpointer if policy_checkpoint_freq_in_iter and ( iteration % policy_checkpoint_freq_in_iter == 0): policy_checkpointer.save(global_step=global_step_val) ### rb_checkpointer if rb_checkpoint_freq_in_iter and ( iteration % rb_checkpoint_freq_in_iter == 0): for rb_checkpointer in rb_checkpointers: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=1000, # TODO(kbanoop): rename to policy_fc_layers. actor_fc_layers=(100, ), # Params for collect collect_episodes_per_iteration=2, replay_buffer_capacity=2000, # Params for train learning_rate=1e-3, gradient_clipping=None, normalize_returns=True, # Params for eval num_eval_episodes=10, eval_interval=100, # Params for checkpoints, summaries, and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, rb_checkpoint_interval=200, log_interval=100, summary_interval=100, summaries_flush_secs=1, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for Reinforce.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) # TODO(kbanoop): Handle distributions without gin. actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers) tf_agent = reinforce_agent.ReinforceAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), normalize_returns=normalize_returns, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() experience = replay_buffer.gather_all() train_op = tf_agent.train(experience) clear_rb_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # TODO(sguada) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Compute evaluation metrics. global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) clear_rb_call = sess.make_callable(clear_rb_op) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() total_loss = train_step_call() clear_rb_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_py_env = env_load_fn(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() collect_policy = tf_agent.collect_policy() initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, unused_info = iterator.get_next() train_op = tf_agent.train( experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute initial evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()