Пример #1
0
    def testTrainStepNotSaved(self):
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('get_train_step', reloaded.signatures)
        train_step_value = self.evaluate(reloaded.train_step())
        self.assertEqual(-1, train_step_value)
Пример #2
0
    def testCheckpointSave(self):
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        saver.save(path)
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        self.assertTrue(tf.compat.v2.io.gfile.exists(checkpoint_path))
Пример #3
0
    def testActionWithinBounds(self):
        bounded_action_spec = tensor_spec.BoundedTensorSpec([1],
                                                            tf.int32,
                                                            minimum=-6,
                                                            maximum=-5)
        policy = q_policy.QPolicy(self._time_step_spec,
                                  bounded_action_spec,
                                  q_network=DummyNet())

        observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        time_step = ts.restart(observations, batch_size=2)
        action_step = policy.action(time_step)
        self.assertEqual(action_step.action.shape.as_list(), [2, 1])
        self.assertEqual(action_step.action.dtype, tf.int32)
        # Initialize all variables
        self.evaluate(tf.compat.v1.global_variables_initializer())
        action = self.evaluate(action_step.action)
        self.assertTrue(np.all(action <= -5) and np.all(action >= -6))
Пример #4
0
    def testLogits(self):
        tf.compat.v1.set_random_seed(1)
        wrapped = q_policy.QPolicy(self._time_step_spec,
                                   self._action_spec,
                                   q_network=DummyNet())
        policy = boltzmann_policy.BoltzmannPolicy(wrapped, temperature=0.5)

        observations = tf.constant([[1, 2]], dtype=tf.float32)
        time_step = ts.restart(observations, batch_size=1)
        distribution_step = policy.distribution(time_step)
        logits = distribution_step.action.logits
        original_logits = wrapped.distribution(time_step).action.logits
        self.evaluate(tf.compat.v1.global_variables_initializer())
        # The un-temperature'd logits would be 4 and 5.5, because it is (1 2) . (1
        # 1) + 1 and (1 2) . (1.5 1.5) + 1. The temperature'd logits will be double
        # that.
        self.assertAllEqual([[[4., 5.5]]], self.evaluate(original_logits))
        self.assertAllEqual([[[8., 11.]]], self.evaluate(logits))
Пример #5
0
    def testSaveGetInitialState(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        q_network = q_rnn_network.QRnnNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=q_network)

        saver_nobatch = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                            'save_model_initial_state_nobatch')
        saver_nobatch.save(path)
        reloaded_nobatch = tf.compat.v2.saved_model.load(path)
        self.assertIn('get_initial_state', reloaded_nobatch.signatures)
        reloaded_get_initial_state = (
            reloaded_nobatch.signatures['get_initial_state'])
        self._compare_input_output_specs(
            reloaded_get_initial_state,
            expected_input_specs=(tf.TensorSpec(dtype=tf.int32,
                                                shape=(),
                                                name='batch_size'), ),
            expected_output_spec=policy.policy_state_spec,
            batch_input=False,
            batch_size=None)

        saver_batch = policy_saver.PolicySaver(policy, batch_size=3)
        path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                            'save_model_initial_state_batch')
        saver_batch.save(path)
        reloaded_batch = tf.compat.v2.saved_model.load(path)
        self.assertIn('get_initial_state', reloaded_batch.signatures)
        reloaded_get_initial_state = reloaded_batch.signatures[
            'get_initial_state']
        self._compare_input_output_specs(
            reloaded_get_initial_state,
            expected_input_specs=(),
            expected_output_spec=policy.policy_state_spec,
            batch_input=False,
            batch_size=3)
Пример #6
0
  def _setup_as_discrete(self, time_step_spec, action_spec, loss_fn,
                         epsilon_greedy):
    self._bc_loss_fn = loss_fn or self._discrete_loss

    if any(isinstance(d, distribution_utils.DistributionSpecV2) for
           d in tf.nest.flatten([self._network_output_spec])):
      # If the output of the cloning network contains a distribution.
      base_policy = actor_policy.ActorPolicy(time_step_spec, action_spec,
                                             self._cloning_network)
    else:
      # If the output of the cloning network is logits.
      base_policy = q_policy.QPolicy(
          time_step_spec,
          action_spec,
          q_network=self._cloning_network,
          validate_action_spec_and_network=False)
    policy = greedy_policy.GreedyPolicy(base_policy)
    collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
        base_policy, epsilon=epsilon_greedy)
    return policy, collect_policy
Пример #7
0
    def testDeferredBatchingAction(self):
        # Construct policy without providing batch_size.
        tf_policy = q_policy.QPolicy(self._time_step_spec,
                                     self._action_spec,
                                     q_network=DummyNet(stateful=False))
        policy = py_tf_policy.PyTFPolicy(tf_policy)

        # But time_steps have batch_size of 5
        batch_size = 5
        single_observation = np.array([1, 2], dtype=np.float32)
        time_steps = [ts.restart(single_observation)] * batch_size
        time_steps = fast_map_structure(lambda *arrays: np.stack(arrays),
                                        *time_steps)

        with self.test_session():
            tf.global_variables_initializer().run()
            action_steps = policy.action(time_steps)
            self.assertEqual(action_steps.action.shape, (batch_size, ))
            self.assertAllEqual(action_steps.action, [1] * batch_size)
            self.assertAllEqual(action_steps.state, ())
    def testTrainStepSaved(self):
        # We need to use one default session so that self.evaluate and the
        # SavedModel loader share the same session.
        with tf.compat.v1.Session().as_default():
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

            policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                      action_spec=self._action_spec,
                                      q_network=network)
            self.evaluate(
                tf.compat.v1.initializers.variables(policy.variables()))

            train_step = common.create_variable('train_step', initial_value=7)
            self.evaluate(tf.compat.v1.initializers.variables([train_step]))

            saver = policy_saver.PolicySaver(policy,
                                             batch_size=None,
                                             train_step=train_step)
            if tf.executing_eagerly():
                step = saver.get_train_step()
            else:
                step = self.evaluate(saver.get_train_step())
            self.assertEqual(7, step)
            path = os.path.join(self.get_temp_dir(), 'save_model')
            saver.save(path)

            reloaded = tf.compat.v2.saved_model.load(path)
            self.assertIn('get_train_step', reloaded.signatures)
            self.evaluate(tf.compat.v1.global_variables_initializer())
            train_step_value = self.evaluate(reloaded.get_train_step())
            self.assertEqual(7, train_step_value)
            train_step = train_step.assign_add(3)
            self.evaluate(train_step)
            saver.save(path)

            reloaded = tf.compat.v2.saved_model.load(path)
            self.evaluate(tf.compat.v1.global_variables_initializer())
            train_step_value = self.evaluate(reloaded.get_train_step())
            self.assertEqual(10, train_step_value)
Пример #9
0
    def testUniqueSignatures(self):
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        action_signature_names = [
            s.name for s in saver._signatures['action'].input_signature
        ]
        self.assertAllEqual(
            ['0/step_type', '0/reward', '0/discount', '0/observation'],
            action_signature_names)
        initial_state_signature_names = [
            s.name
            for s in saver._signatures['get_initial_state'].input_signature
        ]
        self.assertAllEqual(['batch_size'], initial_state_signature_names)
Пример #10
0
    def testTrainStepNotSaved(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('get_train_step', reloaded.signatures)
        train_step_value = self.evaluate(reloaded.get_train_step())
        self.assertEqual(-1, train_step_value)
Пример #11
0
    def testDeferredBatchingAction(self):
        if tf.executing_eagerly():
            self.skipTest('b/123770140')

        # Construct policy without providing batch_size.
        tf_policy = q_policy.QPolicy(self._time_step_spec,
                                     self._action_spec,
                                     q_network=DummyNet(stateful=False))
        policy = py_tf_policy.PyTFPolicy(tf_policy)

        # But time_steps have batch_size of 5
        batch_size = 5
        single_observation = np.array([1, 2], dtype=np.float32)
        time_steps = [ts.restart(single_observation)] * batch_size
        time_steps = fast_map_structure(lambda *arrays: np.stack(arrays),
                                        *time_steps)

        with self.cached_session():
            self.evaluate(tf.compat.v1.global_variables_initializer())
            action_steps = policy.action(time_steps)
            self.assertEqual(action_steps.action.shape, (batch_size, ))
            for a in action_steps.action:
                self.assertIn(a, (0, 1))
            self.assertAllEqual(action_steps.state, ())
Пример #12
0
  def testSaveAction(self, seeded, has_state, distribution_net,
                     has_input_fn_and_spec):
    with tf.compat.v1.Graph().as_default():
      tf.compat.v1.set_random_seed(self._global_seed)
      with tf.compat.v1.Session().as_default():
        global_step = common.create_variable('train_step', initial_value=0)
        if distribution_net:
          network = actor_distribution_network.ActorDistributionNetwork(
              self._time_step_spec.observation, self._action_spec)
          policy = actor_policy.ActorPolicy(
              time_step_spec=self._time_step_spec,
              action_spec=self._action_spec,
              actor_network=network)
        else:
          if has_state:
            network = q_rnn_network.QRnnNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec,
                lstm_size=(40,))
          else:
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

          policy = q_policy.QPolicy(
              time_step_spec=self._time_step_spec,
              action_spec=self._action_spec,
              q_network=network)

        action_seed = 98723

        batch_size = 3
        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size,),
            seed=4)
        action_input_values = self.evaluate(action_inputs)
        action_input_tensors = tf.nest.map_structure(tf.convert_to_tensor,
                                                     action_input_values)

        action_output = policy.action(*action_input_tensors, seed=action_seed)
        distribution_output = policy.distribution(*action_input_tensors)
        self.assertIsInstance(
            distribution_output.action, tfp.distributions.Distribution)

        self.evaluate(tf.compat.v1.global_variables_initializer())

        action_output_dict = collections.OrderedDict(
            ((spec.name, value) for (spec, value) in zip(
                tf.nest.flatten(policy.policy_step_spec),
                tf.nest.flatten(action_output))))

        # Check output of the flattened signature call.
        (action_output_value, action_output_dict) = self.evaluate(
            (action_output, action_output_dict))

        distribution_output_value = self.evaluate(_sample_from_distributions(
            distribution_output))

        input_fn_and_spec = None
        if has_input_fn_and_spec:
          input_fn_and_spec = (_convert_string_vector_to_action_input,
                               tf.TensorSpec((7,), tf.string, name='example'))

        saver = policy_saver.PolicySaver(
            policy,
            batch_size=None,
            use_nest_path_signatures=False,
            seed=action_seed,
            input_fn_and_spec=input_fn_and_spec,
            train_step=global_step)
        path = os.path.join(self.get_temp_dir(), 'save_model_action')
        saver.save(path)

    with tf.compat.v1.Graph().as_default():
      tf.compat.v1.set_random_seed(self._global_seed)
      with tf.compat.v1.Session().as_default():
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        if has_input_fn_and_spec:
          self._compare_input_output_specs(
              reloaded_action,
              expected_input_specs=input_fn_and_spec[1],
              expected_output_spec=policy.policy_step_spec,
              batch_input=True)

        else:
          self._compare_input_output_specs(
              reloaded_action,
              expected_input_specs=(self._time_step_spec,
                                    policy.policy_state_spec),
              expected_output_spec=policy.policy_step_spec,
              batch_input=True)

        # Reload action_input_values as tensors in the new graph.
        action_input_tensors = tf.nest.map_structure(tf.convert_to_tensor,
                                                     action_input_values)

        action_input_spec = (self._time_step_spec, policy.policy_state_spec)
        function_action_input_dict = collections.OrderedDict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten(action_input_spec),
                tf.nest.flatten(action_input_tensors)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded
        # model are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both
        # cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)

        # The seed= argument for the SavedModel action call was given at
        # creation of the PolicySaver.
        if has_input_fn_and_spec:
          action_string_vector = _convert_action_input_to_string_vector(
              action_input_tensors)
          action_string_vector_values = self.evaluate(action_string_vector)
          reloaded_action_output_dict = reloaded_action(action_string_vector)
          reloaded_action_output = reloaded.action(action_string_vector)
          reloaded_distribution_output = reloaded.distribution(
              action_string_vector)
          self.assertIsInstance(reloaded_distribution_output.action,
                                tfp.distributions.Distribution)

        else:
          # This is the flat-signature function.
          reloaded_action_output_dict = reloaded_action(
              **function_action_input_dict)
          # This is the non-flat function.
          reloaded_action_output = reloaded.action(*action_input_tensors)
          reloaded_distribution_output = reloaded.distribution(
              *action_input_tensors)
          self.assertIsInstance(reloaded_distribution_output.action,
                                tfp.distributions.Distribution)

          if not has_state:
            # Try both cases: one with an empty policy_state and one with no
            # policy_state.  Compare them.

            # NOTE(ebrevdo): The first call to .action() must be stored in
            # reloaded_action_output because this is the version being compared
            # later against the true action_output and the values will change
            # after the first call due to randomness.
            reloaded_action_output_no_input_state = reloaded.action(
                action_input_tensors[0])
            reloaded_distribution_output_no_input_state = reloaded.distribution(
                action_input_tensors[0])
            # Even with a seed, multiple calls to action will get different
            # values, so here we just check the signature matches.
            self.assertIsInstance(
                reloaded_distribution_output_no_input_state.action,
                tfp.distributions.Distribution)
            tf.nest.map_structure(self.match_dtype_shape,
                                  reloaded_action_output_no_input_state,
                                  reloaded_action_output)

            tf.nest.map_structure(
                self.match_dtype_shape,
                _sample_from_distributions(
                    reloaded_distribution_output_no_input_state),
                _sample_from_distributions(reloaded_distribution_output))

        self.evaluate(tf.compat.v1.global_variables_initializer())
        (reloaded_action_output_dict,
         reloaded_action_output_value) = self.evaluate(
             (reloaded_action_output_dict, reloaded_action_output))

        reloaded_distribution_output_value = self.evaluate(
            _sample_from_distributions(reloaded_distribution_output))

        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())

        for k in action_output_dict:
          if seeded:
            self.assertAllClose(
                action_output_dict[k],
                reloaded_action_output_dict[k],
                msg='\nMismatched dict key: %s.' % k)
          else:
            self.match_dtype_shape(
                action_output_dict[k],
                reloaded_action_output_dict[k],
                msg='\nMismatch dict key: %s.' % k)

        # With non-signature functions, we can check that passing a seed does
        # the right thing the second time.
        if seeded:
          tf.nest.map_structure(self.assertAllClose, action_output_value,
                                reloaded_action_output_value)
        else:
          tf.nest.map_structure(self.match_dtype_shape, action_output_value,
                                reloaded_action_output_value)

        tf.nest.map_structure(self.assertAllClose,
                              distribution_output_value,
                              reloaded_distribution_output_value)

    ## TFLite tests.

    # The converter must run outside of a TF1 graph context, even in
    # eager mode, to ensure the TF2 path is being executed.  Only
    # works in TF2.
    if tf.compat.v1.executing_eagerly_outside_functions():
      tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
          path, signature_keys=['action'])
      tflite_converter.target_spec.supported_ops = [
          tf.lite.OpsSet.TFLITE_BUILTINS,
          # TODO(b/111309333): Remove this when `has_input_fn_and_spec`
          # is `False` once TFLite has native support for RNG ops, atan, etc.
          tf.lite.OpsSet.SELECT_TF_OPS,
      ]
      tflite_serialized_model = tflite_converter.convert()

      tflite_interpreter = tf.lite.Interpreter(
          model_content=tflite_serialized_model)

      tflite_runner = tflite_interpreter.get_signature_runner('action')
      tflite_signature = tflite_interpreter.get_signature_list()['action']

      if has_input_fn_and_spec:
        tflite_action_input_dict = {
            'example': action_string_vector_values,
        }
      else:
        tflite_action_input_dict = collections.OrderedDict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten(action_input_spec),
                tf.nest.flatten(action_input_values)))

      self.assertEqual(
          set(tflite_signature['inputs']),
          set(tflite_action_input_dict))
      self.assertEqual(
          set(tflite_signature['outputs']),
          set(action_output_dict))

      tflite_output = tflite_runner(**tflite_action_input_dict)

      self.assertAllClose(tflite_output, action_output_dict)
Пример #13
0
    def testInferenceWithCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        # Create and saved_model for a q_policy.
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)
        sample_input = self.evaluate(
            tensor_spec.sample_spec_nest(self._time_step_spec,
                                         outer_dims=(3, )))

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        original_eval = self.evaluate(policy.action(sample_input))
        saver.save(path)
        # Asign -1 to all variables in the policy. Making checkpoint different than
        # the initial saved_model.
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        # Get an instance of the saved_model.
        reloaded_policy = tf.compat.v2.saved_model.load(path)
        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify loaded saved_model variables are different than the current policy.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).any())
        tf.nest.map_structure(assert_np_not_equal, model_variables,
                              reloaded_model_variables)

        # Update from checkpoint.
        checkpoint = tf.train.Checkpoint(policy=reloaded_policy)
        checkpoint_file_prefix = os.path.join(checkpoint_path, 'variables',
                                              'variables')
        checkpoint.read(
            checkpoint_file_prefix).assert_existing_objects_matched()

        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify variables are now equal.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal, model_variables,
                              reloaded_model_variables)

        # Verify variable update affects inference.
        reloaded_eval = self.evaluate(reloaded_policy.action(sample_input))
        tf.nest.map_structure(assert_np_not_equal, original_eval,
                              reloaded_eval)
        current_eval = self.evaluate(policy.action(sample_input))
        tf.nest.map_structure(assert_np_not_equal, current_eval, reloaded_eval)
Пример #14
0
    def testSaveAction(self, seeded, has_state, distribution_net,
                       has_input_fn_and_spec):
        with tf.compat.v1.Graph().as_default():
            tf.compat.v1.set_random_seed(self._global_seed)
            with tf.compat.v1.Session().as_default():
                global_step = common.create_variable('train_step',
                                                     initial_value=0)
                if distribution_net:
                    network = actor_distribution_network.ActorDistributionNetwork(
                        self._time_step_spec.observation, self._action_spec)
                    policy = actor_policy.ActorPolicy(
                        time_step_spec=self._time_step_spec,
                        action_spec=self._action_spec,
                        actor_network=network)
                else:
                    if has_state:
                        network = q_rnn_network.QRnnNetwork(
                            input_tensor_spec=self._time_step_spec.observation,
                            action_spec=self._action_spec)
                    else:
                        network = q_network.QNetwork(
                            input_tensor_spec=self._time_step_spec.observation,
                            action_spec=self._action_spec)

                    policy = q_policy.QPolicy(
                        time_step_spec=self._time_step_spec,
                        action_spec=self._action_spec,
                        q_network=network)

                action_seed = 98723

                batch_size = 3
                action_inputs = tensor_spec.sample_spec_nest(
                    (self._time_step_spec, policy.policy_state_spec),
                    outer_dims=(batch_size, ),
                    seed=4)
                action_input_values = self.evaluate(action_inputs)
                action_input_tensors = tf.nest.map_structure(
                    tf.convert_to_tensor, action_input_values)

                action_output = policy.action(*action_input_tensors,
                                              seed=action_seed)

                self.evaluate(tf.compat.v1.global_variables_initializer())

                action_output_dict = dict(((spec.name, value) for (
                    spec,
                    value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                  tf.nest.flatten(action_output))))

                # Check output of the flattened signature call.
                (action_output_value, action_output_dict) = self.evaluate(
                    (action_output, action_output_dict))

                input_fn_and_spec = None
                if has_input_fn_and_spec:
                    input_fn_and_spec = (
                        self._convert_string_vector_to_action_input,
                        tf.TensorSpec((7, ), tf.string, name='example'))

                saver = policy_saver.PolicySaver(
                    policy,
                    batch_size=None,
                    use_nest_path_signatures=False,
                    seed=action_seed,
                    input_fn_and_spec=input_fn_and_spec,
                    train_step=global_step)
                path = os.path.join(self.get_temp_dir(), 'save_model_action')
                saver.save(path)

        with tf.compat.v1.Graph().as_default():
            tf.compat.v1.set_random_seed(self._global_seed)
            with tf.compat.v1.Session().as_default():
                reloaded = tf.compat.v2.saved_model.load(path)

                self.assertIn('action', reloaded.signatures)
                reloaded_action = reloaded.signatures['action']
                if has_input_fn_and_spec:
                    self._compare_input_output_specs(
                        reloaded_action,
                        expected_input_specs=input_fn_and_spec[1],
                        expected_output_spec=policy.policy_step_spec,
                        batch_input=True)

                else:
                    self._compare_input_output_specs(
                        reloaded_action,
                        expected_input_specs=(self._time_step_spec,
                                              policy.policy_state_spec),
                        expected_output_spec=policy.policy_step_spec,
                        batch_input=True)

                # Reload action_input_values as tensors in the new graph.
                action_input_tensors = tf.nest.map_structure(
                    tf.convert_to_tensor, action_input_values)

                action_input_spec = (self._time_step_spec,
                                     policy.policy_state_spec)
                function_action_input_dict = dict(
                    (spec.name, value)
                    for (spec,
                         value) in zip(tf.nest.flatten(action_input_spec),
                                       tf.nest.flatten(action_input_tensors)))

                # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded
                # model are equal, which in addition to seeding the call to action() and
                # PolicySaver helps ensure equality of the output of action() in both
                # cases.
                self.assertEqual(reloaded_action.graph.seed, self._global_seed)

                def match_dtype_shape(x, y, msg=None):
                    self.assertEqual(x.shape, y.shape, msg=msg)
                    self.assertEqual(x.dtype, y.dtype, msg=msg)

                # The seed= argument for the SavedModel action call was given at
                # creation of the PolicySaver.
                if has_input_fn_and_spec:
                    action_string_vector = self._convert_action_input_to_string_vector(
                        action_input_tensors)
                    reloaded_action_output_dict = reloaded_action(
                        action_string_vector)
                    reloaded_action_output = reloaded.action(
                        action_string_vector)

                else:
                    # This is the flat-signature function.
                    reloaded_action_output_dict = reloaded_action(
                        **function_action_input_dict)
                    # This is the non-flat function.
                    reloaded_action_output = reloaded.action(
                        *action_input_tensors)

                    if not has_state:
                        # Try both cases: one with an empty policy_state and one with no
                        # policy_state.  Compare them.

                        # NOTE(ebrevdo): The first call to .action() must be stored in
                        # reloaded_action_output because this is the version being compared
                        # later against the true action_output and the values will change
                        # after the first call due to randomness.
                        reloaded_action_output_no_input_state = reloaded.action(
                            action_input_tensors[0])
                        # Even with a seed, multiple calls to action will get different
                        # values, so here we just check the signature matches.
                        tf.nest.map_structure(
                            match_dtype_shape,
                            reloaded_action_output_no_input_state,
                            reloaded_action_output)

                self.evaluate(tf.compat.v1.global_variables_initializer())
                (reloaded_action_output_dict,
                 reloaded_action_output_value) = self.evaluate(
                     (reloaded_action_output_dict, reloaded_action_output))

                self.assertAllEqual(action_output_dict.keys(),
                                    reloaded_action_output_dict.keys())

                for k in action_output_dict:
                    if seeded:
                        self.assertAllClose(action_output_dict[k],
                                            reloaded_action_output_dict[k],
                                            msg='\nMismatched dict key: %s.' %
                                            k)
                    else:
                        match_dtype_shape(action_output_dict[k],
                                          reloaded_action_output_dict[k],
                                          msg='\nMismatch dict key: %s.' % k)

                # With non-signature functions, we can check that passing a seed does
                # the right thing the second time.
                if seeded:
                    tf.nest.map_structure(self.assertAllClose,
                                          action_output_value,
                                          reloaded_action_output_value)
                else:
                    tf.nest.map_structure(match_dtype_shape,
                                          action_output_value,
                                          reloaded_action_output_value)
Пример #15
0
  def testSaveGetInitialState(self):
    network = q_rnn_network.QRnnNetwork(
        input_tensor_spec=self._time_step_spec.observation,
        action_spec=self._action_spec,
        lstm_size=(40,))

    policy = q_policy.QPolicy(
        time_step_spec=self._time_step_spec,
        action_spec=self._action_spec,
        q_network=network)

    train_step = common.create_variable('train_step', initial_value=0)
    saver_nobatch = policy_saver.PolicySaver(
        policy,
        train_step=train_step,
        batch_size=None,
        use_nest_path_signatures=False)
    path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_nobatch')

    self.evaluate(tf.compat.v1.global_variables_initializer())

    with self.cached_session():
      saver_nobatch.save(path)
      reloaded_nobatch = tf.compat.v2.saved_model.load(path)
      self.evaluate(
          tf.compat.v1.initializers.variables(reloaded_nobatch.model_variables))

    self.assertIn('get_initial_state', reloaded_nobatch.signatures)
    reloaded_get_initial_state = (
        reloaded_nobatch.signatures['get_initial_state'])
    self._compare_input_output_specs(
        reloaded_get_initial_state,
        expected_input_specs=(tf.TensorSpec(
            dtype=tf.int32, shape=(), name='batch_size'),),
        expected_output_spec=policy.policy_state_spec,
        batch_input=False,
        batch_size=None)

    initial_state = policy.get_initial_state(batch_size=3)
    initial_state = self.evaluate(initial_state)

    reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state(
        batch_size=3)
    reloaded_nobatch_initial_state = self.evaluate(
        reloaded_nobatch_initial_state)
    tf.nest.map_structure(self.assertAllClose, initial_state,
                          reloaded_nobatch_initial_state)

    saver_batch = policy_saver.PolicySaver(
        policy,
        train_step=train_step,
        batch_size=3,
        use_nest_path_signatures=False)
    path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_batch')
    with self.cached_session():
      saver_batch.save(path)
      reloaded_batch = tf.compat.v2.saved_model.load(path)
      self.evaluate(
          tf.compat.v1.initializers.variables(reloaded_batch.model_variables))
    self.assertIn('get_initial_state', reloaded_batch.signatures)
    reloaded_get_initial_state = reloaded_batch.signatures['get_initial_state']
    self._compare_input_output_specs(
        reloaded_get_initial_state,
        expected_input_specs=(),
        expected_output_spec=policy.policy_state_spec,
        batch_input=False,
        batch_size=3)

    reloaded_batch_initial_state = reloaded_batch.get_initial_state()
    reloaded_batch_initial_state = self.evaluate(reloaded_batch_initial_state)
    tf.nest.map_structure(self.assertAllClose, initial_state,
                          reloaded_batch_initial_state)
Пример #16
0
input_tensor_spec = tf_env.observation_spec()
time_step_spec = ts.time_step_spec(input_tensor_spec)
action_spec = tf_env.action_spec()

num_actions = env.size**2
batch_size = 1
observation = tf.cast(
    [(np.random.randint(env.size - 2, size=(env.size, env.size)) +
      1).reshape(25) for _ in range(1)], tf.int32)
time_steps = ts.restart(observation, batch_size=batch_size)

my_q_network = QNetwork(input_tensor_spec=input_tensor_spec,
                        action_spec=action_spec,
                        num_actions=num_actions)
my_q_policy = q_policy.QPolicy(time_step_spec,
                               action_spec,
                               q_network=my_q_network)
action_step = my_q_policy.action(time_steps)
distribution_step = my_q_policy.distribution(time_steps)

print('Action:')
print(action_step.action)

print('Action distribution:')
print(distribution_step.action)

num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(tf_env,
                                                     my_q_policy,
Пример #17
0
    def __init__(
            self,
            time_step_spec,
            action_spec,
            q_network,
            optimizer,
            epsilon_greedy=0.1,
            # Params for target network updates
            target_update_tau=1.0,
            target_update_period=1,
            # Params for training.
            td_errors_loss_fn=None,
            gamma=1.0,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for debugging
            debug_summaries=False,
            summarize_grads_and_vars=False):
        """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.

    Raises:
      ValueError: If the action spec contains more than one action.
    """
        flat_action_spec = nest.flatten(action_spec)
        self._num_actions = [
            spec.maximum - spec.minimum + 1 for spec in flat_action_spec
        ]

        # TODO(oars): Get DQN working with more than one dim in the actions.
        if len(flat_action_spec) > 1 or flat_action_spec[0].shape.ndims > 1:
            raise ValueError('Only one dimensional actions are supported now.')

        self._q_network = q_network
        self._target_q_network = self._q_network.copy(name='TargetQNetwork')
        self._epsilon_greedy = epsilon_greedy
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._optimizer = optimizer
        self._td_errors_loss_fn = td_errors_loss_fn or element_wise_huber_loss
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping

        self._target_update_train_op = None

        policy = q_policy.QPolicy(time_step_spec,
                                  action_spec,
                                  q_network=self._q_network)

        collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
            policy, epsilon=self._epsilon_greedy)
        policy = greedy_policy.GreedyPolicy(policy)

        super(DqnAgent, self).__init__(
            time_step_spec,
            action_spec,
            policy,
            collect_policy,
            train_sequence_length=2 if not q_network.state_spec else None,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)
Пример #18
0
    def testUpdateWithCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        # Create and saved_model for a q_policy.
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        saver.save(path)

        # Assign -1 to all variables in the policy. Making checkpoint different than
        # the initial saved_model.
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        # Get an instance of the saved_model.
        reloaded_policy = tf.compat.v2.saved_model.load(path)
        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify loaded saved_model variables are different than the current policy.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).any())
        tf.nest.map_structure(assert_np_not_equal, model_variables,
                              reloaded_model_variables)

        # Update from checkpoint.
        checkpoint = tf.train.Checkpoint(policy=reloaded_policy)
        manager = tf.train.CheckpointManager(checkpoint,
                                             directory=checkpoint_path,
                                             max_to_keep=None)
        checkpoint.restore(manager.latest_checkpoint).expect_partial()

        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify variables are now equal.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal, model_variables,
                              reloaded_model_variables)
Пример #19
0
  def __init__(
      self,
      time_step_spec,
      action_spec,
      q_network,
      optimizer,
      epsilon_greedy=0.1,
      n_step_update=1,
      boltzmann_temperature=None,
      emit_log_probability=False,
      update_period=None,
      # Params for target network updates
      target_update_tau=1.0,
      target_update_period=1,
      # Params for training.
      td_errors_loss_fn=None,
      gamma=1.0,
      reward_scale_factor=1.0,
      gradient_clipping=None,
      # Params for debugging
      debug_summaries=False,
      enable_functions=True,
      summarize_grads_and_vars=False,
      train_step_counter=None,
      name=None):
    """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Defaults to single-step updates. Note that this requires the
        user to call train on Trajectory objects with a time dimension of
        `n_step_update + 1`. However, note that we do not yet support
        `n_step_update > 1` in the case of RNNs (i.e., non-empty
        `q_network.state_spec`).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      emit_log_probability: Whether policies emit log probabilities or not.
      update_period: Update period.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      enable_functions: A bool to decide whether or not to enable tf function
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
    tf.Module.__init__(self, name=name)

    flat_action_spec = tf.nest.flatten(action_spec)
    self._num_actions = [
        spec.maximum - spec.minimum + 1 for spec in flat_action_spec
    ]

    if len(flat_action_spec) > 1 or flat_action_spec[0].shape.ndims > 1:
      raise ValueError('Only one dimensional actions are supported now.')

    if not all(spec.minimum == 0 for spec in flat_action_spec):
      raise ValueError(
          'Action specs should have minimum of 0, but saw: {0}'.format(
              [spec.minimum for spec in flat_action_spec]))

    if epsilon_greedy is not None and boltzmann_temperature is not None:
      raise ValueError(
          'Configured both epsilon_greedy value {} and temperature {}, '
          'however only one of them can be used for exploration.'.format(
              epsilon_greedy, boltzmann_temperature))

    self._q_network = q_network
    self._target_q_network = self._q_network.copy(name='TargetQNetwork')
    self._epsilon_greedy = epsilon_greedy
    self._n_step_update = n_step_update
    self._boltzmann_temperature = boltzmann_temperature
    self._optimizer = optimizer
    self._td_errors_loss_fn = td_errors_loss_fn or element_wise_huber_loss
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._update_target = self._get_target_updater(target_update_tau,
                                                   target_update_period)

    policy = q_policy.QPolicy(
        time_step_spec,
        action_spec,
        q_network=self._q_network,
        emit_log_probability=emit_log_probability)

    if boltzmann_temperature is not None:
      collect_policy = boltzmann_policy.BoltzmannPolicy(
          policy, temperature=self._boltzmann_temperature)
    else:
      collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
          policy, epsilon=self._epsilon_greedy)
    policy = greedy_policy.GreedyPolicy(policy)

    if q_network.state_spec and n_step_update != 1:
      raise NotImplementedError(
          'DqnAgent does not currently support n-step updates with stateful '
          'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update))

    train_sequence_length = (
        n_step_update + 1 if not q_network.state_spec else None)

    super(DqnAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        update_period=update_period,
        debug_summaries=debug_summaries,
        enable_functions=enable_functions,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter)

    tf.compat.v1.summary.scalar(
        'epsilon/' + self.name,
        self._epsilon_greedy,
        collections=['train_' + self.name])
Пример #20
0
  def __init__(
      self,
      time_step_spec,
      action_spec,
      q_network,
      optimizer,
      epsilon_greedy=0.1,
      boltzmann_temperature=None,
      # Params for target network updates
      target_update_tau=1.0,
      target_update_period=1,
      # Params for training.
      td_errors_loss_fn=None,
      gamma=1.0,
      reward_scale_factor=1.0,
      gradient_clipping=None,
      # Params for debugging
      debug_summaries=False,
      summarize_grads_and_vars=False,
      train_step_counter=None,
      name=None):
    """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
    """
    tf.Module.__init__(self, name=name)

    flat_action_spec = tf.nest.flatten(action_spec)
    self._num_actions = [
        spec.maximum - spec.minimum + 1 for spec in flat_action_spec
    ]

    # TODO(oars): Get DQN working with more than one dim in the actions.
    if len(flat_action_spec) > 1 or flat_action_spec[0].shape.ndims > 1:
      raise ValueError('Only one dimensional actions are supported now.')

    if not all(spec.minimum == 0 for spec in flat_action_spec):
      raise ValueError(
          'Action specs should have minimum of 0, but saw: {0}'.format(
              [spec.minimum for spec in flat_action_spec]))

    if epsilon_greedy is not None and boltzmann_temperature is not None:
      raise ValueError(
          'Configured both epsilon_greedy value {} and temperature {}, '
          'however only one of them can be used for exploration.'.format(
              epsilon_greedy, boltzmann_temperature))

    self._q_network = q_network
    self._target_q_network = self._q_network.copy(name='TargetQNetwork')
    self._epsilon_greedy = epsilon_greedy
    self._boltzmann_temperature = boltzmann_temperature
    self._optimizer = optimizer
    self._td_errors_loss_fn = td_errors_loss_fn or element_wise_huber_loss
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy = q_policy.QPolicy(
        time_step_spec, action_spec, q_network=self._q_network)

    if boltzmann_temperature is not None:
      collect_policy = boltzmann_policy.BoltzmannPolicy(
          policy, temperature=self._boltzmann_temperature)
    else:
      collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
          policy, epsilon=self._epsilon_greedy)
    policy = greedy_policy.GreedyPolicy(policy)

    super(DqnAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=2 if not q_network.state_spec else None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter)
Пример #21
0
    def testSaveRestore(self, batch_size):
        policy_save_path = os.path.join(flags.FLAGS.test_tmpdir, 'policy',
                                        str(batch_size))

        # Construct a policy to be saved under a tf.Graph instance.
        policy_saved_graph = tf.Graph()
        with policy_saved_graph.as_default():
            tf_policy = q_policy.QPolicy(
                self._time_step_spec, self._action_spec,
                DummyNet(use_constant_initializer=False))

            # Parameterized tests reuse temp directories, make no save exists.
            try:
                tf.io.gfile.listdir(policy_save_path)
                tf.io.gfile.rmtree(policy_save_path)
            except tf.errors.NotFoundError:
                pass
            policy_saved = py_tf_policy.PyTFPolicy(tf_policy)
            policy_saved.session = tf.compat.v1.Session(
                graph=policy_saved_graph)
            policy_saved.initialize(batch_size)
            policy_saved.save(policy_dir=policy_save_path,
                              graph=policy_saved_graph)
            # Verify that index files were written. There will also be some number of
            # data files, but this depends on the number of devices.
            self.assertContainsSubset(
                set(['checkpoint', 'ckpt-0.index']),
                set(tf.io.gfile.listdir(policy_save_path)))

        # Construct a policy to be restored under another tf.Graph instance.
        policy_restore_graph = tf.Graph()
        with policy_restore_graph.as_default():
            tf_policy = q_policy.QPolicy(
                self._time_step_spec, self._action_spec,
                DummyNet(use_constant_initializer=False))
            policy_restored = py_tf_policy.PyTFPolicy(tf_policy)
            policy_restored.session = tf.compat.v1.Session(
                graph=policy_restore_graph)
            policy_restored.initialize(batch_size)
            random_init_vals = policy_restored.session.run(
                tf_policy.variables())
            policy_restored.restore(policy_dir=policy_save_path,
                                    graph=policy_restore_graph)
            restored_vals = policy_restored.session.run(tf_policy.variables())
            for random_init_var, restored_var in zip(random_init_vals,
                                                     restored_vals):
                self.assertFalse(np.array_equal(random_init_var, restored_var))

        # Check that variables in the two policies have identical values.
        with policy_restore_graph.as_default():
            restored_values = policy_restored.session.run(
                tf.compat.v1.global_variables())
        with policy_saved_graph.as_default():
            initial_values = policy_saved.session.run(
                tf.compat.v1.global_variables())

        # Networks have two fully connected layers.
        self.assertLen(initial_values, 4)
        self.assertLen(restored_values, 4)

        for initial_var, restored_var in zip(initial_values, restored_values):
            np.testing.assert_array_equal(initial_var, restored_var)
Пример #22
0
    def testSaveAction(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        q_network = q_rnn_network.QRnnNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=q_network)

        action_seed = 98723
        saver = policy_saver.PolicySaver(policy,
                                         batch_size=None,
                                         seed=action_seed)
        path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                            'save_model_action')
        saver.save(path)

        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        self._compare_input_output_specs(
            reloaded_action,
            expected_input_specs=(self._time_step_spec,
                                  policy.policy_state_spec),
            expected_output_spec=policy.policy_step_spec,
            batch_input=True)

        batch_size = 3

        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size, ),
            seed=4)

        function_action_input_dict = dict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten((self._time_step_spec, policy.policy_state_spec
                                 )), tf.nest.flatten(action_inputs)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model
        # are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)
        action_output = policy.action(*action_inputs, seed=action_seed)
        # The seed= argument for the SavedModel action call was given at creation of
        # the PolicySaver.
        reloaded_action_output_dict = reloaded_action(
            **function_action_input_dict)

        action_output_dict = dict(
            ((spec.name, value)
             for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                      tf.nest.flatten(action_output))))

        action_output_dict = self.evaluate(action_output_dict)
        reloaded_action_output_dict = self.evaluate(
            reloaded_action_output_dict)

        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())
        for k in action_output_dict:
            self.assertAllClose(action_output_dict[k],
                                reloaded_action_output_dict[k],
                                msg='\nMismatched dict key: %s.' % k)
Пример #23
0
    def __init__(
            self,
            time_step_spec,
            action_spec,
            q_network,
            optimizer,
            epsilon_greedy=0.1,
            n_step_update=1,
            boltzmann_temperature=None,
            emit_log_probability=False,
            # Params for target network updates
            target_q_network=None,
            target_update_tau=1.0,
            target_update_period=1,
            # Params for training.
            td_errors_loss_fn=None,
            gamma=1.0,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for debugging
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=None,
            name=None):
        """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A `tf_agents.network.Network` to be used by the agent. The
        network will be called with `call(observation, step_type)` and should
        emit logits over the action space.
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Defaults to single-step updates. Note that this requires the
        user to call train on Trajectory objects with a time dimension of
        `n_step_update + 1`. However, note that we do not yet support
        `n_step_update > 1` in the case of RNNs (i.e., non-empty
        `q_network.state_spec`).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      emit_log_probability: Whether policies emit log probabilities or not.
      target_q_network: (Optional.)  A `tf_agents.network.Network` to be used
        as the target network during Q learning.  Every `target_update_period`
        train steps, the weights from `q_network` are copied (possibly with
        smoothing via `target_update_tau`) to `target_q_network`.

        If `target_q_network` is not provided, it is created by making a
        copy of `q_network`, which initializes a new network with the same
        structure and its own layers and weights.

        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or
        when the network is sharing layers with another).  In these cases, it is
        up to you to build a copy having weights that are not
        shared with the original `q_network`, so that this can be used as a
        target network.  If you provide a `target_q_network` that shares any
        weights with `q_network`, a warning will be logged but no exception
        is thrown.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        if epsilon_greedy is not None and boltzmann_temperature is not None:
            raise ValueError(
                'Configured both epsilon_greedy value {} and temperature {}, '
                'however only one of them can be used for exploration.'.format(
                    epsilon_greedy, boltzmann_temperature))

        self._q_network = q_network
        self._target_q_network = common.maybe_copy_target_network_with_checks(
            self._q_network, target_q_network, 'TargetQNetwork')

        self._epsilon_greedy = epsilon_greedy
        self._n_step_update = n_step_update
        self._boltzmann_temperature = boltzmann_temperature
        self._optimizer = optimizer
        self._td_errors_loss_fn = td_errors_loss_fn or common.element_wise_huber_loss
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping
        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)

        policy, collect_policy = self._setup_policy(time_step_spec,
                                                    action_spec,
                                                    boltzmann_temperature,
                                                    emit_log_probability)

        self._greedy_policy = policy
        self._target_policy = q_policy.QPolicy(
            time_step_spec, action_spec, q_network=self._target_q_network)
        self._target_greedy_policy = greedy_policy.GreedyPolicy(
            self._target_policy)

        if q_network.state_spec and n_step_update != 1:
            raise NotImplementedError(
                'DqnAgent does not currently support n-step updates with stateful '
                'networks (i.e., RNNs), but n_step_update = {}'.format(
                    n_step_update))

        train_sequence_length = (n_step_update +
                                 1 if not q_network.state_spec else None)

        super(DqnAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
Пример #24
0
    def __init__(
            self,
            time_step_spec,
            action_spec,
            cloning_network,
            optimizer,
            epsilon_greedy=0.1,
            # Params for training.
            loss_fn=None,
            gradient_clipping=None,
            # Params for debugging
            debug_summaries=False,
            summarize_grads_and_vars=False):
        """Creates an behavioral cloning Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      cloning_network: A tf_agents.network.Network to be used by the agent.
        The network will be called as

          ```
          network(observation, step_type, network_state=None)
          ```
        (with `network_state` optional) and must return a 2-tuple with elements
        `(output, next_network_state)` where `output` will be passed as the
        first argument to `loss_fn`, and used by a `Policy`.  Input tensors will
        be shaped `[batch, time, ...]` when training, and they will be shaped
        `[batch, ...]` when the network is called within a `Policy`.  If
        `cloning_network` has an empty network state, then for training
        `time` will always be `1` (individual examples).
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      loss_fn: A function for computing the error between the output of the
        cloning network and the action that was taken. If None, the loss
        depends on the action dtype.  If the dtype is integer, then `loss_fn`
        is

        ```python
        def loss_fn(logits, action):
          return tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=action - action_spec.minimum, logits=logits)
        ```

        If the dtype is floating point, the loss is
        `tf.math.squared_difference`.

        `loss_fn` must return a loss value for each element of the batch.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.

    Raises:
      NotImplementedError: If the action spec contains more than one action.
    """
        flat_action_spec = nest.flatten(action_spec)
        self._num_actions = [
            spec.maximum - spec.minimum + 1 for spec in flat_action_spec
        ]

        # TODO(oars): Get behavioral cloning working with more than one dim in
        # the actions.
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'Multi-arity actions are not currently supported.')
        if flat_action_spec[0].dtype.is_floating:
            if loss_fn is None:
                loss_fn = tf.math.squared_difference
        else:
            if flat_action_spec[0].shape.ndims > 1:
                raise NotImplementedError(
                    'Only scalar and one dimensional integer actions are supported.'
                )
            if loss_fn is None:
                # TODO(ebrevdo): Maybe move the subtraction of the minimum into a
                # self._label_fn and rewrite this.
                def xent_loss_fn(logits, actions):
                    # Subtract the minimum so that we get a proper cross entropy loss on
                    # [0, maximum - minimum).
                    return tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=logits,
                        labels=actions - flat_action_spec[0].minimum)

                loss_fn = xent_loss_fn

        self._cloning_network = cloning_network
        self._loss_fn = loss_fn
        self._epsilon_greedy = epsilon_greedy
        self._optimizer = optimizer
        self._gradient_clipping = gradient_clipping

        policy = q_policy.QPolicy(time_step_spec,
                                  action_spec,
                                  q_network=self._cloning_network)
        collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
            policy, epsilon=self._epsilon_greedy)
        policy = greedy_policy.GreedyPolicy(policy)

        super(BehavioralCloningAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=1
                             if not cloning_network.state_spec else None,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars)
Пример #25
0
 def testMultipleActionsRaiseError(self):
     action_spec = [tensor_spec.BoundedTensorSpec([], tf.int32, 0, 1)] * 2
     with self.assertRaisesRegexp(ValueError, 'Only scalar actions'):
         q_policy.QPolicy(self._time_step_spec,
                          action_spec,
                          q_network=DummyNet())
Пример #26
0
 def testActionSpecsCompatible(self):
     q_net = DummyNetWithActionSpec(self._action_spec)
     q_policy.QPolicy(self._time_step_spec, self._action_spec, q_net)
Пример #27
0
  def testBuild(self):
    policy = q_policy.QPolicy(
        self._time_step_spec, self._action_spec, q_network=DummyNet())

    self.assertEqual(policy.time_step_spec, self._time_step_spec)
    self.assertEqual(policy.action_spec, self._action_spec)
Пример #28
0
    def testSaveAction(self, seeded, has_state):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        if has_state:
            network = q_rnn_network.QRnnNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)
        else:
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        action_seed = 98723

        saver = policy_saver.PolicySaver(policy,
                                         batch_size=None,
                                         use_nest_path_signatures=False,
                                         seed=action_seed)
        path = os.path.join(self.get_temp_dir(), 'save_model_action')
        saver.save(path)

        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        self._compare_input_output_specs(
            reloaded_action,
            expected_input_specs=(self._time_step_spec,
                                  policy.policy_state_spec),
            expected_output_spec=policy.policy_step_spec,
            batch_input=True)

        batch_size = 3

        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size, ),
            seed=4)

        function_action_input_dict = dict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten((self._time_step_spec, policy.policy_state_spec
                                 )), tf.nest.flatten(action_inputs)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model
        # are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)
        action_output = policy.action(*action_inputs, seed=action_seed)

        # The seed= argument for the SavedModel action call was given at creation of
        # the PolicySaver.

        # This is the flat-signature function.
        reloaded_action_output_dict = reloaded_action(
            **function_action_input_dict)

        def match_dtype_shape(x, y, msg=None):
            self.assertEqual(x.shape, y.shape, msg=msg)
            self.assertEqual(x.dtype, y.dtype, msg=msg)

        # This is the non-flat function.
        if has_state:
            reloaded_action_output = reloaded.action(*action_inputs)
        else:
            # Try both cases: one with an empty policy_state and one with no
            # policy_state.  Compare them.

            # NOTE(ebrevdo): The first call to .action() must be stored in
            # reloaded_action_output because this is the version being compared later
            # against the true action_output and the values will change after the
            # first call due to randomness.
            reloaded_action_output = reloaded.action(*action_inputs)
            reloaded_action_output_no_input_state = reloaded.action(
                action_inputs[0])
            # Even with a seed, multiple calls to action will get different values,
            # so here we just check the signature matches.
            tf.nest.map_structure(match_dtype_shape,
                                  reloaded_action_output_no_input_state,
                                  reloaded_action_output)

        action_output_dict = dict(
            ((spec.name, value)
             for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                      tf.nest.flatten(action_output))))

        # Check output of the flattened signature call.
        action_output_dict = self.evaluate(action_output_dict)
        reloaded_action_output_dict = self.evaluate(
            reloaded_action_output_dict)
        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())

        for k in action_output_dict:
            if seeded:
                self.assertAllClose(action_output_dict[k],
                                    reloaded_action_output_dict[k],
                                    msg='\nMismatched dict key: %s.' % k)
            else:
                match_dtype_shape(action_output_dict[k],
                                  reloaded_action_output_dict[k],
                                  msg='\nMismatch dict key: %s.' % k)

        # Check output of the proper structured call.
        action_output = self.evaluate(action_output)
        reloaded_action_output = self.evaluate(reloaded_action_output)
        # With non-signature functions, we can check that passing a seed does the
        # right thing the second time.
        if seeded:
            tf.nest.map_structure(self.assertAllClose, action_output,
                                  reloaded_action_output)
        else:
            tf.nest.map_structure(match_dtype_shape, action_output,
                                  reloaded_action_output)
Пример #29
0
    def __init__(self,
                 time_step_spec,
                 env_action_spec,
                 replay_buffer_action_spec,
                 actor_network,
                 critic_network,
                 actor_optimizer,
                 critic_optimizer,
                 exploration_noise_std=0.1,
                 epsilon_greedy=0.1,
                 boltzmann_temperature=None,
                 emit_log_probability=False,
                 critic_network_2=None,
                 target_actor_network=None,
                 target_critic_network=None,
                 target_critic_network_2=None,
                 target_update_tau=1.0,
                 target_update_period=1,
                 actor_update_period=1,
                 dqda_clipping=None,
                 td_errors_loss_fn=None,
                 gamma=1.0,
                 reward_scale_factor=1.0,
                 target_policy_noise=0.2,
                 target_policy_noise_clip=0.5,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None,
                 name=None):
        """Creates a Td3Agent Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      env_action_spec: A nest of BoundedTensorSpec representing the environment actions.
      replay_buffer_action_spec: A nest of BoundedTensorSpec representing the actions 
         serving as input for critic.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      critic_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, action, step_type).
      actor_optimizer: The default optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      exploration_noise_std: Scale factor on exploration policy noise.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_actor_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target actor network during Q learning. Every
        `target_update_period` train steps, the weights from `actor_network` are
        copied (possibly withsmoothing via `target_update_tau`) to `
        target_actor_network`.  If `target_actor_network` is not provided, it is
        created by making a copy of `actor_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `actor_network`, so that this can be used as a target network.
        If you provide a `target_actor_network` that shares any weights with
        `actor_network`, a warning will be logged but no exception is thrown.
      target_critic_network: (Optional.) Similar network as target_actor_network
        but for the critic_network. See documentation for target_actor_network.
      target_critic_network_2: (Optional.) Similar network as
        target_actor_network but for the critic_network_2. See documentation for
        target_actor_network. Will only be used if 'critic_network_2' is also
        specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      actor_update_period: Period for the optimization step on actor network.
      dqda_clipping: A scalar or float clips the gradient dqda element-wise
        between [-dqda_clipping, dqda_clipping]. Default is None representing no
        clippiing.
      td_errors_loss_fn:  A function for computing the TD errors loss. If None,
        a default value of elementwise huber_loss is used.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      target_policy_noise: Scale factor on target action noise
      target_policy_noise_clip: Value to clip noise.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)
        self._actor_network = actor_network
        self._target_actor_network = common.maybe_copy_target_network_with_checks(
            self._actor_network, target_actor_network, 'TargetActorNetwork')

        self._critic_network_1 = critic_network
        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1, target_critic_network,
                'TargetCriticNetwork1'))

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2, target_critic_network_2,
                'TargetCriticNetwork2'))

        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer

        self._boltzmann_temperature = boltzmann_temperature
        self._epsilon_greedy = epsilon_greedy

        self._replay_buffer_action_spec = replay_buffer_action_spec
        self._exploration_noise_std = exploration_noise_std
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_update_period = actor_update_period
        self._dqda_clipping = dqda_clipping
        self._td_errors_loss_fn = (td_errors_loss_fn
                                   or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_policy_noise = target_policy_noise
        self._target_policy_noise_clip = target_policy_noise_clip
        self._gradient_clipping = gradient_clipping

        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)

        #    policy = actor_policy.ActorPolicy(
        #        time_step_spec=time_step_spec, action_spec=action_spec,
        #        actor_network=self._actor_network, clip=True)

        #    collect_policy = actor_policy.ActorPolicy(
        #        time_step_spec=time_step_spec, action_spec=action_spec,
        #        actor_network=self._actor_network, clip=False)
        #    collect_policy = gaussian_policy.GaussianPolicy(
        #        collect_policy,
        #        scale=self._exploration_noise_std,
        #        clip=True)

        policy = q_policy.QPolicy(time_step_spec,
                                  replay_buffer_action_spec,
                                  q_network=self._actor_network,
                                  emit_log_probability=emit_log_probability)
        policy._clip = False

        collect_policy = epsilon_discrete_boltzmann_policy.EpsilonDiscreteBoltzmannPolicy(
            policy,
            epsilon=self._epsilon_greedy,
            env_action_spec=env_action_spec)

        collect_policy = discrete_boltzmann_policy.DiscreteBoltzmannPolicy(
            policy, temperature=self._boltzmann_temperature)

        if boltzmann_temperature is not None:
            policy = discrete_boltzmann_policy.DiscreteBoltzmannPolicy(
                policy, temperature=self._boltzmann_temperature / 0.1)

#    collect_policy = policy

#    if boltzmann_temperature is not None:
#      collect_policy = discrete_boltzmann_policy.DiscreteBoltzmannPolicy(
#          policy, temperature=self._boltzmann_temperature)
#    else:
#      collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
#          policy, epsilon=self._epsilon_greedy)
#     policy = greedy_policy.GreedyPolicy(policy)
#    policy = collect_policy

        super(Td3DiscreteAgent,
              self).__init__(time_step_spec,
                             env_action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=2
                             if not self._actor_network.state_spec else None,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)