示例#1
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.
  Simulate a single step in all environments. See more details in `tools.simulate`.

  Args:
    batch_env: In-graph environments object.
    algo_cls: Constructor of a batch algorithm.
    config: Configuration object for the algorithm.

  Returns:
    Object providing graph elements via attributes.
  """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    algo = algo_cls(batch_env, step, is_training, should_log, config)
    should_step = tf.placeholder(tf.bool, name='should_step')
    use_external_action = tf.placeholder(tf.bool, name='use_external_action')
    external_action = tf.placeholder(batch_env.action_info[0],
                                     shape=batch_env.action_info[1],
                                     name='external_action')
    done, score, summary = tools.simulate(batch_env, algo, should_step,
                                          use_external_action, external_action,
                                          should_log, force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())
示例#2
0
def define_simulation_graph(batch_env, algo_cls, config):
    """Define the algortihm and environment interaction.

    Args:
      batch_env: In-graph environments object.
      algo_cls: Constructor of a batch algorithm.
      config: Configuration object for the algorithm.

    Returns:
      Object providing graph elements via attributes.
    """
    # pylint: disable=unused-variable
    step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')
    should_log = tf.placeholder(tf.bool, name='should_log')
    do_report = tf.placeholder(tf.bool, name='do_report')
    force_reset = tf.placeholder(tf.bool, name='force_reset')
    # Extended
    is_optimizing_offense = tf.placeholder(tf.bool,
                                           name='is_optimizing_offense')
    algo = algo_cls(batch_env,
                    step,
                    is_training,
                    should_log,
                    config,
                    is_optimizing_offense=is_optimizing_offense)
    done, score, summary, gail_summary = tools.simulate(
        batch_env, algo, should_log, force_reset)
    message = 'Graph contains {} trainable variables.'
    tf.logging.info(message.format(tools.count_weights()))
    # pylint: enable=unused-variable
    return tools.AttrDict(locals())
示例#3
0
 def test_done_automatic(self):
     batch_env = self._create_test_batch_env((1, 2, 3, 4))
     algo = tools.MockAlgorithm(batch_env)
     done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([True, False, False, False], sess.run(done))
         self.assertAllEqual([True, True, False, False], sess.run(done))
         self.assertAllEqual([True, False, True, False], sess.run(done))
         self.assertAllEqual([True, True, False, True], sess.run(done))
示例#4
0
 def test_reset_automatic(self):
     batch_env = self._create_test_batch_env((1, 2, 3, 4))
     algo = tools.MockAlgorithm(batch_env)
     done, _, _ = tools.simulate(batch_env, algo, log=False, reset=False)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         for _ in range(10):
             sess.run(done)
     self.assertAllEqual([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], batch_env[0].steps)
     self.assertAllEqual([2, 2, 2, 2, 2], batch_env[1].steps)
     self.assertAllEqual([3, 3, 3, 1], batch_env[2].steps)
     self.assertAllEqual([4, 4, 2], batch_env[3].steps)
示例#5
0
 def test_done_forced(self):
     reset = tf.placeholder_with_default(False, ())
     batch_env = self._create_test_batch_env((2, 4))
     algo = tools.MockAlgorithm(batch_env)
     done, _, _ = tools.simulate(batch_env, algo, False, reset)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual([False, False], sess.run(done))
         self.assertAllEqual([False, False], sess.run(done, {reset: True}))
         self.assertAllEqual([True, False], sess.run(done))
         self.assertAllEqual([False, False], sess.run(done, {reset: True}))
         self.assertAllEqual([True, False], sess.run(done))
         self.assertAllEqual([False, False], sess.run(done))
         self.assertAllEqual([True, True], sess.run(done))
示例#6
0
def define_simulation_graph(batch_env, algo_cls, config):
  """Define the algorithm and environment interaction.

  Args:
    batch_env: In-graph environments object.
    algo_cls: Constructor of a batch algorithm.
    config: Configuration object for the algorithm.

  Returns:
    Object providing graph elements via attributes.
  """
  # pylint: disable=unused-variable
  step = tf.Variable(0, False, dtype=tf.int32, name='global_step')
  is_training = tf.placeholder(tf.bool, name='is_training')
  should_log = tf.placeholder(tf.bool, name='should_log')
  do_report = tf.placeholder(tf.bool, name='do_report')
  force_reset = tf.placeholder(tf.bool, name='force_reset')
  algo = algo_cls(batch_env, step, is_training, should_log, config)
  done, score, summary = tools.simulate(
      batch_env, algo, should_log, force_reset)
  message = 'Graph contains {} trainable variables.'
  tf.logging.info(message.format(tools.count_weights()))
  # pylint: enable=unused-variable
  return tools.AttrDict(locals())