Exemplo n.º 1
0
    def wrapped_method(*args, **kwargs):
      """A wrapped method around a TF-Hub module signature."""

      inputs = _getcallargs(self._method_specs[method]["specs"], *args,
                            **kwargs)
      nest.assert_same_structure(self._method_specs[method]["inputs"], inputs)
      flat_inputs = nest.flatten(inputs)
      flat_inputs = {
          str(k): v for k, v in zip(range(len(flat_inputs)), flat_inputs)
      }

      signature = "default" if method == "__call__" else method
      flat_outputs = self._module(
          flat_inputs, signature=signature, as_dict=True)
      flat_outputs = [v for _, v in sorted(flat_outputs.items())]

      output_spec = self._method_specs[method]["outputs"]
      if output_spec is None:
        if len(flat_outputs) != 1:
          raise ValueError(
              "Expected output containing a single tensor, found {}".format(
                  flat_outputs))
        outputs = flat_outputs[0]
      else:
        outputs = nest.unflatten_as(output_spec, flat_outputs)

      return outputs
Exemplo n.º 2
0
  def test_nested_structure(self):
    regular_graph = self._graph
    graph_with_nested_fields = regular_graph.map(
        lambda x: {"a": x, "b": tf.random.uniform([4, 6])})

    nested_structure = [
        None,
        regular_graph,
        (graph_with_nested_fields,),
        tf.random.uniform([10, 6])]
    nested_structure_numpy = utils_tf.nest_to_numpy(nested_structure)

    tree.assert_same_structure(nested_structure, nested_structure_numpy)

    for tensor_or_none, array_or_none in zip(
        tree.flatten(nested_structure),
        tree.flatten(nested_structure_numpy)):
      if tensor_or_none is None:
        self.assertIsNone(array_or_none)
        continue

      self.assertIsNotNone(array_or_none)
      self.assertNDArrayNear(
          tensor_or_none.numpy(),
          array_or_none, 1e-8)
Exemplo n.º 3
0
 def testAssertSameStructure_sameNamedTupleDifferentStructuredContents(
         self):
     with self.assertRaisesRegex(
             ValueError, ("don't have the same nested structure\\.\n\n"
                          "First structure: .*?\n\nSecond structure: ")):
         tree.assert_same_structure(NestTest.Named0ab(3, 4),
                                    NestTest.Named0ab([3], 4))
Exemplo n.º 4
0
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = base.without_apply_rng(
            base.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Exemplo n.º 5
0
def policy_gradient_loss(policies, actions, action_values, policy_vars=None,
                         name="policy_gradient_loss"):
  """Computes policy gradient losses for a batch of trajectories.

  This wraps `policy_gradient` to accept a possibly nested array of `policies`
  and `actions` in order to allow for multiple action distribution types or
  independent multivariate distributions if not directly available. It also sums
  up losses along the time dimension, and is more restrictive about shapes,
  assuming a [T, B] layout for the `batch_shape` of the policies and a
  concatenate(`[T, B]`, `event_shape` of the policies) shape for the actions.

  Args:
    policies: A (possibly nested structure of) distribution(s) supporting
        `batch_shape` and `event_shape` properties along with a `log_prob`
        method (e.g. an instance of `tfp.distributions.Distribution`),
        with `batch_shape` equal to `[T, B]`.
    actions: A (possibly nested structure of) N-D Tensor(s) with shape
        `[T, B, ...]` where the final dimensions are the `event_shape` of the
        corresponding distribution in the nested structure (the shape can be
        just `[T, B]` if the `event_shape` is scalar).
    action_values: Tensor of shape `[T, B]` containing an estimate of the value
        of the selected `actions`.
    policy_vars: An optional (possibly nested structure of) iterable(s) of
        Tensors used by `policies`. If provided is used in scope checks.
    name: Customises the name_scope for this op.

  Returns:
    loss: Tensor of shape `[B]` containing the total loss for each sequence
    in the batch. Differentiable w.r.t `policy_logits` only.
  """
  actions = nest.flatten(actions)
  if policy_vars:
    policy_vars = nest.flatten_up_to(policies, policy_vars)
  else:
    policy_vars = [list()] * len(actions)
  policies = nest.flatten(policies)

  # Check happens after flatten so that we can be more flexible on nest
  # structures. This is equivalent to asserting that `len(policies) ==
  # len(actions)`, which is sufficient for what we're doing here.
  nest.assert_same_structure(policies, actions)

  for policies_, actions_ in zip(policies, actions):
    policies_.batch_shape.assert_has_rank(2)
    actions_.get_shape().assert_is_compatible_with(
        policies_.batch_shape.concatenate(policies_.event_shape))

  scoped_values = policy_vars + actions + [action_values]
  with tf.name_scope(name, values=scoped_values):
    # Loss for the policy gradient. Doesn't push additional gradients through
    # the action_values.
    policy_gradient_loss_sequence = tf.add_n([
        policy_gradient(policies_, actions_, action_values, pvars)
        for policies_, actions_, pvars in zip(policies, actions, policy_vars)])

    return tf.reduce_sum(
        policy_gradient_loss_sequence, axis=[0],
        name="policy_gradient_loss")
Exemplo n.º 6
0
 def mirror_structure(value, reference_tree):
   if tree.is_nested(value):
     # Use check_types=True so that the types of the trees we construct aren't
     # dependent on our arbitrary choice of which nested arg to use as the
     # reference_tree.
     tree.assert_same_structure(value, reference_tree, check_types=True)
     return value
   else:
     return tree.map_structure(lambda _: value, reference_tree)
Exemplo n.º 7
0
    def test_nested_action_spaces(self):
        config = DEFAULT_CONFIG.copy()
        config["env"] = RandomEnv
        # Write output to check, whether actions are written correctly.
        tmp_dir = os.popen("mktemp -d").read()[:-1]
        if not os.path.exists(tmp_dir):
            # Last resort: Resolve via underlying tempdir (and cut tmp_.
            tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
            assert os.path.exists(tmp_dir), f"'{tmp_dir}' not found!"
        config["output"] = tmp_dir
        # Switch off OPE as we don't write action-probs.
        # TODO: We should probably always write those if `output` is given.
        config["input_evaluation"] = []

        # Pretend actions in offline files are already normalized.
        config["actions_in_input_normalized"] = True

        for _ in framework_iterator(config):
            for name, action_space in SPACES.items():
                config["env_config"] = {
                    "action_space": action_space,
                }
                for flatten in [False, True]:
                    print(f"A={action_space} flatten={flatten}")
                    shutil.rmtree(config["output"])
                    config["_disable_action_flattening"] = not flatten
                    trainer = PGTrainer(config)
                    trainer.train()
                    trainer.stop()

                    # Check actions in output file (whether properly flattened
                    # or not).
                    reader = JsonReader(
                        inputs=config["output"],
                        ioctx=trainer.workers.local_worker().io_context,
                    )
                    sample_batch = reader.next()
                    if flatten:
                        assert isinstance(sample_batch["actions"], np.ndarray)
                        assert len(sample_batch["actions"].shape) == 2
                        assert sample_batch["actions"].shape[0] == len(
                            sample_batch)
                    else:
                        tree.assert_same_structure(
                            trainer.get_policy().action_space_struct,
                            sample_batch["actions"],
                        )

                    # Test, whether offline data can be properly read by a
                    # BCTrainer, configured accordingly.
                    config["input"] = config["output"]
                    del config["output"]
                    bc_trainer = BCTrainer(config=config)
                    bc_trainer.train()
                    bc_trainer.stop()
                    config["output"] = tmp_dir
                    config["input"] = "sampler"
Exemplo n.º 8
0
def apply_preprocessors(preprocessors, inputs):
    tree.assert_same_structure(inputs, preprocessors)
    preprocessed_inputs = tree.map_structure(
        lambda preprocessor, input_: (preprocessor(input_)
                                      if preprocessor is not None else input_),
        preprocessors,
        inputs,
    )

    return preprocessed_inputs
Exemplo n.º 9
0
 def testAssertSameStructure_intVsList(self):
     with self.assertRaisesRegex(
             ValueError,
         ("The two structures don't have the same nested structure\\.\n\n"
          "First structure:.*?\n\n"
          "Second structure:.*\n\n"
          r'More specifically: Substructure "type=list str=\[0, 1\]" '
          'is a sequence, while substructure "type=int str=0" '
          "is not")):
         tree.assert_same_structure(0, [0, 1])
Exemplo n.º 10
0
 def testAssertSameStructure_listVsNdArray(self):
     with self.assertRaisesRegex(
             ValueError,
         ("The two structures don't have the same nested structure\\.\n\n"
          "First structure:.*?\n\n"
          "Second structure:.*\n\n"
          r'More specifically: Substructure "type=list str=\[0, 1\]" '
          r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
          "is not")):
         tree.assert_same_structure([0, 1], np.array([0, 1]))
Exemplo n.º 11
0
 def _check_matching_structures(output_tree, bound_tree):
     """Replace all bounds/arrays with True, then compare pytrees."""
     output_struct = tree.traverse(
         lambda x: True
         if isinstance(x, jnp.ndarray) else None, output_tree)
     bound_struct = tree.traverse(
         lambda x: True
         if isinstance(x, bound_propagation.Bound) else None,
         bound_tree)
     tree.assert_same_structure(output_struct, bound_struct)
Exemplo n.º 12
0
    def test_append_returns_same_structure_as_data(self):
        first_step_data = {'x': 1, 'y': 2}
        first_step_ref = self.writer.append(first_step_data)
        tree.assert_same_structure(first_step_data, first_step_ref)

        # Check that this holds true even if the data structure changes between
        # steps.
        second_step_data = {'y': 2, 'z': 3}
        second_step_ref = self.writer.append(second_step_data)
        tree.assert_same_structure(second_step_data, second_step_ref)
def discrete_policy_gradient_loss(policy_logits,
                                  actions,
                                  action_values,
                                  name="discrete_policy_gradient_loss"):
    """Computes discrete policy gradient losses for a batch of trajectories.

  This wraps `discrete_policy_gradient` to accept a possibly nested array of
  `policy_logits` and `actions` in order to allow for multiple discrete actions.
  It also sums up losses along the time dimension, and is more restrictive about
  shapes, assuming a [T, B] layout.

  Args:
    policy_logits: A (possibly nested structure of) Tensor(s) of shape
        `[T, B, num_actions]` containing uncentered log-probabilities.
    actions: A (possibly nested structure of) Tensor(s) of shape
        `[T, B]` and integer type, containing indices for the selected actions.
    action_values: Tensor of shape `[T, B]`
        containing an estimate of the value of the selected `actions`, see
        `discrete_policy_gradient`.
    name: Customises the name_scope for this op.

  Returns:
    loss: Tensor of shape `[B]` containing the total loss for each sequence
    in the batch. Differentiable w.r.t `policy_logits` only.
  """
    policy_logits = nest.flatten(policy_logits)
    actions = nest.flatten(actions)

    # Check happens after flatten so that we can be more flexible on
    # nest structures. This is equivalent to asserting that
    # `len(policy_logits) == len(actions)`, which is sufficient for what we're
    # doing here. In particular, it means that we can allow one argument to be
    # a tensor, while the other one to be a single-element tensor iterable.
    nest.assert_same_structure(policy_logits, actions)
    for scalar_policy_logits in policy_logits:
        scalar_policy_logits.get_shape().assert_has_rank(3)
    for scalar_actions in actions:
        scalar_actions.get_shape().assert_has_rank(2)

    scoped_values = policy_logits + actions + [action_values]
    with tf.name_scope(name, values=scoped_values):
        # Loss for the policy gradient. Doesn't push additional gradients through
        # the action_values.
        policy_gradient_loss_sequence = tf.add_n([
            discrete_policy_gradient(scalar_policy_logits, scalar_actions,
                                     action_values)
            for scalar_policy_logits, scalar_actions in zip(
                policy_logits, actions)
        ])

        return tf.reduce_sum(policy_gradient_loss_sequence,
                             axis=[0],
                             name="policy_gradient_loss")
Exemplo n.º 14
0
 def test_add_batch(self):
     sample_tree = dict(
         a=[jnp.zeros([]), jnp.zeros([1])],
         b=jnp.zeros([1, 1]),
     )
     batch_size = 2
     out = recurrent.add_batch(sample_tree, batch_size)
     tree.assert_same_structure(sample_tree, out)
     flat_in = tree.flatten(sample_tree)
     flat_out = tree.flatten(out)
     for in_array, out_array in zip(flat_in, flat_out):
         self.assertEqual(out_array.shape[0], batch_size)
         self.assertEqual(out_array.shape[1:], in_array.shape)
Exemplo n.º 15
0
 def testAssertSameStructure_differentNumElements(self):
     with self.assertRaisesRegex(
             ValueError,
         ("The two structures don't have the same nested structure\\.\n\n"
          "First structure:.*?\n\n"
          "Second structure:.*\n\n"
          "More specifically: Substructure "
          r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
          'substructure "type=str str=spam" is not\n'
          "Entire first structure:\n"
          r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
          "Entire second structure:\n"
          r"\(\., \.\)")):
         tree.assert_same_structure(STRUCTURE1,
                                    STRUCTURE_DIFFERENT_NUM_ELEMENTS)
Exemplo n.º 16
0
  def test_sample_variable_length_trajectory(self):
    with self._client.trajectory_writer(10) as writer:
      for i in range(10):
        writer.append([np.ones([3, 3], np.int32) * i])
        writer.create_item(TABLE, 1.0, {
            'last': writer.history[0][-1],
            'all': writer.history[0][:],
        })

    dataset = trajectory_dataset.TrajectoryDataset(
        tf.constant(self._client.server_address),
        table=tf.constant(TABLE),
        dtypes={
            'last': tf.int32,
            'all': tf.int32,
        },
        shapes={
            'last': tf.TensorShape([3, 3]),
            'all': tf.TensorShape([None, 3, 3]),
        },
        max_in_flight_samples_per_worker=1,
        flexible_batch_size=1)

    # Continue sample until we have observed all the trajectories.
    seen_lengths = set()
    while len(seen_lengths) < 10:
      sample = self._sample_from(dataset, 1)[0]

      # The structure should always be the same.
      tree.assert_same_structure(
          sample,
          replay_sample.ReplaySample(
              info=replay_sample.SampleInfo(
                  key=1,
                  probability=1.0,
                  table_size=10,
                  priority=0.5,
              ),
              data={
                  'last': None,
                  'all': None
              }))

      seen_lengths.add(sample.data['all'].shape[0])

    self.assertEqual(seen_lengths, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
Exemplo n.º 17
0
    def testMapStructureWithStrings(self):
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        inp_a = ab_tuple(a="foo", b=("bar", "baz"))
        inp_b = ab_tuple(a=2, b=(1, 3))
        out = tree.map_structure(lambda string, repeats: string * repeats,
                                 inp_a, inp_b)
        self.assertEqual("foofoo", out.a)
        self.assertEqual("bar", out.b[0])
        self.assertEqual("bazbazbaz", out.b[1])

        nt = ab_tuple(a=("something", "something_else"), b="yet another thing")
        rev_nt = tree.map_structure(lambda x: x[::-1], nt)
        # Check the output is the correct structure, and all strings are reversed.
        tree.assert_same_structure(nt, rev_nt)
        self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
        self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
        self.assertEqual(nt.b[::-1], rev_nt.b)
Exemplo n.º 18
0
 def test_sync_params(self):
     mock_learner = mock.MagicMock()
     frame_count = 428
     params = self.initial_params
     mock_learner.params_for_actor.return_value = frame_count, params
     traj_len = 10
     actor = actor_lib.Actor(
         agent=self.agent,
         env=self.env,
         learner=mock_learner,
         unroll_length=traj_len,
     )
     received_frame_count, received_params = actor.pull_params()
     self.assertEqual(received_frame_count, frame_count)
     tree.assert_same_structure(received_params, params)
     tree.map_structure(np.testing.assert_array_almost_equal,
                        received_params, params)
Exemplo n.º 19
0
    def test_converts_spec_lists_into_tuples(self):
        for _ in range(10):
            data = [
                (np.ones([1, 1], dtype=np.int32), ),
                [
                    np.ones([3, 3], dtype=np.int8),
                    (np.ones([2, 2], dtype=np.float64), )
                ],
            ]
            self._client.insert(data, {'dist': 1})

        dataset = reverb_dataset.ReplayDataset(
            self._client.server_address,
            table='dist',
            dtypes=[
                (tf.int32, ),
                [
                    tf.int8,
                    (tf.float64, ),
                ],
            ],
            shapes=[
                (tf.TensorShape([1, 1]), ),
                [
                    tf.TensorShape([3, 3]),
                    (tf.TensorShape([2, 2]), ),
                ],
            ],
            max_in_flight_samples_per_worker=100)

        got = self._sample_from(dataset, 10)

        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)
            self.assertIsInstance(sample.info.key, np.uint64)
            tree.assert_same_structure(sample.data, (
                (None, ),
                (
                    None,
                    (None, ),
                ),
            ))
Exemplo n.º 20
0
    def test_sample_fixed_length_trajectory(self):
        self._populate_replay()

        dataset = trajectory_dataset.TrajectoryDataset(
            tf.constant(self._client.server_address),
            table=tf.constant(TABLE),
            dtypes=DTYPES,
            shapes=SHAPES,
            max_in_flight_samples_per_worker=1,
            flexible_batch_size=1)

        tree.assert_same_structure(
            self._sample_from(dataset, 1)[0],
            replay_sample.ReplaySample(info=replay_sample.SampleInfo(
                key=1,
                probability=1.0,
                table_size=10,
                priority=0.5,
            ),
                                       data=SHAPES))
Exemplo n.º 21
0
    def test_stack_sequence_fields(self):
        """Tests that `stack_sequence_fields` behaves correctly on nested data."""

        stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE)

        # Check that the stacked output has the correct structure.
        tree.assert_same_structure(stacked, TEST_SEQUENCE[0])

        # Check that the leaves have the correct array shapes.
        self.assertEqual(stacked['action'].shape, (3, 1))
        self.assertEqual(stacked['observation'][0].shape, (3, 3))
        self.assertEqual(stacked['reward'].shape, (3, ))

        # Check values.
        self.assertEqual(stacked['observation'][0].tolist(), [
            [0., 1., 2.],
            [1., 2., 3.],
            [2., 3., 4.],
        ])
        self.assertEqual(stacked['action'].tolist(), [[1.], [0.5], [0.3]])
        self.assertEqual(stacked['reward'].tolist(), [1., 0., 0.5])
Exemplo n.º 22
0
    def assertConformsToSpec(self, value, spec):
        """Checks that `value` conforms to `spec`.

    Args:
      value: A potentially nested structure of numpy arrays or scalars.
      spec: A potentially nested structure of `specs.Array` instances.
    """
        try:
            tree.assert_same_structure(value, spec)
        except (TypeError, ValueError) as e:
            self.fail(
                "`spec` and `value` have mismatching structures: {}".format(e))

        def validate(path, item, array_spec):
            try:
                array_spec.validate(item)
            except ValueError as e:
                raise ValueError(
                    "Value at path {!r} failed validation: {}.".format(
                        "/".join(map(str, path)), e))

        tree.map_structure_with_path(validate, value, spec)
Exemplo n.º 23
0
 def testAssertSameStructure_listStructureWithAndWithoutTypes(self):
     structure1_list = [[[1, 2], 3], 4, [5, 6]]
     with self.assertRaisesRegex(TypeError,
                                 "don't have the same sequence type"):
         tree.assert_same_structure(STRUCTURE1, structure1_list)
     tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
     tree.assert_same_structure(STRUCTURE1,
                                structure1_list,
                                check_types=False)
Exemplo n.º 24
0
    def test_nested_features(self):
        graph_0 = utils_np.networkxs_to_graphs_tuple(
            [_generate_graph(0, 3),
             _generate_graph(1, 2)])
        graph_1 = utils_np.networkxs_to_graphs_tuple([_generate_graph(2, 2)])
        graph_2 = utils_np.networkxs_to_graphs_tuple([_generate_graph(3, 3)])
        graphs_ = [
            gr.map(tf.convert_to_tensor, graphs.ALL_FIELDS)
            for gr in [graph_0, graph_1, graph_2]
        ]

        def _create_nested_fields(graphs_tuple):
            new_nodes = ({
                "a": graphs_tuple.nodes,
                "b": [graphs_tuple.nodes + 1, graphs_tuple.nodes + 2]
            }, )

            new_edges = [{
                "c": graphs_tuple.edges + 5,
                "d": (graphs_tuple.edges + 1, graphs_tuple.edges + 3),
            }]
            new_globals = []

            return graphs_tuple.replace(nodes=new_nodes,
                                        edges=new_edges,
                                        globals=new_globals)

        graphs_ = [_create_nested_fields(gr) for gr in graphs_]
        concat_graph = utils_tf.concat(graphs_, axis=0)

        actual_nodes = concat_graph.nodes
        actual_edges = concat_graph.edges
        actual_globals = concat_graph.globals

        expected_nodes = tree.map_structure(lambda *x: tf.concat(x, axis=0),
                                            *[gr.nodes for gr in graphs_])
        expected_edges = tree.map_structure(lambda *x: tf.concat(x, axis=0),
                                            *[gr.edges for gr in graphs_])
        expected_globals = tree.map_structure(lambda *x: tf.concat(x, axis=0),
                                              *[gr.globals for gr in graphs_])

        tree.assert_same_structure(expected_nodes, actual_nodes)
        tree.assert_same_structure(expected_edges, actual_edges)
        tree.assert_same_structure(expected_globals, actual_globals)

        tree.map_structure(self.assertAllEqual, expected_nodes, actual_nodes)
        tree.map_structure(self.assertAllEqual, expected_edges, actual_edges)
        tree.map_structure(self.assertAllEqual, expected_globals,
                           actual_globals)

        # Borrowed from `test_concat_first_axis`:
        self.assertAllEqual(np.array([3, 2, 2, 3]), concat_graph.n_node)
        self.assertAllEqual(np.array([2, 1, 1, 2]), concat_graph.n_edge)
        self.assertAllEqual(np.array([1, 2, 4, 6, 8, 9]), concat_graph.senders)
        self.assertAllEqual(np.array([0, 0, 3, 5, 7, 7]),
                            concat_graph.receivers)
Exemplo n.º 25
0
    def training_iteration(self) -> ResultDict:
        # Trigger asynchronous rollouts on all RolloutWorkers.
        # - Rollout results are sent directly to correct replay buffer
        #   shards, instead of here (to the driver).
        with self._timers[SAMPLE_TIMER]:
            sample_results = asynchronous_parallel_requests(
                remote_requests_in_flight=self.remote_requests_in_flight,
                actors=self.workers.remote_workers()
                or [self.workers.local_worker()],
                ray_wait_timeout_s=self.config["sample_wait_timeout"],
                max_remote_requests_in_flight_per_actor=2,
                remote_fn=self._sample_and_send_to_buffer,
            )
        # Update sample counters.
        for sample_result in sample_results.values():
            for (env_steps, agent_steps) in sample_result:
                self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps
                self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps

        # Trigger asynchronous training update requests on all learning
        # policies.
        with self._timers[LEARN_ON_BATCH_TIMER]:
            pol_actors = []
            args = []
            for pid, pol_actor, repl_actor in self.distributed_learners:
                pol_actors.append(pol_actor)
                args.append([repl_actor, pid])
            train_results = asynchronous_parallel_requests(
                remote_requests_in_flight=self.remote_requests_in_flight,
                actors=pol_actors,
                ray_wait_timeout_s=self.config["learn_wait_timeout"],
                max_remote_requests_in_flight_per_actor=2,
                remote_fn=self._update_policy,
                remote_args=args,
            )

        # Update sample counters.
        for train_result in train_results.values():
            for result in train_result:
                if NUM_AGENT_STEPS_TRAINED in result:
                    self._counters[NUM_AGENT_STEPS_TRAINED] += result[
                        NUM_AGENT_STEPS_TRAINED]

        # For those policies that have been updated in this iteration
        # (not all policies may have undergone an updated as we are
        # requesting updates asynchronously):
        # - Gather train infos.
        # - Update weights to those remote rollout workers that contain
        #   the respective policy.
        with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
            train_infos = {}
            policy_weights = {}
            for pol_actor, policy_results in train_results.items():
                results_have_same_structure = True
                for result1, result2 in zip(policy_results,
                                            policy_results[1:]):
                    try:
                        tree.assert_same_structure(result1, result2)
                    except (ValueError, TypeError):
                        results_have_same_structure = False
                        break
                if len(policy_results) > 1 and results_have_same_structure:
                    policy_result = tree.map_structure(
                        lambda *_args: sum(_args) / len(policy_results),
                        *policy_results)
                else:
                    policy_result = policy_results[-1]
                if policy_result:
                    pid = self.distributed_learners.get_policy_id(pol_actor)
                    train_infos[pid] = policy_result
                    policy_weights[pid] = pol_actor.get_weights.remote()

            policy_weights_ref = ray.put(policy_weights)

            global_vars = {
                "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
                "league_builder": self.league_builder.__getstate__(),
            }

            for worker in self.workers.remote_workers():
                worker.set_weights.remote(policy_weights_ref, global_vars)

        return train_infos
Exemplo n.º 26
0
  def __init__(self,
               server_address: Union[str, tf.Tensor],
               table: Union[str, tf.Tensor],
               dtypes: Any,
               shapes: Any,
               max_in_flight_samples_per_worker: int,
               num_workers_per_iterator: int = -1,
               max_samples_per_stream: int = -1,
               rate_limiter_timeout_ms: int = -1,
               flexible_batch_size: int = -1):
    """Constructs a new TimestepDataset.

    Args:
      server_address: Address of gRPC ReverbService.
      table: Probability table to sample from.
      dtypes: Dtypes of the data output. Can be nested.
      shapes: Shapes of the data output. Can be nested.
      max_in_flight_samples_per_worker: The number of samples requested in each
        batch of samples. Higher values give higher throughput but too big
        values can result in skewed sampling distributions as large number of
        samples are fetched from single snapshot of the replay (followed by a
        period of lower activity as the samples are consumed). A good rule of
        thumb is to set this value to 2-3x times the batch size used.
      num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number
        of worker threads to create per dataset iterator. When the selected
        table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be
        used to avoid races causing invalid ordering of items. For all other
        samplers, this value should be roughly equal to the number of threads
        available on the CPU.
      max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum
        number of samples to fetch from a stream before a new call is made.
        Keeping this number low ensures that the data is fetched uniformly from
        all server.
      rate_limiter_timeout_ms: (Defaults to -1: infinite).  Timeout (in
        milliseconds) to wait on the rate limiter when sampling from the table.
        If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to
        `Table::Sample` describing how long to wait for the rate limiter to
          allow sampling. The first time that a request times out (across any of
          the workers), the Dataset iterator is closed and the sequence is
          considered finished.
      flexible_batch_size: (Defaults to -1: auto selected) The maximum number of
        items to sampled from `Table` with single call. Values > 1 enables
        `Table::SampleFlexibleBatch` to return more than one item (but no more
          than `flexible_batch_size`) in a single call without releasing the
          table lock iff the rate limiter allows it. NOTE! It is unlikely that
          you need to tune this value yourself. The auto selected value should
          almost always be preferred. Larger `flexible_batch_size` values result
          a bias towards sampling over inserts. In highly overloaded systems
          this results in higher sample QPS and lower insert QPS compared to
          lower `flexible_batch_size` values.

    Raises:
      ValueError: If `dtypes` and `shapes` don't share the same structure.
      ValueError: If `max_in_flight_samples_per_worker` is not a
        positive integer.
      ValueError: If `num_workers_per_iterator` is not a positive integer or -1.
      ValueError: If `max_samples_per_stream` is not a positive integer or -1.
      ValueError: If `rate_limiter_timeout_ms < -1`.
      ValueError: If `flexible_batch_size` is not a positive integer or -1.
    """
    tree.assert_same_structure(dtypes, shapes, False)
    if max_in_flight_samples_per_worker < 1:
      raise ValueError(
          'max_in_flight_samples_per_worker (%d) must be a positive integer' %
          max_in_flight_samples_per_worker)
    if num_workers_per_iterator < 1 and num_workers_per_iterator != -1:
      raise ValueError(
          'num_workers_per_iterator (%d) must be a positive integer or -1' %
          num_workers_per_iterator)
    if max_samples_per_stream < 1 and max_samples_per_stream != -1:
      raise ValueError(
          'max_samples_per_stream (%d) must be a positive integer or -1' %
          max_samples_per_stream)
    if rate_limiter_timeout_ms < -1:
      raise ValueError('rate_limiter_timeout_ms (%d) must be an integer >= -1' %
                       rate_limiter_timeout_ms)
    if flexible_batch_size < 1 and flexible_batch_size != -1:
      raise ValueError(
          'flexible_batch_size (%d) must be a positive integer or -1' %
          flexible_batch_size)

    # Add the info fields (all scalars).
    dtypes = replay_sample.ReplaySample(
        info=replay_sample.SampleInfo.tf_dtypes(), data=dtypes)
    shapes = replay_sample.ReplaySample(
        info=replay_sample.SampleInfo(
            key=tf.TensorShape([]),
            probability=tf.TensorShape([]),
            table_size=tf.TensorShape([]),
            priority=tf.TensorShape([])),
        data=shapes)

    # The tf.data API doesn't fully support lists so we convert all uses of
    # lists into tuples.
    dtypes = _convert_lists_to_tuples(dtypes)
    shapes = _convert_lists_to_tuples(shapes)

    self._server_address = server_address
    self._table = table
    self._dtypes = dtypes
    self._shapes = shapes
    self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker
    self._num_workers_per_iterator = num_workers_per_iterator
    self._max_samples_per_stream = max_samples_per_stream
    self._rate_limiter_timeout_ms = rate_limiter_timeout_ms
    self._flexible_batch_size = flexible_batch_size

    if _is_tf1_runtime():
      # Disabling to avoid errors given the different tf.data.Dataset init args
      # between v1 and v2 APIs.
      # pytype: disable=wrong-arg-count
      super().__init__()
    else:
      # DatasetV2 requires the dataset as a variant tensor during init.
      super().__init__(self._as_variant_tensor())
Exemplo n.º 27
0
def _validate_spec(spec: types.NestedSpec, value: types.NestedArray):
    """Validate a value from a potentially nested spec."""
    tree.assert_same_structure(value, spec)
    tree.map_structure(lambda s, v: s.validate(v), spec, value)
Exemplo n.º 28
0
    def append(self, data: Any, *, partial_step: bool = False):
        """Columnwise append of data leaf nodes to internal buffers.

    If `data` includes fields or sub structures which haven't been present in
    any previous calls then the types and shapes of the new fields are extracted
    and used to validate future `append` calls. The structure of `history` is
    also updated to include the union of the structure across all `append`
    calls.

    When new fields are added after the first step then the newly created
    history field will be filled with `None` in all preceding positions. This
    results in the equal indexing across columns. That is `a[i]` and `b[i]`
    references the same step in the sequence even if `b` was first observed
    after `a` had already been seen.

    It is possible to create a "step" using more than one `append` call by
    setting the `partial_step` flag. Partial steps can be used when some parts
    of the step becomes available only as a result of inserting (and learning
    from) trajectories that include the fields available first (e.g learn from
    the SARS trajectory to select the next action in an on-policy agent). In the
    final `append` call of the step, `partial_step` must be set to False.
    Failing to "close" the partial step will result in error as the same field
    must NOT be provided more than once in the same step.

    Args:
      data: The (possibly nested) structure to make available for new items to
        reference.
      partial_step: If `True` then the step is not considered "done" with this
        call. See above for more details. Defaults to `False`.

    Returns:
      References to the data structured just like provided `data`.

    Raises:
      ValueError: If the same column is provided more than once in the same
        step.
    """
        # Unless it is the first step, check that the structure is the same.
        if self._structure is None:
            self._update_structure(tree.map_structure(lambda _: None, data))

        try:
            tree.assert_same_structure(data, self._structure, True)
            expanded_data = data
        except ValueError:
            try:
                # If `data` is a subset of the full spec then we can simply fill in the
                # gaps with None.
                expanded_data = _tree_merge_into(source=data,
                                                 target=self._structure)
            except ValueError:
                # `data` contains fields which haven't been observed before so we need
                # expand the spec using the union of the history and `data`.
                self._update_structure(
                    _tree_union(self._structure,
                                tree.map_structure(lambda x: None, data)))

                # Now that the structure has been updated to include all the fields in
                # `data` we are able to expand `data` to the full structure. Note that
                # if `data` is a superset of the previous history structure then this
                # "expansion" is just a noop.
                expanded_data = _tree_merge_into(data, self._structure)

        # Use our custom mapping to flatten the expanded structure into columns.
        flat_column_data = self._flatten(expanded_data)

        # If the last step is still open then verify that already populated columns
        # are None in the new `data`.
        if self._last_step_is_open:
            for i, (column, column_data) in enumerate(
                    zip(self._column_history, flat_column_data)):
                if column_data is None or column.can_set_last:
                    continue

                raise ValueError(
                    f'Field {self._get_path_for_column_index(i)} has already been set '
                    f'in the active step by previous (partial) append call and thus '
                    f'must be omitted or set to None but got: {column_data}')

        # Flatten the data and pass it to the C++ writer for column wise append. In
        # all columns where data is provided (i.e not None) will return a reference
        # to the data (`pybind.WeakCellRef`) which is used to define trajectories
        # for `create_item`. The columns which did not receive a value (i.e None)
        # will return None.
        if partial_step:
            flat_column_data_references = self._writer.AppendPartial(
                flat_column_data)
        else:
            flat_column_data_references = self._writer.Append(flat_column_data)

        # Append references to respective columns. Note that we use the expanded
        # structure in order to populate the columns missing from the data with
        # None.
        for column, data_reference in zip(self._column_history,
                                          flat_column_data_references):
            # If the last step is still open (i.e `partial_step` was set) then we
            # populate that step instead of creating a new one.
            if not self._last_step_is_open:
                column.append(data_reference)
            elif data_reference is not None:
                column.set_last(data_reference)

        # Save the flag so the next `append` call either populates the same step
        # or begins a new step.
        self._last_step_is_open = partial_step

        # Unpack the column data into the expanded structure.
        expanded_structured_data_references = self._unflatten(
            flat_column_data_references)

        # Return the referenced structured in the same way as `data`. If only a
        # subset of the fields were present in the input data then only these fields
        # will exist in the output.
        return _tree_filter(expanded_structured_data_references, data)
Exemplo n.º 29
0
def update(state: RunningStatisticsState,
           batch: types.NestedArray,
           *,
           config: NestStatisticsConfig = NestStatisticsConfig(),
           weights: Optional[jnp.ndarray] = None,
           std_min_value: float = 1e-6,
           std_max_value: float = 1e6,
           pmap_axis_name: Optional[str] = None,
           validate_shapes: bool = True) -> RunningStatisticsState:
    """Updates the running statistics with the given batch of data.

  Note: data batch and state elements (mean, etc.) must have the same structure.

  Note: by default will use int32 for counts and float32 for accumulated
  variance. This results in an integer overflow after 2^31 data points and
  degrading precision after 2^24 batch updates or even earlier if variance
  updates have large dynamic range.
  To improve precision, consider setting jax_enable_x64 to True, see
  https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

  Arguments:
    state: The running statistics before the update.
    batch: The data to be used to update the running statistics.
    config: The config that specifies which leaves of the nested structure
      should the running statistics be computed for.
    weights: Weights of the batch data. Should match the batch dimensions.
      Passing a weight of 2. should be equivalent to updating on the
      corresponding data point twice.
    std_min_value: Minimum value for the standard deviation.
    std_max_value: Maximum value for the standard deviation.
    pmap_axis_name: Name of the pmapped axis, if any.
    validate_shapes: If true, the shapes of all leaves of the batch will be
      validated. Enabled by default. Doesn't impact performance when jitted.

  Returns:
    Updated running statistics.
  """
    # We require exactly the same structure to avoid issues when flattened
    # batch and state have different order of elements.
    tree.assert_same_structure(batch, state.mean)
    batch_shape = tree.flatten(batch)[0].shape
    # We assume the batch dimensions always go first.
    batch_dims = batch_shape[:len(batch_shape) -
                             tree.flatten(state.mean)[0].ndim]
    batch_axis = range(len(batch_dims))
    if weights is None:
        step_increment = np.prod(batch_dims)
    else:
        step_increment = jnp.sum(weights)
    if pmap_axis_name is not None:
        step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name)
    count = state.count + step_increment

    # Validation is important. If the shapes don't match exactly, but are
    # compatible, arrays will be silently broadcasted resulting in incorrect
    # statistics.
    if validate_shapes:
        if weights is not None:
            if weights.shape != batch_dims:
                raise ValueError(f'{weights.shape} != {batch_dims}')
        _validate_batch_shapes(batch, state.mean, batch_dims)

    def _compute_node_statistics(
            path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray,
            batch: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        assert isinstance(mean, jnp.ndarray), type(mean)
        assert isinstance(summed_variance, jnp.ndarray), type(summed_variance)
        if not _is_path_included(config, path):
            # Return unchanged.
            return mean, summed_variance
        # The mean and the sum of past variances are updated with Welford's
        # algorithm using batches (see https://stackoverflow.com/q/56402955).
        diff_to_old_mean = batch - mean
        if weights is not None:
            expanded_weights = jnp.reshape(
                weights,
                list(weights.shape) + [1] * (batch.ndim - weights.ndim))
            diff_to_old_mean = diff_to_old_mean * expanded_weights
        mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count
        if pmap_axis_name is not None:
            mean_update = jax.lax.psum(mean_update, axis_name=pmap_axis_name)
        mean = mean + mean_update

        diff_to_new_mean = batch - mean
        variance_update = diff_to_old_mean * diff_to_new_mean
        variance_update = jnp.sum(variance_update, axis=batch_axis)
        if pmap_axis_name is not None:
            variance_update = jax.lax.psum(variance_update,
                                           axis_name=pmap_axis_name)
        summed_variance = summed_variance + variance_update
        return mean, summed_variance

    updated_stats = tree_utils.fast_map_structure_with_path(
        _compute_node_statistics, state.mean, state.summed_variance, batch)
    # map_structure_up_to is slow, so shortcut if we know the input is not
    # structured.
    if isinstance(state.mean, jnp.ndarray):
        mean, summed_variance = updated_stats
    else:
        # Reshape the updated stats from `nest(mean, summed_variance)` to
        # `nest(mean), nest(summed_variance)`.
        mean, summed_variance = [
            tree.map_structure_up_to(state.mean,
                                     lambda s, i=idx: s[i],
                                     updated_stats) for idx in range(2)
        ]

    def compute_std(path: Path, summed_variance: jnp.ndarray,
                    std: jnp.ndarray) -> jnp.ndarray:
        assert isinstance(summed_variance, jnp.ndarray)
        if not _is_path_included(config, path):
            return std
        # Summed variance can get negative due to rounding errors.
        summed_variance = jnp.maximum(summed_variance, 0)
        std = jnp.sqrt(summed_variance / count)
        std = jnp.clip(std, std_min_value, std_max_value)
        return std

    std = tree_utils.fast_map_structure_with_path(compute_std, summed_variance,
                                                  state.std)

    return RunningStatisticsState(count=count,
                                  mean=mean,
                                  summed_variance=summed_variance,
                                  std=std)
Exemplo n.º 30
0
    def training_step(self) -> ResultDict:
        # Trigger asynchronous rollouts on all RolloutWorkers.
        # - Rollout results are sent directly to correct replay buffer
        #   shards, instead of here (to the driver).
        with self._timers[SAMPLE_TIMER]:
            # if there are no remote workers (e.g. num_workers=0)
            if not self.workers.remote_workers():
                worker = self.workers.local_worker()
                statistics = worker.apply(self._sample_and_send_to_buffer)
                sample_results = {worker: [statistics]}
            else:
                self._sampling_actor_manager.call_on_all_available(
                    self._sample_and_send_to_buffer)
                sample_results = self._sampling_actor_manager.get_ready()
        # Update sample counters.
        for sample_result in sample_results.values():
            for (env_steps, agent_steps) in sample_result:
                self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps
                self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps

        # Trigger asynchronous training update requests on all learning
        # policies.
        with self._timers[LEARN_ON_BATCH_TIMER]:
            for pid, pol_actor, repl_actor in self.distributed_learners:
                if pol_actor not in self._learner_worker_manager.workers:
                    self._learner_worker_manager.add_workers(pol_actor)
                self._learner_worker_manager.call(self._update_policy,
                                                  actor=pol_actor,
                                                  fn_args=[repl_actor, pid])
            train_results = self._learner_worker_manager.get_ready()

        # Update sample counters.
        for train_result in train_results.values():
            for result in train_result:
                if NUM_AGENT_STEPS_TRAINED in result:
                    self._counters[NUM_AGENT_STEPS_TRAINED] += result[
                        NUM_AGENT_STEPS_TRAINED]

        # For those policies that have been updated in this iteration
        # (not all policies may have undergone an updated as we are
        # requesting updates asynchronously):
        # - Gather train infos.
        # - Update weights to those remote rollout workers that contain
        #   the respective policy.
        with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
            train_infos = {}
            policy_weights = {}
            for pol_actor, policy_results in train_results.items():
                results_have_same_structure = True
                for result1, result2 in zip(policy_results,
                                            policy_results[1:]):
                    try:
                        tree.assert_same_structure(result1, result2)
                    except (ValueError, TypeError):
                        results_have_same_structure = False
                        break
                if len(policy_results) > 1 and results_have_same_structure:
                    policy_result = tree.map_structure(
                        lambda *_args: sum(_args) / len(policy_results),
                        *policy_results)
                else:
                    policy_result = policy_results[-1]
                if policy_result:
                    pid = self.distributed_learners.get_policy_id(pol_actor)
                    train_infos[pid] = policy_result
                    policy_weights[pid] = pol_actor.get_weights.remote()

            policy_weights_ref = ray.put(policy_weights)

            global_vars = {
                "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
                "league_builder": self.league_builder.__getstate__(),
            }

            for worker in self.workers.remote_workers():
                worker.set_weights.remote(policy_weights_ref, global_vars)

        return train_infos