示例#1
0
 def test_network_can_preprocess_and_combine(self):
     batch_size = 3
     frames = 5
     num_actions = 2
     lstm_size = 6
     states = (tf.random.uniform([batch_size, frames, 1]),
               tf.random.uniform([batch_size, frames]))
     preprocessing_layers = (
         tf.keras.layers.Dense(4),
         tf.keras.Sequential([
             expand_dims_layer.ExpandDims(-1),  # Convert to vec size (1,).
             tf.keras.layers.Dense(4)
         ]))
     network = q_rnn_network.QRnnNetwork(
         input_tensor_spec=(tensor_spec.TensorSpec([1], tf.float32),
                            tensor_spec.TensorSpec([], tf.float32)),
         preprocessing_layers=preprocessing_layers,
         preprocessing_combiner=tf.keras.layers.Add(),
         lstm_size=(lstm_size, ),
         action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0,
                                                   num_actions - 1))
     empty_step_type = tf.constant([[time_step.StepType.FIRST] * frames] *
                                   batch_size)
     q_values, _ = network(states, empty_step_type)
     self.assertAllEqual(q_values.shape.as_list(),
                         [batch_size, frames, num_actions])
     # At least 2 variables each for the preprocessing layers.
     self.assertGreater(len(network.trainable_variables), 4)
  def testTrainWithRNN(self):
    # Emits trajectories shaped (batch=1, time=6, ...)
    traj, time_step_spec, action_spec = (
        driver_test_utils.make_random_trajectory())
    cloning_net = q_rnn_network.QRnnNetwork(
        time_step_spec.observation, action_spec)
    agent = behavioral_cloning_agent.BehavioralCloningAgent(
        time_step_spec,
        action_spec,
        cloning_network=cloning_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01),
        num_outer_dims=2)
    # Disable clipping to make sure we can see the difference in behavior
    agent.policy._clip = False
    # Remove policy_info, as BehavioralCloningAgent expects none.
    traj = traj.replace(policy_info=())
    # TODO(b/123883319)
    if tf.executing_eagerly():
      train_and_loss = lambda: agent.train(traj)
    else:
      train_and_loss = agent.train(traj)
    replay = trajectory_replay.TrajectoryReplay(agent.policy)
    self.evaluate(tf.compat.v1.global_variables_initializer())
    initial_actions = self.evaluate(replay.run(traj)[0])

    for _ in range(TRAIN_ITERATIONS):
      self.evaluate(train_and_loss)
    post_training_actions = self.evaluate(replay.run(traj)[0])
    # We don't necessarily converge to the same actions as in trajectory after
    # 10 steps of an untuned optimizer, but the policy does change.
    self.assertFalse(np.all(initial_actions == post_training_actions))
示例#3
0
    def testSaveGetInitialState(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        network = q_rnn_network.QRnnNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver_nobatch = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(),
                            'save_model_initial_state_nobatch')
        saver_nobatch.save(path)
        reloaded_nobatch = tf.compat.v2.saved_model.load(path)
        self.assertIn('get_initial_state', reloaded_nobatch.signatures)
        reloaded_get_initial_state = (
            reloaded_nobatch.signatures['get_initial_state'])
        self._compare_input_output_specs(
            reloaded_get_initial_state,
            expected_input_specs=(tf.TensorSpec(dtype=tf.int32,
                                                shape=(),
                                                name='batch_size'), ),
            expected_output_spec=policy.policy_state_spec,
            batch_input=False,
            batch_size=None)

        initial_state = policy.get_initial_state(batch_size=3)
        initial_state = self.evaluate(initial_state)

        reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state(
            batch_size=3)
        reloaded_nobatch_initial_state = self.evaluate(
            reloaded_nobatch_initial_state)
        tf.nest.map_structure(self.assertAllClose, initial_state,
                              reloaded_nobatch_initial_state)

        saver_batch = policy_saver.PolicySaver(policy, batch_size=3)
        path = os.path.join(self.get_temp_dir(),
                            'save_model_initial_state_batch')
        saver_batch.save(path)
        reloaded_batch = tf.compat.v2.saved_model.load(path)
        self.assertIn('get_initial_state', reloaded_batch.signatures)
        reloaded_get_initial_state = reloaded_batch.signatures[
            'get_initial_state']
        self._compare_input_output_specs(
            reloaded_get_initial_state,
            expected_input_specs=(),
            expected_output_spec=policy.policy_state_spec,
            batch_input=False,
            batch_size=3)

        reloaded_batch_initial_state = reloaded_batch.get_initial_state()
        reloaded_batch_initial_state = self.evaluate(
            reloaded_batch_initial_state)
        tf.nest.map_structure(self.assertAllClose, initial_state,
                              reloaded_batch_initial_state)
示例#4
0
  def testTrainWithRNN(self):
    # Hard code a trajectory shaped (time=6, batch=1, ...).
    traj, time_step_spec, action_spec = create_arbitrary_trajectory()
    cloning_net = q_rnn_network.QRnnNetwork(
        time_step_spec.observation, action_spec)
    agent = behavioral_cloning_agent.BehavioralCloningAgent(
        time_step_spec,
        action_spec,
        cloning_network=cloning_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01),
        num_outer_dims=2)
    # Disable clipping to make sure we can see the difference in behavior
    agent.policy._clip = False
    # Remove policy_info, as BehavioralCloningAgent expects none.
    traj = traj.replace(policy_info=())
    # TODO(b/123883319)
    if tf.executing_eagerly():
      train_and_loss = lambda: agent.train(traj)
    else:
      train_and_loss = agent.train(traj)
    self.evaluate(tf.compat.v1.global_variables_initializer())

    initial_loss = self.evaluate(train_and_loss).loss
    for _ in range(TRAIN_ITERATIONS - 1):
      loss = self.evaluate(train_and_loss).loss

    # We don't necessarily converge to the same actions as in trajectory after
    # 10 steps of an untuned optimizer, but the loss should go down.
    self.assertGreater(initial_loss, loss)
示例#5
0
    def test_network_can_preprocess_and_combine_no_time_dim(self):
        if tf.executing_eagerly():
            self.skipTest('b/123776211')

        batch_size = 3
        num_actions = 2
        lstm_size = 5
        states = (tf.random.uniform([batch_size,
                                     1]), tf.random.uniform([batch_size]))
        preprocessing_layers = (
            tf.keras.layers.Dense(4),
            tf.keras.Sequential([
                expand_dims_layer.ExpandDims(-1),  # Convert to vec size (1,).
                tf.keras.layers.Dense(4)
            ]))
        network = q_rnn_network.QRnnNetwork(
            input_tensor_spec=(tensor_spec.TensorSpec([1], tf.float32),
                               tensor_spec.TensorSpec([], tf.float32)),
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=tf.keras.layers.Add(),
            lstm_size=(lstm_size, ),
            action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0,
                                                      num_actions - 1))
        empty_step_type = tf.constant([time_step.StepType.FIRST] * batch_size)
        q_values, _ = network(states, empty_step_type)

        # Processed 1 time step and the time axis was squeezed back.
        self.assertAllEqual(q_values.shape.as_list(),
                            [batch_size, num_actions])

        # At least 2 variables each for the preprocessing layers.
        self.assertGreater(len(network.trainable_variables), 4)
示例#6
0
  def test_network_builds(self):
    env = suite_gym.load('CartPole-v0')
    tf_env = tf_py_environment.TFPyEnvironment(env)
    rnn_network = q_rnn_network.QRnnNetwork(tf_env.observation_spec(),
                                            tf_env.action_spec())

    time_step = tf_env.current_time_step()
    q_values, state = rnn_network(time_step.observation, time_step.step_type)
    self.assertEqual((1, 2), q_values.shape)
    self.assertEqual((1, 40), state[0].shape)
    self.assertEqual((1, 40), state[1].shape)
    def test_network_builds(self):
        env = suite_gym.load('CartPole-v0')
        tf_env = tf_py_environment.TFPyEnvironment(env)
        rnn_network = q_rnn_network.QRnnNetwork(tf_env.observation_spec(),
                                                tf_env.action_spec(),
                                                lstm_size=(40, ))

        first_time_step = tf_env.current_time_step()
        q_values, state = rnn_network(
            first_time_step.observation,
            first_time_step.step_type,
            network_state=rnn_network.get_initial_state(batch_size=1))
        self.assertEqual((1, 2), q_values.shape)
        self.assertEqual((1, 40), state[0].shape)
        self.assertEqual((1, 40), state[1].shape)
示例#8
0
  def test_network_builds_rnn_construction_fn(self):
      env = suite_gym.load('CartPole-v0')
      tf_env = tf_py_environment.TFPyEnvironment(env)
      rnn_network = q_rnn_network.QRnnNetwork(tf_env.observation_spec(),
                                              tf_env.action_spec(),
                                              rnn_construction_fn=rnn_keras_fn,
                                              rnn_construction_kwargs={'lstm_size': 3})

      first_time_step = tf_env.current_time_step()
      q_values, state = rnn_network(
          first_time_step.observation, first_time_step.step_type,
          network_state=rnn_network.get_initial_state(batch_size=1),
      )
      self.assertEqual((1, 2), q_values.shape)
      self.assertEqual((3,), state[0].shape)
示例#9
0
  def test_network_builds_stacked_cells(self):
    env = suite_gym.load('CartPole-v0')
    tf_env = tf_py_environment.TFPyEnvironment(env)
    rnn_network = q_rnn_network.QRnnNetwork(
        tf_env.observation_spec(), tf_env.action_spec(), lstm_size=(10, 5))

    time_step = tf_env.current_time_step()
    q_values, state = rnn_network(time_step.observation, time_step.step_type)
    self.assertTrue(isinstance(state, tuple))
    self.assertEqual(2, len(state))

    self.assertEqual((1, 2), q_values.shape)
    self.assertEqual((1, 10), state[0][0].shape)
    self.assertEqual((1, 10), state[0][1].shape)
    self.assertEqual((1, 5), state[1][0].shape)
    self.assertEqual((1, 5), state[1][1].shape)
示例#10
0
 def setUp(self):
     super(PolicySaverTest, self).setUp()
     observation_spec = tf.TensorSpec(dtype=tf.int64,
                                      shape=(),
                                      name='callee_users')
     self._time_step_spec = time_step.time_step_spec(observation_spec)
     self._action_spec = tensor_spec.BoundedTensorSpec(
         dtype=tf.int64,
         shape=(),
         minimum=0,
         maximum=1,
         name='inlining_decision')
     self._network = q_rnn_network.QRnnNetwork(
         input_tensor_spec=self._time_step_spec.observation,
         action_spec=self._action_spec,
         lstm_size=(40, ))
  def __init__(self, config, *args, **kwargs):
    """Initialize the agent."""
    self.observation_spec = config['observation_spec']
    self.time_step_spec = ts.time_step_spec(self.observation_spec)
    self.action_spec = config['action_spec']
    self.environment_batch_size = config['environment_batch_size']

    self.q_net = q_rnn_network.QRnnNetwork(
      self.observation_spec,
      self.action_spec,
      input_fc_layer_params=(256, ),
      lstm_size=(256, 256))

    self.optimizer = tf.keras.optimizers.Adam()

    self.agent = dqn_agent.DqnAgent(
      self.time_step_spec,
      self.action_spec,
      q_network=self.q_net,
      optimizer=self.optimizer,
      epsilon_greedy=0.3)

    self.agent.initialize()

    self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      data_spec=self.agent.collect_data_spec,
      batch_size=1,
      dataset_drop_remainder=True)

    self.dataset = self.replay_buffer.as_dataset(
      sample_batch_size=1,
      num_steps=2,
      single_deterministic_pass=True)

    self.intention_classifier = tf.keras.models.Sequential(
      [tf.keras.layers.Dense(1, input_shape=self.observation_spec.shape, activation='sigmoid')])

    self.intention_classifier.compile(optimizer=self.optimizer, loss=tf.keras.losses.BinaryCrossentropy())

    self.intention_dataset_input = []
    self.intention_dataset_true = []

    self.num_hinted = 0
示例#12
0
  def test_network_builds_stacked_cells(self):
    env = suite_gym.load('CartPole-v0')
    tf_env = tf_py_environment.TFPyEnvironment(env)
    rnn_network = q_rnn_network.QRnnNetwork(
        tf_env.observation_spec(), tf_env.action_spec(), lstm_size=(10, 5))

    first_time_step = tf_env.current_time_step()
    q_values, state = rnn_network(
        first_time_step.observation, first_time_step.step_type,
        network_state=rnn_network.get_initial_state(batch_size=1)
    )
    tf.nest.assert_same_structure(rnn_network.state_spec, state)
    self.assertEqual(2, len(state))

    self.assertEqual((1, 2), q_values.shape)
    self.assertEqual((1, 10), state[0][0].shape)
    self.assertEqual((1, 10), state[0][1].shape)
    self.assertEqual((1, 5), state[1][0].shape)
    self.assertEqual((1, 5), state[1][1].shape)
示例#13
0
 def setUp(self):
     observation_spec = {
         'inlining_default':
         tf.TensorSpec(dtype=tf.int64, shape=(), name='inlining_default')
     }
     self._time_step_spec = time_step.time_step_spec(observation_spec)
     self._action_spec = tensor_spec.BoundedTensorSpec(
         dtype=tf.int64,
         shape=(),
         minimum=0,
         maximum=1,
         name='inlining_decision')
     self._network = q_rnn_network.QRnnNetwork(
         input_tensor_spec=self._time_step_spec.observation,
         action_spec=self._action_spec,
         lstm_size=(40, ),
         preprocessing_layers={
             'inlining_default': tf.keras.layers.Lambda(lambda x: x)
         })
     super(TrainerTest, self).setUp()
示例#14
0
 def testTrainWithRNN(self):
     # Emits trajectories shaped (batch=1, time=6, ...)
     traj, time_step_spec, action_spec = (
         driver_test_utils.make_random_trajectory())
     cloning_net = q_rnn_network.QRnnNetwork(time_step_spec.observation,
                                             action_spec)
     agent = behavioral_cloning_agent.BehavioralCloningAgent(
         time_step_spec,
         action_spec,
         cloning_network=cloning_net,
         optimizer=tf.train.AdamOptimizer(learning_rate=0.01))
     # Remove policy_info, as BehavioralCloningAgent expects none.
     traj = traj.replace(policy_info=())
     train_and_loss = agent.train(traj)
     replay = trajectory_replay.TrajectoryReplay(agent.policy())
     self.evaluate(tf.global_variables_initializer())
     initial_actions = self.evaluate(replay.run(traj)[0])
     for _ in range(TRAIN_ITERATIONS):
         self.evaluate(train_and_loss)
     post_training_actions = self.evaluate(replay.run(traj)[0])
     # We don't necessarily converge to the same actions as in trajectory after
     # 10 steps of an untuned optimizer, but the policy does change.
     self.assertFalse(np.all(initial_actions == post_training_actions))
示例#15
0
    def testSaveAction(self, seeded, has_state):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        if has_state:
            network = q_rnn_network.QRnnNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)
        else:
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        action_seed = 98723

        saver = policy_saver.PolicySaver(policy,
                                         batch_size=None,
                                         use_nest_path_signatures=False,
                                         seed=action_seed)
        path = os.path.join(self.get_temp_dir(), 'save_model_action')
        saver.save(path)

        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        self._compare_input_output_specs(
            reloaded_action,
            expected_input_specs=(self._time_step_spec,
                                  policy.policy_state_spec),
            expected_output_spec=policy.policy_step_spec,
            batch_input=True)

        batch_size = 3

        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size, ),
            seed=4)

        function_action_input_dict = dict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten((self._time_step_spec, policy.policy_state_spec
                                 )), tf.nest.flatten(action_inputs)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model
        # are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)
        action_output = policy.action(*action_inputs, seed=action_seed)

        # The seed= argument for the SavedModel action call was given at creation of
        # the PolicySaver.

        # This is the flat-signature function.
        reloaded_action_output_dict = reloaded_action(
            **function_action_input_dict)

        def match_dtype_shape(x, y, msg=None):
            self.assertEqual(x.shape, y.shape, msg=msg)
            self.assertEqual(x.dtype, y.dtype, msg=msg)

        # This is the non-flat function.
        if has_state:
            reloaded_action_output = reloaded.action(*action_inputs)
        else:
            # Try both cases: one with an empty policy_state and one with no
            # policy_state.  Compare them.

            # NOTE(ebrevdo): The first call to .action() must be stored in
            # reloaded_action_output because this is the version being compared later
            # against the true action_output and the values will change after the
            # first call due to randomness.
            reloaded_action_output = reloaded.action(*action_inputs)
            reloaded_action_output_no_input_state = reloaded.action(
                action_inputs[0])
            # Even with a seed, multiple calls to action will get different values,
            # so here we just check the signature matches.
            tf.nest.map_structure(match_dtype_shape,
                                  reloaded_action_output_no_input_state,
                                  reloaded_action_output)

        action_output_dict = dict(
            ((spec.name, value)
             for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                      tf.nest.flatten(action_output))))

        # Check output of the flattened signature call.
        action_output_dict = self.evaluate(action_output_dict)
        reloaded_action_output_dict = self.evaluate(
            reloaded_action_output_dict)
        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())

        for k in action_output_dict:
            if seeded:
                self.assertAllClose(action_output_dict[k],
                                    reloaded_action_output_dict[k],
                                    msg='\nMismatched dict key: %s.' % k)
            else:
                match_dtype_shape(action_output_dict[k],
                                  reloaded_action_output_dict[k],
                                  msg='\nMismatch dict key: %s.' % k)

        # Check output of the proper structured call.
        action_output = self.evaluate(action_output)
        reloaded_action_output = self.evaluate(reloaded_action_output)
        # With non-signature functions, we can check that passing a seed does the
        # right thing the second time.
        if seeded:
            tf.nest.map_structure(self.assertAllClose, action_output,
                                  reloaded_action_output)
        else:
            tf.nest.map_structure(match_dtype_shape, action_output,
                                  reloaded_action_output)
示例#16
0
def train():
    summary_interval = 1000
    summaries_flush_secs = 10
    num_eval_episodes = 5
    root_dir = '/tmp/tensorflow/logs/tfenv01'
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs*1000)
    train_summary_writer.set_as_default()
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs*1000)
    # maybe py_metrics?
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    environment = TradeEnvironment()
    # utils.validate_py_environment(environment, episodes=5)
    # Environments
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        train_env = tf_py_environment.TFPyEnvironment(environment)
        eval_env = tf_py_environment.TFPyEnvironment(environment)

        num_iterations = 50
        fc_layer_params = (512, )  # ~ (17 + 1001) / 2
        input_fc_layer_params = (50, )
        output_fc_layer_params = (20, )
        lstm_size = (30, )
        
        initial_collect_steps = 20
        collect_steps_per_iteration = 1
        collect_episodes_per_iteration = 1  # the same as above
        batch_size = 64
        replay_buffer_capacity = 10000
        
        train_sequence_length = 10

        gamma = 0.99  # check if 1.0 works as well
        target_update_tau = 0.05
        target_update_period = 5
        epsilon_greedy = 0.1
        gradient_clipping = None
        reward_scale_factor = 1.0

        learning_rate = 1e-2
        log_interval = 30
        eval_interval = 15

        # train_env.observation_spec(),
        q_net = q_rnn_network.QRnnNetwork(
            train_env.time_step_spec().observation,
            train_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params,
        )

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        tf_agent = dqn_agent.DqnAgent(
            train_env.time_step_spec(),
            train_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=global_step,
        )

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=train_env.batch_size,
            max_length=replay_buffer_capacity,
        )

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Policy which does not allow some actions in certain states
        q_policy = FilteredQPolicy(
            tf_agent._time_step_spec, 
            tf_agent._action_spec, 
            q_network=tf_agent._q_network,
        )

        # Valid policy to pre-fill replay buffer
        initial_collect_policy = DummyTradePolicy(
            train_env.time_step_spec(),
            train_env.action_spec(),
        )
        print('Initial collecting...')
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            train_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps,
        ).run()

        # Main agent's policy; greedy one
        policy = greedy_policy.GreedyPolicy(q_policy)
        # Policy used for evaluation, the same as above
        eval_policy = greedy_policy.GreedyPolicy(q_policy)
    
        tf_agent._policy = policy
        collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
            q_policy, epsilon=tf_agent._epsilon_greedy)
        # Patch random policy for epsilon greedy collect policy
        filtered_random_tf_policy = FilteredRandomTFPolicy(
            time_step_spec=policy.time_step_spec,
            action_spec=policy.action_spec,
        )
        collect_policy._random_policy = filtered_random_tf_policy
        tf_agent._collect_policy = collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            train_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration,
        ).run()
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=train_sequence_length+1,
        ).prefetch(3)

        iterator = iter(dataset) 
        experience, _ = next(iterator)
        loss_info = common.function(tf_agent.train)(experience=experience)

        # Checkpoints
        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer,
        )
        
        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(train_metric.tf_summaries(
                train_step=global_step,
                step_metrics=train_metrics[:2],
            ))

        with eval_summary_writer.as_default(), \
                tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()
    
        with tf.compat.v1.Session() as sess:
            # sess.run(train_summary_writer.init())
            # sess.run(eval_summary_writer.init())
            
            # Initialize the graph
            # tfe.Saver().restore()
            # train_checkpointer.initialize_or_restore()
            # rb_checkpointer.initialize_or_restore()
            # sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            print('Collecting initial experience...')
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([loss_info, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step,
            )

            # Train
            for i in range(num_iterations):
                start_time = time.time()
                collect_call()

                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_inerval == 0:
                    print('step=%d, loss=%f', 



                          global_step_val, loss_info_value.loss)
                    steps_per_sec = (global_step_val-timed_at_step) / time_acc
                    print('%.3f steps/sec', steps_per_sec)
                    sess.run(
                        steps_per_second_summary,
                        feed_dict={steps_per_second_ph: steps_per_sec},
                    )
                    timed_at_step = global_step_val
                    time_acc = 0

                # Save checkpoints
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                # Evaluate
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
    print('Done!')        
示例#17
0
def load_agents_and_create_videos(root_dir,
        env_name='CartPole-v0',
        num_iterations=NUM_ITERATIONS,
        max_ep_steps=1000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=((128,64,32)),
        # Params for QRnnNetwork
        input_fc_layer_params=(50,),
        lstm_size=(20,),
        output_fc_layer_params=(20,),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=10000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval 
        num_eval_episodes=10,
        num_random_episodes=1,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None,
        random_metrics_callback=None):
    
    
    
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')
    
    eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    
    # Match the environments used in training
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps))
    eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    if train_sequence_length != 1 and n_step_update != 1:
        raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

    if train_sequence_length > 1:
        q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
    else:
        q_net = q_network.QNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=fc_layer_params)

        train_sequence_length = n_step_update

    # Match the agents used in training
    tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
    
    tf_agent.initialize()

    train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),]

    eval_policy = tf_agent.policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

    train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))

    policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=eval_policy,
            global_step=global_step)

    rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

    # Load the data from training
    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    # Define a random policy for comparison
    random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(),
                                                    eval_tf_env.action_spec())

    # Make movies of the trained agent and a random agent
    date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
    
    trained_filename = "trained-agent" + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename)

    random_filename = 'random-agent ' + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
示例#18
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(
    ),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40, ),  # lstm size for actor
    critic_lstm_size=(40, ),  # lstm size for critic
    actor_output_fc_layers=(100, ),  # lstm output
    critic_output_fc_layers=(100, ),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1
):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
    # Setup GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    # Get train and eval directories
    root_dir = os.path.expanduser(root_dir)
    root_dir = os.path.join(root_dir, env_name, experiment_name)

    # Get summary writers
    summary_writer = tf.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    # Eval metrics
    eval_metrics = [
        tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Whether to record for summary
    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create Carla environment
        if agent_name == 'latent_sac':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names +
                                                 mask_names,
                                                 action_repeat=action_repeat)
        elif agent_name == 'dqn':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 discrete=True,
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)
        else:
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)

        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

        # Specs
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        ## Make tf agent
        if agent_name == 'latent_sac':
            # Get model network for latent sac
            if model_network_ctor_type == 'hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
            elif model_network_ctor_type == 'non-hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
            else:
                raise NotImplementedError
            model_net = model_network_ctor(input_names,
                                           input_names + mask_names)

            # Get the latent spec
            latent_size = model_net.latent_size
            latent_observation_spec = tensor_spec.TensorSpec((latent_size, ),
                                                             dtype=tf.float32)
            latent_time_step_spec = ts.time_step_spec(
                observation_spec=latent_observation_spec)

            # Get actor and critic net
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                latent_observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)
            critic_net = critic_network.CriticNetwork(
                (latent_observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)

            # Build the inner SAC agent based on latent space
            inner_agent = sac_agent.SacAgent(
                latent_time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            inner_agent.initialize()

            # Build the latent sac agent
            tf_agent = latent_sac_agent.LatentSACAgent(
                time_step_spec,
                action_spec,
                inner_agent=inner_agent,
                model_network=model_net,
                model_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=model_learning_rate),
                model_batch_size=model_batch_size,
                num_images_per_summary=num_images_per_summary,
                sequence_length=sequence_length,
                gradient_clipping=gradient_clipping,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                fps=fps)

        else:
            # Set up preprosessing layers for dictionary observation inputs
            preprocessing_layers = collections.OrderedDict()
            for name in input_names:
                preprocessing_layers[name] = Preprocessing_Layer(32, 256)
            if len(input_names) < 2:
                preprocessing_combiner = None

            if agent_name == 'dqn':
                q_rnn_net = q_rnn_network.QRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_rnn_net,
                    epsilon_greedy=epsilon_greedy,
                    n_step_update=1,
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=q_learning_rate),
                    td_errors_loss_fn=common.element_wise_squared_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            elif agent_name == 'ddpg' or agent_name == 'td3':
                actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                if agent_name == 'ddpg':
                    tf_agent = ddpg_agent.DdpgAgent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        ou_stddev=ou_stddev,
                        ou_damping=ou_damping,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars)
                elif agent_name == 'td3':
                    tf_agent = td3_agent.Td3Agent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        exploration_noise_std=exploration_noise_std,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        actor_update_period=actor_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars,
                        train_step_counter=global_step)

            elif agent_name == 'sac':
                actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers,
                    continuous_projection_net=normal_projection_net)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = sac_agent.SacAgent(
                    time_step_spec,
                    action_spec,
                    actor_network=actor_distribution_rnn_net,
                    critic_network=critic_rnn_net,
                    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=actor_learning_rate),
                    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=critic_learning_rate),
                    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=alpha_learning_rate),
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    td_errors_loss_fn=tf.math.
                    squared_difference,  # make critic loss dimension compatible
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            else:
                raise NotImplementedError

        tf_agent.initialize()

        # Get replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,  # No parallel environments
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        # Train metrics
        env_steps = tf_metrics.EnvironmentSteps()
        average_return = tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        # Get policies
        # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_policy = tf_agent.policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        collect_policy = tf_agent.collect_policy

        # Checkpointers
        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        # Collect driver
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        # Optimize the performance by using tf functions
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        if agent_name == 'latent_sac':
            compute_summaries(eval_metrics,
                              eval_tf_env,
                              eval_policy,
                              train_step=global_step,
                              summary_writer=summary_writer,
                              num_episodes=1,
                              num_episodes_to_render=1,
                              model_net=model_net,
                              fps=10,
                              image_keys=input_names + mask_names)
        else:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=1,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
            )
            metric_utils.log_metrics(eval_metrics)

        # Dataset generates trajectories with shape [Bxslx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        # Get train step
        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        if agent_name == 'latent_sac':

            def train_model_step():
                experience, _ = next(iterator)
                return tf_agent.train_model(experience)

            train_model_step = common.function(train_model_step)

        # Training initializations
        time_step = None
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Start training
        for iteration in range(num_iterations):
            start_time = time.time()

            if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
                train_model_step()
            else:
                # Run collect
                time_step, _ = collect_driver.run(time_step=time_step)

                # Train an iteration
                for _ in range(train_steps_per_iteration):
                    train_step()

            time_acc += time.time() - start_time

            # Log training information
            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.summary.scalar(name='env_steps_per_sec',
                                  data=env_steps_per_sec,
                                  step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            # Get training metrics
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            # Evaluation
            if global_step.numpy() % eval_interval == 0:
                # Log evaluation metrics
                if agent_name == 'latent_sac':
                    compute_summaries(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        train_step=global_step,
                        summary_writer=summary_writer,
                        num_episodes=num_eval_episodes,
                        num_episodes_to_render=num_images_per_summary,
                        model_net=model_net,
                        fps=10,
                        image_keys=input_names + mask_names)
                else:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    metric_utils.log_metrics(eval_metrics)

            # Save checkpoints
            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
示例#20
0
    def testSaveAction(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        q_network = q_rnn_network.QRnnNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=q_network)

        action_seed = 98723
        saver = policy_saver.PolicySaver(policy,
                                         batch_size=None,
                                         seed=action_seed)
        path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                            'save_model_action')
        saver.save(path)

        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        self._compare_input_output_specs(
            reloaded_action,
            expected_input_specs=(self._time_step_spec,
                                  policy.policy_state_spec),
            expected_output_spec=policy.policy_step_spec,
            batch_input=True)

        batch_size = 3

        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size, ),
            seed=4)

        function_action_input_dict = dict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten((self._time_step_spec, policy.policy_state_spec
                                 )), tf.nest.flatten(action_inputs)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model
        # are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)
        action_output = policy.action(*action_inputs, seed=action_seed)
        # The seed= argument for the SavedModel action call was given at creation of
        # the PolicySaver.
        reloaded_action_output_dict = reloaded_action(
            **function_action_input_dict)

        action_output_dict = dict(
            ((spec.name, value)
             for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                      tf.nest.flatten(action_output))))

        action_output_dict = self.evaluate(action_output_dict)
        reloaded_action_output_dict = self.evaluate(
            reloaded_action_output_dict)

        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())
        for k in action_output_dict:
            self.assertAllClose(action_output_dict[k],
                                reloaded_action_output_dict[k],
                                msg='\nMismatched dict key: %s.' % k)
示例#21
0
env_name = 'Breakout-v0'
train_py_env = BreakoutEnv(suite_gym.load(env_name))
train_env = tf_py_environment.TFPyEnvironment(train_py_env)

eval_py_env = BreakoutEnv(suite_gym.load(env_name))
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

conv_layer_params = [(32, 8, 4), (64, 4, 2)]
input_fc_layer_params = (256, )
output_fc_layer_params = (256, )
lstm_size = [256]

q_net = q_rnn_network.QRnnNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    input_fc_layer_params=input_fc_layer_params,
    output_fc_layer_params=output_fc_layer_params,
    conv_layer_params=conv_layer_params,
    lstm_size=lstm_size)

optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
loss_func = common.element_wise_huber_loss
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(train_env.time_step_spec(),
                           train_env.action_spec(),
                           q_network=q_net,
                           optimizer=optimizer,
                           td_errors_loss_fn=loss_func,
                           train_step_counter=train_step_counter)
agent.initialize()
示例#22
0
  def testSaveAction(self, seeded, has_state, distribution_net,
                     has_input_fn_and_spec):
    with tf.compat.v1.Graph().as_default():
      tf.compat.v1.set_random_seed(self._global_seed)
      with tf.compat.v1.Session().as_default():
        global_step = common.create_variable('train_step', initial_value=0)
        if distribution_net:
          network = actor_distribution_network.ActorDistributionNetwork(
              self._time_step_spec.observation, self._action_spec)
          policy = actor_policy.ActorPolicy(
              time_step_spec=self._time_step_spec,
              action_spec=self._action_spec,
              actor_network=network)
        else:
          if has_state:
            network = q_rnn_network.QRnnNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec,
                lstm_size=(40,))
          else:
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

          policy = q_policy.QPolicy(
              time_step_spec=self._time_step_spec,
              action_spec=self._action_spec,
              q_network=network)

        action_seed = 98723

        batch_size = 3
        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size,),
            seed=4)
        action_input_values = self.evaluate(action_inputs)
        action_input_tensors = tf.nest.map_structure(tf.convert_to_tensor,
                                                     action_input_values)

        action_output = policy.action(*action_input_tensors, seed=action_seed)
        distribution_output = policy.distribution(*action_input_tensors)
        self.assertIsInstance(
            distribution_output.action, tfp.distributions.Distribution)

        self.evaluate(tf.compat.v1.global_variables_initializer())

        action_output_dict = collections.OrderedDict(
            ((spec.name, value) for (spec, value) in zip(
                tf.nest.flatten(policy.policy_step_spec),
                tf.nest.flatten(action_output))))

        # Check output of the flattened signature call.
        (action_output_value, action_output_dict) = self.evaluate(
            (action_output, action_output_dict))

        distribution_output_value = self.evaluate(_sample_from_distributions(
            distribution_output))

        input_fn_and_spec = None
        if has_input_fn_and_spec:
          input_fn_and_spec = (_convert_string_vector_to_action_input,
                               tf.TensorSpec((7,), tf.string, name='example'))

        saver = policy_saver.PolicySaver(
            policy,
            batch_size=None,
            use_nest_path_signatures=False,
            seed=action_seed,
            input_fn_and_spec=input_fn_and_spec,
            train_step=global_step)
        path = os.path.join(self.get_temp_dir(), 'save_model_action')
        saver.save(path)

    with tf.compat.v1.Graph().as_default():
      tf.compat.v1.set_random_seed(self._global_seed)
      with tf.compat.v1.Session().as_default():
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        if has_input_fn_and_spec:
          self._compare_input_output_specs(
              reloaded_action,
              expected_input_specs=input_fn_and_spec[1],
              expected_output_spec=policy.policy_step_spec,
              batch_input=True)

        else:
          self._compare_input_output_specs(
              reloaded_action,
              expected_input_specs=(self._time_step_spec,
                                    policy.policy_state_spec),
              expected_output_spec=policy.policy_step_spec,
              batch_input=True)

        # Reload action_input_values as tensors in the new graph.
        action_input_tensors = tf.nest.map_structure(tf.convert_to_tensor,
                                                     action_input_values)

        action_input_spec = (self._time_step_spec, policy.policy_state_spec)
        function_action_input_dict = collections.OrderedDict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten(action_input_spec),
                tf.nest.flatten(action_input_tensors)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded
        # model are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both
        # cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)

        # The seed= argument for the SavedModel action call was given at
        # creation of the PolicySaver.
        if has_input_fn_and_spec:
          action_string_vector = _convert_action_input_to_string_vector(
              action_input_tensors)
          action_string_vector_values = self.evaluate(action_string_vector)
          reloaded_action_output_dict = reloaded_action(action_string_vector)
          reloaded_action_output = reloaded.action(action_string_vector)
          reloaded_distribution_output = reloaded.distribution(
              action_string_vector)
          self.assertIsInstance(reloaded_distribution_output.action,
                                tfp.distributions.Distribution)

        else:
          # This is the flat-signature function.
          reloaded_action_output_dict = reloaded_action(
              **function_action_input_dict)
          # This is the non-flat function.
          reloaded_action_output = reloaded.action(*action_input_tensors)
          reloaded_distribution_output = reloaded.distribution(
              *action_input_tensors)
          self.assertIsInstance(reloaded_distribution_output.action,
                                tfp.distributions.Distribution)

          if not has_state:
            # Try both cases: one with an empty policy_state and one with no
            # policy_state.  Compare them.

            # NOTE(ebrevdo): The first call to .action() must be stored in
            # reloaded_action_output because this is the version being compared
            # later against the true action_output and the values will change
            # after the first call due to randomness.
            reloaded_action_output_no_input_state = reloaded.action(
                action_input_tensors[0])
            reloaded_distribution_output_no_input_state = reloaded.distribution(
                action_input_tensors[0])
            # Even with a seed, multiple calls to action will get different
            # values, so here we just check the signature matches.
            self.assertIsInstance(
                reloaded_distribution_output_no_input_state.action,
                tfp.distributions.Distribution)
            tf.nest.map_structure(self.match_dtype_shape,
                                  reloaded_action_output_no_input_state,
                                  reloaded_action_output)

            tf.nest.map_structure(
                self.match_dtype_shape,
                _sample_from_distributions(
                    reloaded_distribution_output_no_input_state),
                _sample_from_distributions(reloaded_distribution_output))

        self.evaluate(tf.compat.v1.global_variables_initializer())
        (reloaded_action_output_dict,
         reloaded_action_output_value) = self.evaluate(
             (reloaded_action_output_dict, reloaded_action_output))

        reloaded_distribution_output_value = self.evaluate(
            _sample_from_distributions(reloaded_distribution_output))

        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())

        for k in action_output_dict:
          if seeded:
            self.assertAllClose(
                action_output_dict[k],
                reloaded_action_output_dict[k],
                msg='\nMismatched dict key: %s.' % k)
          else:
            self.match_dtype_shape(
                action_output_dict[k],
                reloaded_action_output_dict[k],
                msg='\nMismatch dict key: %s.' % k)

        # With non-signature functions, we can check that passing a seed does
        # the right thing the second time.
        if seeded:
          tf.nest.map_structure(self.assertAllClose, action_output_value,
                                reloaded_action_output_value)
        else:
          tf.nest.map_structure(self.match_dtype_shape, action_output_value,
                                reloaded_action_output_value)

        tf.nest.map_structure(self.assertAllClose,
                              distribution_output_value,
                              reloaded_distribution_output_value)

    ## TFLite tests.

    # The converter must run outside of a TF1 graph context, even in
    # eager mode, to ensure the TF2 path is being executed.  Only
    # works in TF2.
    if tf.compat.v1.executing_eagerly_outside_functions():
      tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
          path, signature_keys=['action'])
      tflite_converter.target_spec.supported_ops = [
          tf.lite.OpsSet.TFLITE_BUILTINS,
          # TODO(b/111309333): Remove this when `has_input_fn_and_spec`
          # is `False` once TFLite has native support for RNG ops, atan, etc.
          tf.lite.OpsSet.SELECT_TF_OPS,
      ]
      tflite_serialized_model = tflite_converter.convert()

      tflite_interpreter = tf.lite.Interpreter(
          model_content=tflite_serialized_model)

      tflite_runner = tflite_interpreter.get_signature_runner('action')
      tflite_signature = tflite_interpreter.get_signature_list()['action']

      if has_input_fn_and_spec:
        tflite_action_input_dict = {
            'example': action_string_vector_values,
        }
      else:
        tflite_action_input_dict = collections.OrderedDict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten(action_input_spec),
                tf.nest.flatten(action_input_values)))

      self.assertEqual(
          set(tflite_signature['inputs']),
          set(tflite_action_input_dict))
      self.assertEqual(
          set(tflite_signature['outputs']),
          set(action_output_dict))

      tflite_output = tflite_runner(**tflite_action_input_dict)

      self.assertAllClose(tflite_output, action_output_dict)
示例#23
0
    def testSaveAction(self, seeded, has_state, distribution_net,
                       has_input_fn_and_spec):
        with tf.compat.v1.Graph().as_default():
            tf.compat.v1.set_random_seed(self._global_seed)
            with tf.compat.v1.Session().as_default():
                global_step = common.create_variable('train_step',
                                                     initial_value=0)
                if distribution_net:
                    network = actor_distribution_network.ActorDistributionNetwork(
                        self._time_step_spec.observation, self._action_spec)
                    policy = actor_policy.ActorPolicy(
                        time_step_spec=self._time_step_spec,
                        action_spec=self._action_spec,
                        actor_network=network)
                else:
                    if has_state:
                        network = q_rnn_network.QRnnNetwork(
                            input_tensor_spec=self._time_step_spec.observation,
                            action_spec=self._action_spec)
                    else:
                        network = q_network.QNetwork(
                            input_tensor_spec=self._time_step_spec.observation,
                            action_spec=self._action_spec)

                    policy = q_policy.QPolicy(
                        time_step_spec=self._time_step_spec,
                        action_spec=self._action_spec,
                        q_network=network)

                action_seed = 98723

                batch_size = 3
                action_inputs = tensor_spec.sample_spec_nest(
                    (self._time_step_spec, policy.policy_state_spec),
                    outer_dims=(batch_size, ),
                    seed=4)
                action_input_values = self.evaluate(action_inputs)
                action_input_tensors = tf.nest.map_structure(
                    tf.convert_to_tensor, action_input_values)

                action_output = policy.action(*action_input_tensors,
                                              seed=action_seed)

                self.evaluate(tf.compat.v1.global_variables_initializer())

                action_output_dict = dict(((spec.name, value) for (
                    spec,
                    value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                  tf.nest.flatten(action_output))))

                # Check output of the flattened signature call.
                (action_output_value, action_output_dict) = self.evaluate(
                    (action_output, action_output_dict))

                input_fn_and_spec = None
                if has_input_fn_and_spec:
                    input_fn_and_spec = (
                        self._convert_string_vector_to_action_input,
                        tf.TensorSpec((7, ), tf.string, name='example'))

                saver = policy_saver.PolicySaver(
                    policy,
                    batch_size=None,
                    use_nest_path_signatures=False,
                    seed=action_seed,
                    input_fn_and_spec=input_fn_and_spec,
                    train_step=global_step)
                path = os.path.join(self.get_temp_dir(), 'save_model_action')
                saver.save(path)

        with tf.compat.v1.Graph().as_default():
            tf.compat.v1.set_random_seed(self._global_seed)
            with tf.compat.v1.Session().as_default():
                reloaded = tf.compat.v2.saved_model.load(path)

                self.assertIn('action', reloaded.signatures)
                reloaded_action = reloaded.signatures['action']
                if has_input_fn_and_spec:
                    self._compare_input_output_specs(
                        reloaded_action,
                        expected_input_specs=input_fn_and_spec[1],
                        expected_output_spec=policy.policy_step_spec,
                        batch_input=True)

                else:
                    self._compare_input_output_specs(
                        reloaded_action,
                        expected_input_specs=(self._time_step_spec,
                                              policy.policy_state_spec),
                        expected_output_spec=policy.policy_step_spec,
                        batch_input=True)

                # Reload action_input_values as tensors in the new graph.
                action_input_tensors = tf.nest.map_structure(
                    tf.convert_to_tensor, action_input_values)

                action_input_spec = (self._time_step_spec,
                                     policy.policy_state_spec)
                function_action_input_dict = dict(
                    (spec.name, value)
                    for (spec,
                         value) in zip(tf.nest.flatten(action_input_spec),
                                       tf.nest.flatten(action_input_tensors)))

                # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded
                # model are equal, which in addition to seeding the call to action() and
                # PolicySaver helps ensure equality of the output of action() in both
                # cases.
                self.assertEqual(reloaded_action.graph.seed, self._global_seed)

                def match_dtype_shape(x, y, msg=None):
                    self.assertEqual(x.shape, y.shape, msg=msg)
                    self.assertEqual(x.dtype, y.dtype, msg=msg)

                # The seed= argument for the SavedModel action call was given at
                # creation of the PolicySaver.
                if has_input_fn_and_spec:
                    action_string_vector = self._convert_action_input_to_string_vector(
                        action_input_tensors)
                    reloaded_action_output_dict = reloaded_action(
                        action_string_vector)
                    reloaded_action_output = reloaded.action(
                        action_string_vector)

                else:
                    # This is the flat-signature function.
                    reloaded_action_output_dict = reloaded_action(
                        **function_action_input_dict)
                    # This is the non-flat function.
                    reloaded_action_output = reloaded.action(
                        *action_input_tensors)

                    if not has_state:
                        # Try both cases: one with an empty policy_state and one with no
                        # policy_state.  Compare them.

                        # NOTE(ebrevdo): The first call to .action() must be stored in
                        # reloaded_action_output because this is the version being compared
                        # later against the true action_output and the values will change
                        # after the first call due to randomness.
                        reloaded_action_output_no_input_state = reloaded.action(
                            action_input_tensors[0])
                        # Even with a seed, multiple calls to action will get different
                        # values, so here we just check the signature matches.
                        tf.nest.map_structure(
                            match_dtype_shape,
                            reloaded_action_output_no_input_state,
                            reloaded_action_output)

                self.evaluate(tf.compat.v1.global_variables_initializer())
                (reloaded_action_output_dict,
                 reloaded_action_output_value) = self.evaluate(
                     (reloaded_action_output_dict, reloaded_action_output))

                self.assertAllEqual(action_output_dict.keys(),
                                    reloaded_action_output_dict.keys())

                for k in action_output_dict:
                    if seeded:
                        self.assertAllClose(action_output_dict[k],
                                            reloaded_action_output_dict[k],
                                            msg='\nMismatched dict key: %s.' %
                                            k)
                    else:
                        match_dtype_shape(action_output_dict[k],
                                          reloaded_action_output_dict[k],
                                          msg='\nMismatch dict key: %s.' % k)

                # With non-signature functions, we can check that passing a seed does
                # the right thing the second time.
                if seeded:
                    tf.nest.map_structure(self.assertAllClose,
                                          action_output_value,
                                          reloaded_action_output_value)
                else:
                    tf.nest.map_structure(match_dtype_shape,
                                          action_output_value,
                                          reloaded_action_output_value)
示例#24
0
def train():
    global VERBOSE
    environment = TradeEnvironment()
    # utils.validate_py_environment(environment, episodes=5)
    # Environments
    train_env = tf_py_environment.TFPyEnvironment(environment)
    eval_env = tf_py_environment.TFPyEnvironment(environment)

    num_iterations = 50
    fc_layer_params = (512, )  # ~ (17 + 1001) / 2
    input_fc_layer_params = (17, )
    output_fc_layer_params = (20, )
    lstm_size = (17, )
    initial_collect_steps = 20
    collect_steps_per_iteration = 1
    batch_size = 64
    replay_buffer_capacity = 10000

    gamma = 0.99  # check if 1 will work here
    target_update_tau = 0.05
    target_update_period = 5
    epsilon_greedy = 0.1
    reward_scale_factor = 1.0
    learning_rate = 1e-2
    log_interval = 30
    num_eval_episodes = 5
    eval_interval = 15

    # q_net = q_network.QNetwork(
    #     train_env.observation_spec(),
    #     train_env.action_spec(),
    #     fc_layer_params=fc_layer_params,
    # )

    q_net = q_rnn_network.QRnnNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        input_fc_layer_params=input_fc_layer_params,
        lstm_size=lstm_size,
        output_fc_layer_params=output_fc_layer_params,
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    train_step_counter = tf.compat.v2.Variable(0)

    tf_agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        train_step_counter=train_step_counter,
        gradient_clipping=None,
        debug_summaries=False,
        summarize_grads_and_vars=False,
    )

    q_policy = FilteredQPolicy(
        tf_agent._time_step_spec,
        tf_agent._action_spec,
        q_network=tf_agent._q_network,
    )

    # Valid policy to pre-fill replay buffer
    dummy_policy = DummyTradePolicy(
        train_env.time_step_spec(),
        train_env.action_spec(),
    )

    # Main agent's policy; greedy one
    policy = greedy_policy.GreedyPolicy(q_policy)
    filtered_random_py_policy = FilteredRandomPyPolicy(
        time_step_spec=policy.time_step_spec,
        action_spec=policy.action_spec,
    )
    filtered_random_tf_policy = tf_py_policy.TFPyPolicy(
        filtered_random_py_policy)
    collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
        q_policy, epsilon=tf_agent._epsilon_greedy)
    # Patch random policy for epsilon greedy collect policy

    filtered_random_tf_policy = FilteredRandomTFPolicy(
        time_step_spec=policy.time_step_spec,
        action_spec=policy.action_spec,
    )
    collect_policy._random_policy = filtered_random_tf_policy

    tf_agent._policy = policy
    tf_agent._collect_policy = collect_policy
    tf_agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_capacity,
    )
    print(
        'Pre-filling replay buffer in {} steps'.format(initial_collect_steps))
    for _ in range(initial_collect_steps):
        traj = collect_step(train_env, dummy_policy)
        replay_buffer.add_batch(traj)

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2,
    ).prefetch(3)

    iterator = iter(dataset)
    # Train
    tf_agent.train = common.function(tf_agent.train)

    tf_agent.train_step_counter.assign(0)

    avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                    num_eval_episodes)

    returns = [avg_return]

    print('Starting iterations...')
    for i in range(num_iterations):

        # fill replay buffer
        for _ in range(collect_steps_per_iteration):
            traj = collect_step(train_env, tf_agent.collect_policy)
            # Add trajectory to the replay buffer
            replay_buffer.add_batch(traj)

        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience)

        step = tf_agent.train_step_counter.numpy()

        if step % log_interval == 0:
            print('step = {0}: loss = {1}'.format(step, train_loss.loss))

        if step % eval_interval == 0:
            avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                            num_eval_episodes)
            print('step = {0}: avg return = {1}'.format(step, avg_return))
            returns.append(avg_return)

    print('Finished {} iterations!'.format(num_iterations))

    print('Playing with resulting policy')
    VERBOSE = True
    r = compute_avg_return(eval_env, tf_agent.policy, 1)
    print('Result: {}'.format(r))
    steps = range(0, num_iterations + 1, eval_interval)

    # merged = tf.summary.merge_all()
    # writer = tf.summary.FileWriter(FLAGS.log_dir)
    #
    # writer.close()
    print('Check out chart for learning')
    plt.plot(steps, returns)
    plt.ylabel('Average Return')
    plt.xlabel('Step')
    plt.ylim(top=1000)
    plt.show()
示例#25
0
def train_eval(
        root_dir,
        env_name='MaskedCartPole-v0',
        num_iterations=100000,
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),
        train_sequence_length=10,
        # Params for collect
        initial_collect_steps=50,
        collect_episodes_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=10,
        batch_size=128,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        q_net = q_rnn_network.QRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params)

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        loss_info = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            logging.info('Collecting initial experience.')
            sess.run(initial_collect_op)

            # Compute evaluation metrics.
            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([loss_info, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
  def testSaveGetInitialState(self):
    network = q_rnn_network.QRnnNetwork(
        input_tensor_spec=self._time_step_spec.observation,
        action_spec=self._action_spec,
        lstm_size=(40,))

    policy = q_policy.QPolicy(
        time_step_spec=self._time_step_spec,
        action_spec=self._action_spec,
        q_network=network)

    train_step = common.create_variable('train_step', initial_value=0)
    saver_nobatch = policy_saver.PolicySaver(
        policy,
        train_step=train_step,
        batch_size=None,
        use_nest_path_signatures=False)
    path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_nobatch')

    self.evaluate(tf.compat.v1.global_variables_initializer())

    with self.cached_session():
      saver_nobatch.save(path)
      reloaded_nobatch = tf.compat.v2.saved_model.load(path)
      self.evaluate(
          tf.compat.v1.initializers.variables(reloaded_nobatch.model_variables))

    self.assertIn('get_initial_state', reloaded_nobatch.signatures)
    reloaded_get_initial_state = (
        reloaded_nobatch.signatures['get_initial_state'])
    self._compare_input_output_specs(
        reloaded_get_initial_state,
        expected_input_specs=(tf.TensorSpec(
            dtype=tf.int32, shape=(), name='batch_size'),),
        expected_output_spec=policy.policy_state_spec,
        batch_input=False,
        batch_size=None)

    initial_state = policy.get_initial_state(batch_size=3)
    initial_state = self.evaluate(initial_state)

    reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state(
        batch_size=3)
    reloaded_nobatch_initial_state = self.evaluate(
        reloaded_nobatch_initial_state)
    tf.nest.map_structure(self.assertAllClose, initial_state,
                          reloaded_nobatch_initial_state)

    saver_batch = policy_saver.PolicySaver(
        policy,
        train_step=train_step,
        batch_size=3,
        use_nest_path_signatures=False)
    path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_batch')
    with self.cached_session():
      saver_batch.save(path)
      reloaded_batch = tf.compat.v2.saved_model.load(path)
      self.evaluate(
          tf.compat.v1.initializers.variables(reloaded_batch.model_variables))
    self.assertIn('get_initial_state', reloaded_batch.signatures)
    reloaded_get_initial_state = reloaded_batch.signatures['get_initial_state']
    self._compare_input_output_specs(
        reloaded_get_initial_state,
        expected_input_specs=(),
        expected_output_spec=policy.policy_state_spec,
        batch_input=False,
        batch_size=3)

    reloaded_batch_initial_state = reloaded_batch.get_initial_state()
    reloaded_batch_initial_state = self.evaluate(reloaded_batch_initial_state)
    tf.nest.map_structure(self.assertAllClose, initial_state,
                          reloaded_batch_initial_state)