Exemple #1
0
    def _initialize_graph(self, sess):
        """Initialize the graph for sess."""
        self._train_checkpointer.initialize_or_restore(sess)
        self._rb_checkpointer.initialize_or_restore(sess)
        # TODO(sguada) Remove once Periodically can be saved.
        common_utils.initialize_uninitialized_variables(sess)

        sess.run(self._ds_itr.initializer)
        sess.run(self._init_agent_op)

        self._train_step_call = sess.make_callable(
            [self._train_op, self._summary_op])

        self._collect_timer = timer.Timer()
        self._train_timer = timer.Timer()
        self._action_timer = timer.Timer()
        self._step_timer = timer.Timer()
        self._observer_timer = timer.Timer()

        global_step_val = sess.run(self._global_step)
        self._timed_at_step = global_step_val

        # Call save to initialize the save_counter (need to do this before
        # finalizing the graph).
        self._train_checkpointer.save(global_step=global_step_val)
        self._policy_checkpointer.save(global_step=global_step_val)
        self._rb_checkpointer.save(global_step=global_step_val)

        tf.contrib.summary.initialize(session=sess,
                                      graph=tf.get_default_graph())
Exemple #2
0
    def _initialize_graph(self, sess):
        """Initialize the graph for sess."""
        self._train_checkpointer.initialize_or_restore(sess)
        self._rb_checkpointer.initialize_or_restore(sess)
        common.initialize_uninitialized_variables(sess)

        sess.run(self._init_agent_op)

        self._train_step_call = sess.make_callable(self._train_op)

        self._collect_timer = timer.Timer()
        self._train_timer = timer.Timer()
        self._action_timer = timer.Timer()
        self._step_timer = timer.Timer()
        self._observer_timer = timer.Timer()

        global_step_val = sess.run(self._global_step)
        self._timed_at_step = global_step_val

        # Call save to initialize the save_counter (need to do this before
        # finalizing the graph).
        self._train_checkpointer.save(global_step=global_step_val)
        self._policy_checkpointer.save(global_step=global_step_val)
        self._rb_checkpointer.save(global_step=global_step_val)
        sess.run(self._train_summary_writer.init())

        if self._do_eval:
            sess.run(self._eval_summary_writer.init())
Exemple #3
0
    def testObjectiveDependentLosses(self):
        networks_and_loss_fns = self._create_objective_network_and_loss_fn_sequence(
        )
        networks_and_loss_fns[1] = (networks_and_loss_fns[1][0],
                                    tf.compat.v1.losses.sigmoid_cross_entropy)
        networks_and_loss_fns[2] = (networks_and_loss_fns[2][0],
                                    tf.compat.v1.losses.absolute_difference)
        agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent(
            self._time_step_spec,
            self._action_spec,
            self._scalarizer,
            objective_network_and_loss_fn_sequence=networks_and_loss_fns,
            optimizer=None)
        observations = np.array([[0.1, 0.2], [1, 0.5]], dtype=np.float32)
        actions = np.array([0, 1], dtype=np.int32)
        objectives = np.array([[0.2, 1, 1.5], [4, 0, 5.5]], dtype=np.float32)
        initial_step, final_step = _get_initial_and_final_steps(
            observations, objectives)
        action_step = _get_action_step(actions)
        experience = _get_experience(initial_step, action_step, final_step)

        init_op = agent.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss, _ = agent._loss(experience)
        self.evaluate(tf.compat.v1.initialize_all_variables())
        self.assertAllClose(self.evaluate(loss), 2.410641)
Exemple #4
0
    def testComputeLossWithArmFeatures(self):
        obs_spec = bandit_spec_utils.create_per_arm_observation_spec(
            global_dim=2, per_arm_dim=3, num_actions=3)
        time_step_spec = ts.time_step_spec(obs_spec)
        constraint_net = (global_and_arm_feature_network.
                          create_feed_forward_common_tower_network(
                              obs_spec,
                              global_layers=(4, ),
                              arm_layers=(4, ),
                              common_layers=(4, )))
        neural_constraint = constraints.NeuralConstraint(
            time_step_spec,
            self._action_spec,
            constraint_network=constraint_net)

        observations = {
            bandit_spec_utils.GLOBAL_FEATURE_KEY:
            tf.constant([[1, 2], [3, 4]], dtype=tf.float32),
            bandit_spec_utils.PER_ARM_FEATURE_KEY:
            tf.cast(tf.reshape(tf.range(18), shape=[2, 3, 3]),
                    dtype=tf.float32)
        }
        actions = tf.constant([0, 1], dtype=tf.int32)
        rewards = tf.constant([0.5, 3.0], dtype=tf.float32)

        init_op = neural_constraint.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss = neural_constraint.compute_loss(observations, actions, rewards)
        self.assertGreater(self.evaluate(loss), 0.0)
def load_pol_ckpt(train_eval_dir, sess, eval_global_step, meld_agent,
                  global_step, eval_second_pol):
    train_dir = os.path.join(train_eval_dir, 'train')

    pol_name = 'policy'
    if eval_second_pol:
        pol_name = 'policy2'

    actual_loaded_step = set_loading_step(train_dir, eval_global_step,
                                          pol_name)

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, pol_name),
        policy=meld_agent.policy,
        global_step=global_step,
        max_to_keep=99999999999
    )  # keep many policy checkpoints, in case of future eval

    policy_status = policy_checkpointer.initialize_or_restore(sess)

    # Initialize variables
    common.initialize_uninitialized_variables(sess)

    set_global_step(global_step, sess, actual_loaded_step)

    # make the checkpoint file pointing back to the latest checkpoint
    set_loading_step(train_dir, step=None)

    return actual_loaded_step, policy_status
Exemple #6
0
 def testInitializeAgent(self):
   agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
       self._time_step_spec,
       self._action_spec)
   init_op = agent.initialize()
   if not tf.executing_eagerly():
     with self.cached_session() as sess:
       common.initialize_uninitialized_variables(sess)
       self.assertIsNone(sess.run(init_op))
Exemple #7
0
 def testInitializeConstraint(self):
     constraint_net = DummyNet(self._observation_spec, self._action_spec)
     neural_constraint = constraints.NeuralConstraint(
         self._time_step_spec,
         self._action_spec,
         constraint_network=constraint_net)
     init_op = neural_constraint.initialize()
     if not tf.executing_eagerly():
         with self.cached_session() as sess:
             common.initialize_uninitialized_variables(sess)
             self.assertIsNone(sess.run(init_op))
Exemple #8
0
 def testInitializeAgent(self, agent_class):
     q_net = DummyNet(self._observation_spec, self._action_spec)
     agent = agent_class(self._time_step_spec,
                         self._action_spec,
                         q_network=q_net,
                         optimizer=None)
     init_op = agent.initialize()
     if not tf.executing_eagerly():
         with self.cached_session() as sess:
             common.initialize_uninitialized_variables(sess)
             self.assertIsNone(sess.run(init_op))
Exemple #9
0
    def initialize(self, batch_size, graph=None):
        if self._built:
            raise RuntimeError('PyTFPolicy can only be initialized once.')

        if not graph:
            graph = tf.compat.v1.get_default_graph()

        self._construct(batch_size, graph)
        var_list = tf.nest.flatten(self._tf_policy.variables())
        common.initialize_uninitialized_variables(self.session, var_list)
        self._built = True
Exemple #10
0
 def testInitializeAgent(self):
     agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent(
         self._time_step_spec,
         self._action_spec,
         self._scalarizer,
         objective_networks=self._create_objective_networks(),
         optimizer=None)
     init_op = agent.initialize()
     if not tf.executing_eagerly():
         with self.cached_session() as sess:
             common.initialize_uninitialized_variables(sess)
             self.assertIsNone(sess.run(init_op))
 def testInitializeAgent(self):
     reward_net = DummyNet(self._observation_spec, self._action_spec)
     agent = greedy_agent.GreedyRewardPredictionAgent(
         self._time_step_spec,
         self._action_spec,
         reward_network=reward_net,
         optimizer=None)
     init_op = agent.initialize()
     if not tf.executing_eagerly():
         with self.cached_session() as sess:
             common.initialize_uninitialized_variables(sess)
             self.assertIsNone(sess.run(init_op))
Exemple #12
0
def collect(tf_env,
            tf_policy,
            output_dir,
            checkpoint=None,
            num_iterations=500000,
            episodes_per_file=500,
            summary_interval=1000):
    """A simple train and eval for SAC."""
    if not os.path.isdir(output_dir):
        logger.info('Making output directory %s...', output_dir)
        os.makedirs(output_dir)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        # Make the replay buffer.
        replay_buffer = tfrecord_replay_buffer.TFRecordReplayBuffer(
            data_spec=tf_policy.trajectory_spec,
            experiment_id='exp',
            file_prefix=os.path.join(output_dir, 'data'),
            episodes_per_file=episodes_per_file)
        replay_observer = [replay_buffer.add_batch]

        collect_policy = tf_policy
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env, collect_policy, observers=replay_observer,
            num_steps=1).run()

        with tf.compat.v1.Session() as sess:
            # Initialize training.
            try:
                common.initialize_uninitialized_variables(sess)
            except Exception:
                pass

            # Restore checkpoint.
            if checkpoint is not None:
                if os.path.isdir(checkpoint):
                    train_dir = os.path.join(checkpoint, 'train')
                    checkpoint_path = tf.train.latest_checkpoint(train_dir)
                else:
                    checkpoint_path = checkpoint

                restorer = tf.train.Saver(name='restorer')
                restorer.restore(sess, checkpoint_path)

            collect_call = sess.make_callable(collect_op)
            for _ in range(num_iterations):
                collect_call()
Exemple #13
0
 def testInitializeAgent(self, agent_class, run_mode):
     if tf.executing_eagerly() and run_mode == context.graph_mode:
         self.skipTest('b/123778560')
     with run_mode():
         q_net = DummyNet(self._observation_spec, self._action_spec)
         agent = agent_class(self._time_step_spec,
                             self._action_spec,
                             q_network=q_net,
                             optimizer=None)
         init_op = agent.initialize()
         if not tf.executing_eagerly():
             with self.cached_session() as sess:
                 common.initialize_uninitialized_variables(sess)
                 self.assertIsNone(sess.run(init_op))
Exemple #14
0
  def testComputeActionFeasibility(self):
    constraint_net = DummyNet(self._observation_spec, self._action_spec)

    neural_constraint = constraints.NeuralConstraint(
        self._time_step_spec,
        self._action_spec,
        constraint_network=constraint_net)
    init_op = neural_constraint.initialize()
    if not tf.executing_eagerly():
      with self.cached_session() as sess:
        common.initialize_uninitialized_variables(sess)
        self.assertIsNone(sess.run(init_op))

    observation = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    feasibility_prob = neural_constraint(observation)
    self.assertAllClose(self.evaluate(feasibility_prob), np.ones([2, 3]))
Exemple #15
0
  def testComputeActionFeasibility(self):
    constraint_net = DummyNet(self._observation_spec, self._action_spec)

    quantile_constraint = constraints.QuantileConstraint(
        self._time_step_spec,
        self._action_spec,
        constraint_network=constraint_net)
    init_op = quantile_constraint.initialize()
    if not tf.executing_eagerly():
      with self.cached_session() as sess:
        common.initialize_uninitialized_variables(sess)
        self.assertIsNone(sess.run(init_op))

    observation = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    feasibility_prob = quantile_constraint(observation)
    self.assertAllGreaterEqual(self.evaluate(feasibility_prob), 0.0)
    self.assertAllLessEqual(self.evaluate(feasibility_prob), 1.0)
Exemple #16
0
    def testComputeLoss(self):
        constraint_net = DummyNet(self._observation_spec, self._action_spec)
        observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        actions = tf.constant([0, 1], dtype=tf.int32)
        rewards = tf.constant([0.5, 3.0], dtype=tf.float32)

        neural_constraint = constraints.NeuralConstraint(
            self._time_step_spec,
            self._action_spec,
            constraint_network=constraint_net)
        init_op = neural_constraint.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss = neural_constraint.compute_loss(observations, actions, rewards)
        self.assertAllClose(self.evaluate(loss), 42.25)
Exemple #17
0
    def testLoss(self):
        agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent(
            self._time_step_spec,
            self._action_spec,
            self._scalarizer,
            objective_networks=self._create_objective_networks(),
            optimizer=None)
        observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        actions = tf.constant([0, 1], dtype=tf.int32)
        objectives = tf.constant([[8, 12, 11], [25, 18, 32]], dtype=tf.float32)

        init_op = agent.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss, _ = agent.loss(observations, actions, objectives)
        self.evaluate(tf.compat.v1.initialize_all_variables())
        self.assertAllClose(self.evaluate(loss), 0.0)
    def testLoss(self):
        reward_net = DummyNet(self._observation_spec, self._action_spec)
        observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        actions = tf.constant([0, 1], dtype=tf.int32)
        rewards = tf.constant([0.5, 3.0], dtype=tf.float32)

        agent = greedy_agent.GreedyRewardPredictionAgent(
            self._time_step_spec,
            self._action_spec,
            reward_network=reward_net,
            optimizer=None)
        init_op = agent.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss, _ = agent.loss(observations, actions, rewards)
        self.evaluate(tf.compat.v1.initialize_all_variables())
        self.assertAllClose(self.evaluate(loss), 42.25)
Exemple #19
0
  def testTrainAgent(self):
    observations = np.array([[1, 1]], dtype=np.float32)
    actions = np.array([0, 1], dtype=np.int32)
    rewards = np.array([0.0, 1.0], dtype=np.float32)
    initial_step, final_step = _get_initial_and_final_steps(
        observations, rewards)
    action_step = _get_action_step(actions)
    experience = _get_experience(initial_step, action_step, final_step)

    agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
        self._time_step_spec,
        self._action_spec,
        batch_size=2)
    init_op = agent.initialize()
    if not tf.executing_eagerly():
      with self.cached_session() as sess:
        common.initialize_uninitialized_variables(sess)
        self.assertIsNone(sess.run(init_op))
    loss, _ = agent._train(experience, weights=None)
    self.evaluate(tf.compat.v1.initialize_all_variables())
    # The loss is -sum(rewards).
    self.assertAllClose(self.evaluate(loss), -1.0)
Exemple #20
0
    def testLoss(self):
        reward_net = DummyNet(self._observation_spec, self._action_spec)
        observations = np.array([[1, 2], [3, 4]], dtype=np.float32)
        actions = np.array([0, 1], dtype=np.int32)
        rewards = np.array([0.5, 3.0], dtype=np.float32)
        initial_step, final_step = _get_initial_and_final_steps_nested_rewards(
            observations, rewards)
        action_step = _get_action_step(actions)
        experience = _get_experience(initial_step, action_step, final_step)

        agent = greedy_agent.GreedyRewardPredictionAgent(
            self._time_step_spec,
            self._action_spec,
            reward_network=reward_net,
            optimizer=None)
        init_op = agent.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss, _ = agent._loss(experience)
        self.evaluate(tf.compat.v1.initialize_all_variables())
        self.assertAllClose(self.evaluate(loss), 42.25)
Exemple #21
0
    def testLoss(self):
        agent = greedy_multi_objective_agent.GreedyMultiObjectiveNeuralAgent(
            self._time_step_spec,
            self._action_spec,
            self._scalarizer,
            objective_network_and_loss_fn_sequence=self.
            _create_objective_network_and_loss_fn_sequence(),
            optimizer=None)
        observations = np.array([[1, 2], [3, 4]], dtype=np.float32)
        actions = np.array([0, 1], dtype=np.int32)
        objectives = np.array([[8, 12, 11], [25, 18, 32]], dtype=np.float32)
        initial_step, final_step = _get_initial_and_final_steps(
            observations, objectives)
        action_step = _get_action_step(actions)
        experience = _get_experience(initial_step, action_step, final_step)

        init_op = agent.initialize()
        if not tf.executing_eagerly():
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
                self.assertIsNone(sess.run(init_op))
        loss, _ = agent._loss(experience)
        self.evaluate(tf.compat.v1.initialize_all_variables())
        self.assertAllClose(self.evaluate(loss), 0.0)
Exemple #22
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=100000,
    fc_layer_params=(100,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    epsilon_greedy=0.1,
    replay_buffer_capacity=100000,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    learning_rate=1e-3,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    agent_class=dqn_agent.DqnAgent,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for DQN."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_py_env = suite_gym.load(env_name)

    q_net = q_network.QNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=fc_layer_params)

    # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
    tf_agent = agent_class(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_observer = [replay_buffer.add_batch]
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_policy = tf_agent.collect_policy
    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    experience, _ = iterator.get_next()
    train_op = common.function(tf_agent.train)(experience=experience)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )

      collect_call = sess.make_callable(collect_op)
      global_step_call = sess.make_callable(global_step)
      train_step_call = sess.make_callable([train_op, summary_ops])

      timed_at_step = global_step_call()
      collect_time = 0
      train_time = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        # Train/collect/eval.
        start_time = time.time()
        collect_call()
        collect_time += time.time() - start_time
        start_time = time.time()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _ = train_step_call()
        train_time += time.time() - start_time

        global_step_val = global_step_call()

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (
              (global_step_val - timed_at_step) / (collect_time + train_time))
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          logging.info('%.3f steps/sec', steps_per_sec)
          logging.info('%s', 'collect_time = {}, train_time = {}'.format(
              collect_time, train_time))
          timed_at_step = global_step_val
          collect_time = 0
          train_time = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
Exemple #23
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=step_metrics)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss = sess.run(train_op)
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
Exemple #24
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()

        train_fn = common.function(tf_agent.train)
        train_op = train_fn(experience=trajectories,
                            train_step_counter=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(b/126239733) Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                    global_step_val = global_step_call()
                time_acc += time.time() - start_time

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True)
    def testTrainWithRnn(self):
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            self._obs_spec,
            self._action_spec,
            input_fc_layer_params=None,
            output_fc_layer_params=None,
            conv_layer_params=None,
            lstm_size=(40, ),
        )

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (self._obs_spec, self._action_spec),
            observation_fc_layer_params=(16, ),
            action_fc_layer_params=(16, ),
            joint_fc_layer_params=(16, ),
            lstm_size=(16, ),
            output_fc_layer_params=None,
        )

        counter = common.create_variable('test_train_counter')

        optimizer_fn = tf.compat.v1.train.AdamOptimizer

        agent = sac_agent.SacAgent(
            self._time_step_spec,
            self._action_spec,
            critic_network=critic_net,
            actor_network=actor_net,
            actor_optimizer=optimizer_fn(1e-3),
            critic_optimizer=optimizer_fn(1e-3),
            alpha_optimizer=optimizer_fn(1e-3),
            train_step_counter=counter,
        )

        batch_size = 5
        observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size,
                                   dtype=tf.float32)
        actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32)
        time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * batch_size,
                                                       dtype=tf.int32),
                                 reward=tf.constant([[1] * 3] * batch_size,
                                                    dtype=tf.float32),
                                 discount=tf.constant([[1] * 3] * batch_size,
                                                      dtype=tf.float32),
                                 observation=observations)

        experience = trajectory.Trajectory(time_steps.step_type, observations,
                                           actions, (), time_steps.step_type,
                                           time_steps.reward,
                                           time_steps.discount)

        # Force variable creation.
        agent.policy.variables()

        if not tf.executing_eagerly():
            # Get experience first to make sure optimizer variables are created and
            # can be initialized.
            experience = agent.train(experience)
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
            self.assertEqual(self.evaluate(counter), 0)
            self.evaluate(experience)
            self.assertEqual(self.evaluate(counter), 1)
        else:
            self.assertEqual(self.evaluate(counter), 0)
            self.evaluate(agent.train(experience))
            self.assertEqual(self.evaluate(counter), 1)
Exemple #26
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(
                    batch_size * 5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )
                sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_flush_op)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Exemple #27
0
def train_eval(

        ##############################################
        # types of params:
        # 0: specific to algorithm (gin file 0)
        # 1: specific to environment (gin file 1)
        # 2: specific to experiment (gin file 2 + command line)

        # Note: there are other important params
        # in eg ModelDistributionNetwork that the gin files specify
        # like sparse vs dense rewards, latent dimensions, etc.
        ##############################################

    # basic params for running/logging experiment
    root_dir,  # 2
        experiment_name,  # 2
        num_iterations=int(1e7),  # 2
        seed=1,  # 2
        gpu_allow_growth=False,  # 2
        gpu_memory_limit=None,  # 2
        verbose=True,  # 2
        policy_checkpoint_freq_in_iter=100,  # policies needed for future eval                             # 2
        train_checkpoint_freq_in_iter=0,  #default don't save                                              # 2
        rb_checkpoint_freq_in_iter=0,  #default don't save                                                 # 2
        logging_freq_in_iter=10,  # printing to terminal                                                   # 2
        summary_freq_in_iter=10,  # saving to tb                                                           # 2
        num_images_per_summary=2,  # 2
        summaries_flush_secs=10,  # 2
        max_episode_len_override=None,  # 2
        num_trials_to_render=1,  # 2

        # environment, action mode, etc.
    env_name='HalfCheetah-v2',  # 1
        action_repeat=1,  # 1
        action_mode='joint_position',  # joint_position or joint_delta_position                           # 1
        double_camera=False,  # camera input                                                               # 1
        universe='gym',  # default
        task_reward_dim=1,  # default

        # dims for all networks
    actor_fc_layers=(256, 256),  # 1
        critic_obs_fc_layers=None,  # 1
        critic_action_fc_layers=None,  # 1
        critic_joint_fc_layers=(256, 256),  # 1
        num_repeat_when_concatenate=None,  # 1

        # networks
    critic_input='state',  # 0
        actor_input='state',  # 0

        # specifying tasks and eval
    episodes_per_trial=1,  # 2
        num_train_tasks=10,  # 2
        num_eval_tasks=10,  # 2
        num_eval_trials=10,  # 2
        eval_interval=10,  # 2
        eval_on_holdout_tasks=True,  # 2

        # data collection/buffer
    init_collect_trials_per_task=None,  # 2
        collect_trials_per_task=None,  # 2
        num_tasks_to_collect_per_iter=5,  # 2
        replay_buffer_capacity=int(1e5),  # 2

        # training
    init_model_train_ratio=0.8,  # 2
        model_train_ratio=1,  # 2
        model_train_freq=1,  # 2
        ac_train_ratio=1,  # 2
        ac_train_freq=1,  # 2
        num_tasks_per_train=5,  # 2
        train_trials_per_task=5,  # 2
        model_bs_in_steps=256,  # 2
        ac_bs_in_steps=128,  # 2

        # default AC learning rates, gamma, etc.
    target_update_tau=0.005,
        target_update_period=1,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        model_learning_rate=1e-4,
        td_errors_loss_fn=functools.partial(
            tf.compat.v1.losses.mean_squared_error, weights=0.5),
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        log_image_strips=False,
        stop_model_training=1E10,
        eval_only=False,  # evaluate checkpoints ONLY
        log_image_observations=False,
        load_offline_data=False,  # whether to use offline data
        offline_data_dir=None,  # replay buffer's dir
        offline_episode_len=None,  # episode len of episodes stored in rb
        offline_ratio=0,  # ratio of data that is from offline buffer
):

    g = tf.Graph()

    # register all gym envs
    max_steps_dict = {
        "HalfCheetahVel-v0": 50,
        "SawyerReach-v0": 40,
        "SawyerReachMT-v0": 40,
        "SawyerPeg-v0": 40,
        "SawyerPegMT-v0": 40,
        "SawyerPegMT4box-v0": 40,
        "SawyerShelfMT-v0": 40,
        "SawyerKitchenMT-v0": 40,
        "SawyerShelfMT-v2": 40,
        "SawyerButtons-v0": 40,
    }
    if max_episode_len_override:
        max_steps_dict[env_name] = max_episode_len_override
    register_all_gym_envs(max_steps_dict)

    # set max_episode_len based on our env
    max_episode_len = max_steps_dict[env_name]

    ######################################################
    # Calculate additional params
    ######################################################

    # convert to number of steps
    env_steps_per_trial = episodes_per_trial * max_episode_len
    real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1)
    env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial
    per_task_collect_steps = collect_trials_per_task * env_steps_per_trial

    # initial collect + train
    init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial
    init_model_train_steps = int(init_collect_env_steps *
                                 init_model_train_ratio)

    # collect + train
    collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps
    model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio)
    ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio)

    # other
    global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter
    sample_episodes_per_task = train_trials_per_task * episodes_per_trial  # number of episodes to sample from each replay
    model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial

    # assertions that make sure parameters make sense
    assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial"
    assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False"
    assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train"
    assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train"

    ######################################################
    # Print a summary of params
    ######################################################
    MELD_summary_string = f"""\n\n\n
==============================================================
==============================================================
  \n
  MELD algorithm summary:

  * each trial consists of {episodes_per_trial} episodes
  * episode length: {max_episode_len}, trial length: {env_steps_per_trial}
  * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks}
  * environment: {env_name}
  
  For each of {num_train_tasks} tasks:
    Do {init_collect_trials_per_task} trials of initial collect
  (total {init_collect_env_steps} env steps)
  
  Do {init_model_train_steps} steps of initial model training
    
  For i in range(inf):
    For each of {num_tasks_to_collect_per_iter} randomly selected tasks:
      Do {collect_trials_per_task} trials of collect
    (which is {collect_trials_per_task*env_steps_per_trial} env steps per task)
    (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration)
    
    if i % model_train_freq(={model_train_freq}):
      Do {model_train_steps_per_iter} steps of model training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {model_bs_in_trials} trials, train model on whole trials.
    
    if i % ac_train_freq(={ac_train_freq}):
      Do {ac_train_steps_per_iter} steps of ac training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, 
          to train ac.
  
  
  * Other important params:
  Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps
  Average evaluation across {num_eval_trials} trials
  Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps
  Checkpoint:
   - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
   - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints
   - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
    
  \n
=============================================================
=============================================================
  """

    print(MELD_summary_string)
    time.sleep(1)

    ######################################################
    # Seed + name + GPU configs + directories for saving
    ######################################################
    np.random.seed(int(seed))
    experiment_name += "_seed" + str(seed)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    train_eval_dir = get_train_eval_dir(root_dir, universe, env_name,
                                        experiment_name)
    train_dir = os.path.join(train_eval_dir, 'train')
    eval_dir = os.path.join(train_eval_dir, 'eval')
    eval_dir_2 = os.path.join(train_eval_dir, 'eval2')

    ######################################################
    # Train and Eval Summary Writers
    ######################################################
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_summary_flush_op = eval_summary_writer.flush()

    eval_logger = Logger(eval_dir_2)

    ######################################################
    # Train and Eval metrics
    ######################################################
    eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len  # across all eval trials in each evaluation
    eval_metrics = []
    for position in range(
            episodes_per_trial
    ):  # have metrics for each episode position, to track whether it is learning
        eval_metrics_pos = [
            py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' +
                                           str(position),
                                           buffer_size=eval_buffer_size),
            py_metrics.AverageEpisodeLengthMetric(
                name='f_AverageEpisodeLengthEval_' + str(position),
                buffer_size=eval_buffer_size),
            custom_metrics.AverageScoreMetric(
                name="d_AverageScoreMetricEval_" + str(position),
                buffer_size=eval_buffer_size),
        ]
        eval_metrics.extend(eval_metrics_pos)

    train_buffer_size = num_train_tasks * episodes_per_trial
    train_metrics = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'),
        tf_metrics.EnvironmentSteps(name='EnvironmentSteps'),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageReturnMetric(name="a_AverageReturnTrain",
                                           buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageEpisodeLengthMetric(
                name="e_AverageEpisodeLengthTrain",
                buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain",
                                              buffer_size=train_buffer_size)),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step(
    )  # will be use to record number of model grad steps + ac grad steps + env_step

    log_cond = get_log_condition_tensor(
        global_step, init_collect_trials_per_task, env_steps_per_trial,
        num_train_tasks, init_model_train_steps, collect_trials_per_task,
        num_tasks_to_collect_per_iter, model_train_steps_per_iter,
        ac_train_steps_per_iter, summary_freq_in_iter, eval_interval)

    with tf.compat.v2.summary.record_if(log_cond):

        ######################################################
        # Create env
        ######################################################
        py_env, eval_py_env, train_tasks, eval_tasks = load_environments(
            universe,
            action_mode,
            env_name=env_name,
            observations_whitelist=['state', 'pixels', "env_info"],
            action_repeat=action_repeat,
            num_train_tasks=num_train_tasks,
            num_eval_tasks=num_eval_tasks,
            eval_on_holdout_tasks=eval_on_holdout_tasks,
            return_multiple_tasks=True,
        )
        override_reward_func = None
        if load_offline_data:
            py_env.set_task_dict(train_tasks)
            override_reward_func = py_env.override_reward_func

        tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True)

        # Get data specs from env
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        original_control_timestep = get_control_timestep(eval_py_env)

        # fps
        control_timestep = original_control_timestep * float(action_repeat)
        render_fps = int(np.round(1.0 / original_control_timestep))

        ######################################################
        # Latent variable model
        ######################################################
        if verbose:
            print("-- start constructing model networks --")

        model_net = ModelDistributionNetwork(
            double_camera=double_camera,
            observation_spec=observation_spec,
            num_repeat_when_concatenate=num_repeat_when_concatenate,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            max_episode_len=max_episode_len
        )  # rest of arguments provided via gin

        if verbose:
            print("-- finish constructing AC networks --")

        ######################################################
        # Compressor Network for Actor/Critic
        # The model's compressor is also used by the AC
        # compressor function: images --> features
        ######################################################

        compressor_net = model_net.compressor

        ######################################################
        # Specs for Actor and Critic
        ######################################################
        if actor_input == 'state':
            actor_state_size = observation_spec['state'].shape[0]
        elif actor_input == 'latentSample':
            actor_state_size = model_net.state_size
        elif actor_input == "latentDistribution":
            actor_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ),
                                                  dtype=tf.float32)

        if critic_input == 'state':
            critic_state_size = observation_spec['state'].shape[0]
        elif critic_input == 'latentSample':
            critic_state_size = model_net.state_size
        elif critic_input == "latentDistribution":
            critic_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ),
                                                   dtype=tf.float32)

        ######################################################
        # Actor and Critic Networks
        ######################################################
        if verbose:
            print("-- start constructing Actor and Critic networks --")

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            actor_input_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
        )

        critic_net = critic_network.CriticNetwork(
            (critic_input_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        if verbose:
            print("-- finish constructing AC networks --")
            print("-- start constructing agent --")

        ######################################################
        # Create the agent
        ######################################################

        which_posterior_overwrite = None
        which_reward_overwrite = None

        meld_agent = MeldAgent(
            # specs
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            # step counter
            train_step_counter=
            global_step,  # will count number of model training steps
            # networks
            actor_network=actor_net,
            critic_network=critic_net,
            model_network=model_net,
            compressor_network=compressor_net,
            # optimizers
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            model_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=model_learning_rate),
            # target update
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            # inputs
            critic_input=critic_input,
            actor_input=actor_input,
            # bs stuff
            model_batch_size=model_bs_in_steps,
            ac_batch_size=ac_bs_in_steps,
            # other
            num_tasks_per_train=num_tasks_per_train,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            # offline data
            override_reward_func=override_reward_func,
            offline_ratio=offline_ratio,
        )

        if verbose:
            print("-- finish constructing agent --")

        ######################################################
        # Replay buffers + observers to add data to them
        ######################################################
        replay_buffers = []
        replay_observers = []
        for _ in range(num_train_tasks):
            replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer(
                meld_agent.collect_policy.
                trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                capacity=replay_buffer_capacity,
                completed_only=
                True,  # in as_dataset, if num_steps is None, this means return full episodes
                # device='GPU:0', # gpu not supported for some reason
                begin_episode_fn=lambda traj: traj.is_first()[
                    0],  # first step of seq we add should be is_first
                end_episode_fn=lambda traj: traj.is_last()[
                    0],  # last step of seq we add should be is_last
                dataset_drop_remainder=
                True,  #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
            )
            replay_buffer = StatefulEpisodicReplayBuffer(
                replay_buffer_episodic)  # adding num_episodes here is bad
            replay_buffers.append(replay_buffer)
            replay_observers.append([replay_buffer.add_sequence])

        if load_offline_data:
            # for each task, has a separate replay buffer for relabeled data
            replay_buffers_withRelabel = []
            replay_observers_withRelabel = []
            for _ in range(num_train_tasks):
                replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer(
                    meld_agent.collect_policy.
                    trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                    capacity=replay_buffer_capacity,
                    completed_only=
                    True,  # in as_dataset, if num_steps is None, this means return full episodes
                    # device='GPU:0', # gpu not supported for some reason
                    begin_episode_fn=lambda traj: traj.is_first()[
                        0],  # first step of seq we add should be is_first
                    end_episode_fn=lambda traj: traj.is_last()[
                        0],  # last step of seq we add should be is_last
                    dataset_drop_remainder=True,
                    # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
                )
                replay_buffer_withRelabel = StatefulEpisodicReplayBuffer(
                    replay_buffer_episodic_withRelabel
                )  # adding num_episodes here is bad
                replay_buffers_withRelabel.append(replay_buffer_withRelabel)
                replay_observers_withRelabel.append(
                    [replay_buffer_withRelabel.add_sequence])

        if verbose:
            print("-- finish constructing replay buffers --")
            print("-- start constructing policies and collect ops --")

        ######################################################
        # Policies
        #####################################################

        # init collect policy (random)
        init_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        # eval
        eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy)

        ################################################################################
        # Collect ops : use policies to get data + have the observer put data into corresponding RB
        ################################################################################

        #init collection (with random policy)
        init_collect_ops = []
        for task_idx in range(num_train_tasks):
            # put init data into the rb + track with the train metric
            observers = replay_observers[task_idx] + train_metrics

            # initial collect op
            init_collect_op = DynamicTrialDriver(
                tf_env,
                init_collect_policy,
                num_trials_to_collect=init_collect_trials_per_task,
                observers=observers,
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            init_collect_ops.append(init_collect_op)

        # data collection for training (with collect policy)
        collect_ops = []
        for task_idx in range(num_train_tasks):
            collect_op = DynamicTrialDriver(
                tf_env,
                meld_agent.collect_policy,
                num_trials_to_collect=collect_trials_per_task,
                observers=replay_observers[task_idx] +
                train_metrics,  # put data into 1st RB + track with 1st pol metrics
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            collect_ops.append(collect_op)

        if verbose:
            print("-- finish constructing policies and collect ops --")
            print("-- start constructing replay buffer->training pipeline --")

        ######################################################
        # replay buffer --> dataset --> iterate to get trajecs for training
        ######################################################

        # get some data from all task replay buffers (even though won't actually train on all of them)
        dataset_iterators = []
        all_tasks_trajectories_fromdense = []
        for task_idx in range(num_train_tasks):
            dataset = replay_buffers[task_idx].as_dataset(
                sample_batch_size=
                sample_episodes_per_task,  # number of episodes to sample
                num_steps=max_episode_len + 1
            ).prefetch(
                3
            )  # +1 to include the last state: a trajectory with n transition has n+1 states
            # iterator to go through the data
            dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                dataset)
            dataset_iterators.append(dataset_iterator)
            # get sample_episodes_per_task sequences, each of length num_steps
            trajectories_task_i, _ = dataset_iterator.get_next()
            all_tasks_trajectories_fromdense.append(trajectories_task_i)

        if load_offline_data:
            # have separate dataset for relabel data
            dataset_iterators_withRelabel = []
            all_tasks_trajectories_fromdense_withRelabel = []
            for task_idx in range(num_train_tasks):
                dataset = replay_buffers_withRelabel[task_idx].as_dataset(
                    sample_batch_size=
                    sample_episodes_per_task,  # number of episodes to sample
                    num_steps=offline_episode_len + 1
                ).prefetch(
                    3
                )  # +1 to include the last state: a trajectory with n transition has n+1 states
                # iterator to go through the data
                dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                    dataset)
                dataset_iterators_withRelabel.append(dataset_iterator)
                # get sample_episodes_per_task sequences, each of length num_steps
                trajectories_task_i, _ = dataset_iterator.get_next()
                all_tasks_trajectories_fromdense_withRelabel.append(
                    trajectories_task_i)

        if verbose:
            print("-- finish constructing replay buffer->training pipeline --")
            print("-- start constructing model and AC training ops --")

        ######################################
        # Decoding latent samples into rewards
        ######################################

        latent_samples_1_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent1_size))
        latent_samples_2_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent2_size))
        decode_rews_op = meld_agent._model_network.decode_latents_into_reward(
            latent_samples_1_ph, latent_samples_2_ph)

        ######################################
        # Model/Actor/Critic train + summary ops
        ######################################

        # train AC on data from replay buffer
        if load_offline_data:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        if verbose:
            print("-- finish constructing AC training ops --")

        ############################
        # Model train + summary ops
        ############################

        # train model on data from replay buffer
        if load_offline_data:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense)

        model_summary_ops, model_summary_ops_2 = [], []
        for summary_op in tf.compat.v1.summary.all_v2_summary_ops():
            if summary_op not in summary_ops:
                model_summary_ops.append(summary_op)

        if verbose:
            print("-- finish constructing model training ops --")
            print("-- start constructing checkpointers --")

        ########################
        # Eval metrics
        ########################

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=train_metrics[:2])

        ########################
        # Create savers
        ########################
        train_config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                       summarize_config=False)
        eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir,
                                                      summarize_config=False)

        ########################
        # Create checkpointers
        ########################

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=meld_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=1)
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=meld_agent.policy,
            global_step=global_step,
            max_to_keep=99999999999
        )  # keep many policy checkpoints, in case of future eval
        rb_checkpointers = []
        for buffer_idx in range(len(replay_buffers)):
            rb_checkpointer = common.Checkpointer(
                ckpt_dir=os.path.join(train_dir, 'replay_buffers/',
                                      "task" + str(buffer_idx)),
                max_to_keep=1,
                replay_buffer=replay_buffers[buffer_idx])
            rb_checkpointers.append(rb_checkpointer)

        if load_offline_data:  # for LOADING data not for checkpointing. No new data going in anyways
            rb_checkpointers_withRelabel = []
            for buffer_idx in range(len(replay_buffers_withRelabel)):
                ckpt_dir = os.path.join(offline_data_dir,
                                        "task" + str(buffer_idx))
                rb_checkpointer = common.Checkpointer(
                    ckpt_dir=ckpt_dir,
                    max_to_keep=99999999999,
                    replay_buffer=replay_buffers_withRelabel[buffer_idx])
                rb_checkpointers_withRelabel.append(rb_checkpointer)
            # Notice: these replay buffers need to follow the same sequence of tasks as the current one

        if verbose:
            print("-- finish constructing checkpointers --")
            print("-- start main training loop --")

        with tf.compat.v1.Session() as sess:

            ########################
            # Initialize
            ########################

            if eval_only:
                sess.run(eval_summary_writer.init())
                load_eval_log(
                    train_eval_dir=train_eval_dir,
                    meld_agent=meld_agent,
                    global_step=global_step,
                    sess=sess,
                    eval_metrics=eval_metrics,
                    eval_py_env=eval_py_env,
                    eval_py_policy=eval_py_policy,
                    num_eval_trials=num_eval_trials,
                    max_episode_len=max_episode_len,
                    episodes_per_trial=episodes_per_trial,
                    log_image_strips=log_image_strips,
                    num_trials_to_render=num_trials_to_render,
                    train_tasks=
                    train_tasks,  # in case want to eval on a train task
                    eval_tasks=eval_tasks,
                    model_net=model_net,
                    render_fps=render_fps,
                    decode_rews_op=decode_rews_op,
                    latent_samples_1_ph=latent_samples_1_ph,
                    latent_samples_2_ph=latent_samples_2_ph,
                )
                return

            # Initialize checkpointing
            train_checkpointer.initialize_or_restore(sess)
            for rb_checkpointer in rb_checkpointers:
                rb_checkpointer.initialize_or_restore(sess)

            if load_offline_data:
                for rb_checkpointer in rb_checkpointers_withRelabel:
                    rb_checkpointer.initialize_or_restore(sess)

            # Initialize dataset iterators
            for dataset_iterator in dataset_iterators:
                sess.run(dataset_iterator.initializer)

            if load_offline_data:
                for dataset_iterator in dataset_iterators_withRelabel:
                    sess.run(dataset_iterator.initializer)

            # Initialize variables
            common.initialize_uninitialized_variables(sess)

            # Initialize summary writers
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Initialize savers
            train_config_saver.after_create_session(sess)
            eval_config_saver.after_create_session(sess)
            # Get value of step counter
            global_step_val = sess.run(global_step)

            if verbose:
                print("====== finished initialization ======")

            ################################################################
            # If this is start of new exp (i.e., 1st step) and not continuing old exp
            # eval rand policy + do initial data collection
            ################################################################
            fresh_start = (global_step_val == 0)

            if fresh_start:

                ########################
                # Evaluate initial policy
                ########################

                if eval_interval:
                    logging.info(
                        '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks',
                        num_eval_trials)
                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=num_eval_tasks,
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_val,
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )
                    sess.run(eval_summary_flush_op)
                    logging.info(
                        'Done with evaluation of initial (random) policy.\n\n')

                ########################
                # Initial data collection
                ########################

                logging.info(
                    '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task',
                    global_step_val, init_collect_trials_per_task,
                    max_episode_len, episodes_per_trial)

                init_increment_global_step_op = global_step.assign_add(
                    env_steps_per_trial * init_collect_trials_per_task)

                for task_idx in range(num_train_tasks):
                    logging.info('on task %d / %d', task_idx + 1,
                                 num_train_tasks)
                    py_env.set_task_for_env(train_tasks[task_idx])
                    sess.run([
                        init_collect_ops[task_idx],
                        init_increment_global_step_op
                    ])  # incremented gs in granularity of task

                rb_checkpointer.save(global_step=global_step_val)
                logging.info('Finished init collect.\n\n')

            else:
                logging.info(
                    '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n',
                    global_step_val)

            #########################
            # Create calls
            #########################

            # [1] calls for running the policies to collect training data
            collect_calls = []
            increment_global_step_op = global_step.assign_add(
                env_steps_per_trial * collect_trials_per_task)
            for task_idx in range(num_train_tasks):
                collect_calls.append(
                    sess.make_callable(
                        [collect_ops[task_idx], increment_global_step_op]))

            # [2] call for doing a training step (A + C)
            ac_train_step_call = sess.make_callable([ac_train_op, summary_ops])

            # [3] call for doing a training step (model)
            model_train_step_call = sess.make_callable(
                [model_train_op, check_step_types, model_summary_ops])

            # [4] call for evaluating what global_step number we're on
            global_step_call = sess.make_callable(global_step)

            # reset keeping track of steps/time
            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            with train_summary_writer.as_default(
            ), tf.compat.v2.summary.record_if(True):
                steps_per_second_summary = tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_second_ph,
                    step=global_step)

            #################################
            # init model training
            #################################
            if fresh_start:
                logging.info(
                    '\n\nPerforming %d steps of init model training, each step on %d random tasks',
                    init_model_train_steps, num_tasks_per_train)
                for i in range(init_model_train_steps):

                    temp_start = time.time()
                    if i % 100 == 0:
                        print(".... init model training ", i, "/",
                              init_model_train_steps)

                    # init model training
                    total_loss_value_model, check_step_types, _ = model_train_step_call(
                    )

                    if PRINT_TIMING:
                        print("single model train step: ",
                              time.time() - temp_start)

            if verbose:
                print("\n\n\n-- start training loop --\n")

            #################################
            # Training Loop
            #################################
            start_time = time.time()
            for iteration in range(num_iterations):

                if iteration > 0:
                    g.finalize()

                # print("\n\n\niter", iteration, sess.run(curr_iter))
                print("global step", global_step_call())

                logging.info("Iteration: %d, Global step: %d\n", iteration,
                             global_step_val)

                ####################
                # collect data
                ####################
                logging.info(
                    '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks',
                    collect_trials_per_task, max_episode_len,
                    episodes_per_trial, num_tasks_to_collect_per_iter)

                # randomly select tasks to collect this iteration
                list_of_collect_task_idxs = np.random.choice(
                    len(train_tasks),
                    num_tasks_to_collect_per_iter,
                    replace=False)
                for count, task_idx in enumerate(list_of_collect_task_idxs):
                    logging.info('on randomly selected task %d / %d',
                                 count + 1, num_tasks_to_collect_per_iter)

                    # set task for the env
                    py_env.set_task_for_env(train_tasks[task_idx])

                    # collect data with collect policy
                    _, policy_state_val = collect_calls[task_idx]()

                logging.info('Finish data collection. Global step: %d\n',
                             global_step_call())

                ####################
                # train model
                ####################
                if (iteration
                        == 0) or ((iteration % model_train_freq == 0) and
                                  (global_step_val < stop_model_training)):
                    logging.info(
                        '\n\nPerforming %d steps of model training, each on %d random tasks',
                        model_train_steps_per_iter, num_tasks_per_train)
                    for model_iter in range(model_train_steps_per_iter):
                        temp_start_2 = time.time()

                        # train model
                        total_loss_value_model, _, _ = model_train_step_call()

                        # print("is logging step", model_iter, sess.run(is_logging_step))
                        if PRINT_TIMING:
                            print("2: single model train step: ",
                                  time.time() - temp_start_2)
                    logging.info('Finish model training. Global step: %d\n',
                                 global_step_call())
                else:
                    print("SKIPPING MODEL TRAINING")

                ####################
                # train actor critic
                ####################
                if iteration % ac_train_freq == 0:
                    logging.info(
                        '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n',
                        ac_train_steps_per_iter, num_tasks_per_train)
                    for ac_iter in range(ac_train_steps_per_iter):
                        temp_start_2_ac = time.time()

                        # train ac
                        total_loss_value_ac, _ = ac_train_step_call()
                        if PRINT_TIMING:
                            print("2: single AC train step: ",
                                  time.time() - temp_start_2_ac)
                logging.info('Finish AC training. Global step: %d\n',
                             global_step_call())

                # add up time
                time_acc += time.time() - start_time

                ####################
                # logging/summaries
                ####################

                ### Eval
                if eval_interval and (iteration % eval_interval == 0):
                    logging.info(
                        '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks',
                        num_eval_trials)

                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=
                        num_trials_to_render,  # hardcoded: or gif will get too long
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_call(),
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )

                ### steps_per_second_summary
                global_step_val = global_step_call()
                if logging_freq_in_iter and (iteration % logging_freq_in_iter
                                             == 0):
                    # log step number + speed (steps/sec)
                    logging.info(
                        'step = %d, loss = %f', global_step_val,
                        total_loss_value_ac.loss + total_loss_value_model.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f env_steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})

                    # reset keeping track of steps/time
                    timed_at_step = global_step_val
                    time_acc = 0

                ### train_checkpoint
                if train_checkpoint_freq_in_iter and (
                        iteration % train_checkpoint_freq_in_iter == 0):
                    train_checkpointer.save(global_step=global_step_val)

                ### policy_checkpointer
                if policy_checkpoint_freq_in_iter and (
                        iteration % policy_checkpoint_freq_in_iter == 0):
                    policy_checkpointer.save(global_step=global_step_val)

                ### rb_checkpointer
                if rb_checkpoint_freq_in_iter and (
                        iteration % rb_checkpoint_freq_in_iter == 0):
                    for rb_checkpointer in rb_checkpointers:
                        rb_checkpointer.save(global_step=global_step_val)
Exemple #28
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(100, ),
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gradient_clipping=None,
        normalize_returns=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        rb_checkpoint_interval=200,
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        # TODO(kbanoop): Handle distributions without gin.
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        experience = replay_buffer.gather_all()
        train_op = tf_agent.train(experience)
        clear_rb_op = replay_buffer.clear()

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(sguada) Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Compute evaluation metrics.
            global_step_call = sess.make_callable(global_step)
            global_step_val = global_step_call()
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            clear_rb_call = sess.make_callable(clear_rb_op)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                total_loss = train_step_call()
                clear_rb_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
Exemple #29
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v1',
    env_load_fn=suite_mujoco.load,
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    num_parallel_environments=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.contrib.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  # TODO(kbanoop): Figure out if it is possible to avoid the with block.
  with tf.contrib.summary.record_summaries_every_n_global_steps(
      summary_interval):
    if num_parallel_environments > 1:
      tf_env = tf_py_environment.TFPyEnvironment(
          parallel_py_environment.ParallelPyEnvironment(
              [lambda: env_load_fn(env_name)] * num_parallel_environments))
    else:
      tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_py_env = env_load_fn(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec(),
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    global_step = tf.train.get_or_create_global_step()

    collect_policy = tf_agent.collect_policy()
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = dataset.make_initializable_iterator()
    trajectories, unused_info = iterator.get_next()
    train_op = tf_agent.train(
        experience=trajectories, train_step_counter=global_step)

    train_checkpointer = common_utils.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=tf.contrib.checkpoint.List(train_metrics))
    policy_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy(),
        global_step=global_step)
    rb_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    for train_metric in train_metrics:
      train_metric.tf_summaries(step_metrics=train_metrics[:2])
    summary_op = tf.contrib.summary.all_summary_ops()

    with eval_summary_writer.as_default(), \
         tf.contrib.summary.always_record_summaries():
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries()

    init_agent_op = tf_agent.initialize()

    with tf.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(sguada) Remove once Periodically can be saved.
      common_utils.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      tf.contrib.summary.initialize(session=sess)
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_op, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.contrib.summary.scalar(
          name='global_steps/sec', tensor=steps_per_second_ph)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          tf.logging.info('step = %d, loss = %f', global_step_val,
                          loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          tf.logging.info('%.3f steps/sec' % steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
Exemple #30
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                  num_steps=n_step_update + 1)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries(train_step=global_step)

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute initial evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.compat.v2.summary.scalar(
            name='global_steps_per_sec',
            data=steps_per_second_ph,
            step=global_step)

        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()