def wrapped_method(*args, **kwargs): """A wrapped method around a TF-Hub module signature.""" inputs = _getcallargs(self._method_specs[method]["specs"], *args, **kwargs) nest.assert_same_structure(self._method_specs[method]["inputs"], inputs) flat_inputs = nest.flatten(inputs) flat_inputs = { str(k): v for k, v in zip(range(len(flat_inputs)), flat_inputs) } signature = "default" if method == "__call__" else method flat_outputs = self._module( flat_inputs, signature=signature, as_dict=True) flat_outputs = [v for _, v in sorted(flat_outputs.items())] output_spec = self._method_specs[method]["outputs"] if output_spec is None: if len(flat_outputs) != 1: raise ValueError( "Expected output containing a single tensor, found {}".format( flat_outputs)) outputs = flat_outputs[0] else: outputs = nest.unflatten_as(output_spec, flat_outputs) return outputs
def test_nested_structure(self): regular_graph = self._graph graph_with_nested_fields = regular_graph.map( lambda x: {"a": x, "b": tf.random.uniform([4, 6])}) nested_structure = [ None, regular_graph, (graph_with_nested_fields,), tf.random.uniform([10, 6])] nested_structure_numpy = utils_tf.nest_to_numpy(nested_structure) tree.assert_same_structure(nested_structure, nested_structure_numpy) for tensor_or_none, array_or_none in zip( tree.flatten(nested_structure), tree.flatten(nested_structure_numpy)): if tensor_or_none is None: self.assertIsNone(array_or_none) continue self.assertIsNotNone(array_or_none) self.assertNDArrayNear( tensor_or_none.numpy(), array_or_none, 1e-8)
def testAssertSameStructure_sameNamedTupleDifferentStructuredContents( self): with self.assertRaisesRegex( ValueError, ("don't have the same nested structure\\.\n\n" "First structure: .*?\n\nSecond structure: ")): tree.assert_same_structure(NestTest.Named0ab(3, 4), NestTest.Named0ab([3], 4))
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = base.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = base.without_apply_rng( base.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def policy_gradient_loss(policies, actions, action_values, policy_vars=None, name="policy_gradient_loss"): """Computes policy gradient losses for a batch of trajectories. This wraps `policy_gradient` to accept a possibly nested array of `policies` and `actions` in order to allow for multiple action distribution types or independent multivariate distributions if not directly available. It also sums up losses along the time dimension, and is more restrictive about shapes, assuming a [T, B] layout for the `batch_shape` of the policies and a concatenate(`[T, B]`, `event_shape` of the policies) shape for the actions. Args: policies: A (possibly nested structure of) distribution(s) supporting `batch_shape` and `event_shape` properties along with a `log_prob` method (e.g. an instance of `tfp.distributions.Distribution`), with `batch_shape` equal to `[T, B]`. actions: A (possibly nested structure of) N-D Tensor(s) with shape `[T, B, ...]` where the final dimensions are the `event_shape` of the corresponding distribution in the nested structure (the shape can be just `[T, B]` if the `event_shape` is scalar). action_values: Tensor of shape `[T, B]` containing an estimate of the value of the selected `actions`. policy_vars: An optional (possibly nested structure of) iterable(s) of Tensors used by `policies`. If provided is used in scope checks. name: Customises the name_scope for this op. Returns: loss: Tensor of shape `[B]` containing the total loss for each sequence in the batch. Differentiable w.r.t `policy_logits` only. """ actions = nest.flatten(actions) if policy_vars: policy_vars = nest.flatten_up_to(policies, policy_vars) else: policy_vars = [list()] * len(actions) policies = nest.flatten(policies) # Check happens after flatten so that we can be more flexible on nest # structures. This is equivalent to asserting that `len(policies) == # len(actions)`, which is sufficient for what we're doing here. nest.assert_same_structure(policies, actions) for policies_, actions_ in zip(policies, actions): policies_.batch_shape.assert_has_rank(2) actions_.get_shape().assert_is_compatible_with( policies_.batch_shape.concatenate(policies_.event_shape)) scoped_values = policy_vars + actions + [action_values] with tf.name_scope(name, values=scoped_values): # Loss for the policy gradient. Doesn't push additional gradients through # the action_values. policy_gradient_loss_sequence = tf.add_n([ policy_gradient(policies_, actions_, action_values, pvars) for policies_, actions_, pvars in zip(policies, actions, policy_vars)]) return tf.reduce_sum( policy_gradient_loss_sequence, axis=[0], name="policy_gradient_loss")
def mirror_structure(value, reference_tree): if tree.is_nested(value): # Use check_types=True so that the types of the trees we construct aren't # dependent on our arbitrary choice of which nested arg to use as the # reference_tree. tree.assert_same_structure(value, reference_tree, check_types=True) return value else: return tree.map_structure(lambda _: value, reference_tree)
def test_nested_action_spaces(self): config = DEFAULT_CONFIG.copy() config["env"] = RandomEnv # Write output to check, whether actions are written correctly. tmp_dir = os.popen("mktemp -d").read()[:-1] if not os.path.exists(tmp_dir): # Last resort: Resolve via underlying tempdir (and cut tmp_. tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:] assert os.path.exists(tmp_dir), f"'{tmp_dir}' not found!" config["output"] = tmp_dir # Switch off OPE as we don't write action-probs. # TODO: We should probably always write those if `output` is given. config["input_evaluation"] = [] # Pretend actions in offline files are already normalized. config["actions_in_input_normalized"] = True for _ in framework_iterator(config): for name, action_space in SPACES.items(): config["env_config"] = { "action_space": action_space, } for flatten in [False, True]: print(f"A={action_space} flatten={flatten}") shutil.rmtree(config["output"]) config["_disable_action_flattening"] = not flatten trainer = PGTrainer(config) trainer.train() trainer.stop() # Check actions in output file (whether properly flattened # or not). reader = JsonReader( inputs=config["output"], ioctx=trainer.workers.local_worker().io_context, ) sample_batch = reader.next() if flatten: assert isinstance(sample_batch["actions"], np.ndarray) assert len(sample_batch["actions"].shape) == 2 assert sample_batch["actions"].shape[0] == len( sample_batch) else: tree.assert_same_structure( trainer.get_policy().action_space_struct, sample_batch["actions"], ) # Test, whether offline data can be properly read by a # BCTrainer, configured accordingly. config["input"] = config["output"] del config["output"] bc_trainer = BCTrainer(config=config) bc_trainer.train() bc_trainer.stop() config["output"] = tmp_dir config["input"] = "sampler"
def apply_preprocessors(preprocessors, inputs): tree.assert_same_structure(inputs, preprocessors) preprocessed_inputs = tree.map_structure( lambda preprocessor, input_: (preprocessor(input_) if preprocessor is not None else input_), preprocessors, inputs, ) return preprocessed_inputs
def testAssertSameStructure_intVsList(self): with self.assertRaisesRegex( ValueError, ("The two structures don't have the same nested structure\\.\n\n" "First structure:.*?\n\n" "Second structure:.*\n\n" r'More specifically: Substructure "type=list str=\[0, 1\]" ' 'is a sequence, while substructure "type=int str=0" ' "is not")): tree.assert_same_structure(0, [0, 1])
def testAssertSameStructure_listVsNdArray(self): with self.assertRaisesRegex( ValueError, ("The two structures don't have the same nested structure\\.\n\n" "First structure:.*?\n\n" "Second structure:.*\n\n" r'More specifically: Substructure "type=list str=\[0, 1\]" ' r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' "is not")): tree.assert_same_structure([0, 1], np.array([0, 1]))
def _check_matching_structures(output_tree, bound_tree): """Replace all bounds/arrays with True, then compare pytrees.""" output_struct = tree.traverse( lambda x: True if isinstance(x, jnp.ndarray) else None, output_tree) bound_struct = tree.traverse( lambda x: True if isinstance(x, bound_propagation.Bound) else None, bound_tree) tree.assert_same_structure(output_struct, bound_struct)
def test_append_returns_same_structure_as_data(self): first_step_data = {'x': 1, 'y': 2} first_step_ref = self.writer.append(first_step_data) tree.assert_same_structure(first_step_data, first_step_ref) # Check that this holds true even if the data structure changes between # steps. second_step_data = {'y': 2, 'z': 3} second_step_ref = self.writer.append(second_step_data) tree.assert_same_structure(second_step_data, second_step_ref)
def discrete_policy_gradient_loss(policy_logits, actions, action_values, name="discrete_policy_gradient_loss"): """Computes discrete policy gradient losses for a batch of trajectories. This wraps `discrete_policy_gradient` to accept a possibly nested array of `policy_logits` and `actions` in order to allow for multiple discrete actions. It also sums up losses along the time dimension, and is more restrictive about shapes, assuming a [T, B] layout. Args: policy_logits: A (possibly nested structure of) Tensor(s) of shape `[T, B, num_actions]` containing uncentered log-probabilities. actions: A (possibly nested structure of) Tensor(s) of shape `[T, B]` and integer type, containing indices for the selected actions. action_values: Tensor of shape `[T, B]` containing an estimate of the value of the selected `actions`, see `discrete_policy_gradient`. name: Customises the name_scope for this op. Returns: loss: Tensor of shape `[B]` containing the total loss for each sequence in the batch. Differentiable w.r.t `policy_logits` only. """ policy_logits = nest.flatten(policy_logits) actions = nest.flatten(actions) # Check happens after flatten so that we can be more flexible on # nest structures. This is equivalent to asserting that # `len(policy_logits) == len(actions)`, which is sufficient for what we're # doing here. In particular, it means that we can allow one argument to be # a tensor, while the other one to be a single-element tensor iterable. nest.assert_same_structure(policy_logits, actions) for scalar_policy_logits in policy_logits: scalar_policy_logits.get_shape().assert_has_rank(3) for scalar_actions in actions: scalar_actions.get_shape().assert_has_rank(2) scoped_values = policy_logits + actions + [action_values] with tf.name_scope(name, values=scoped_values): # Loss for the policy gradient. Doesn't push additional gradients through # the action_values. policy_gradient_loss_sequence = tf.add_n([ discrete_policy_gradient(scalar_policy_logits, scalar_actions, action_values) for scalar_policy_logits, scalar_actions in zip( policy_logits, actions) ]) return tf.reduce_sum(policy_gradient_loss_sequence, axis=[0], name="policy_gradient_loss")
def test_add_batch(self): sample_tree = dict( a=[jnp.zeros([]), jnp.zeros([1])], b=jnp.zeros([1, 1]), ) batch_size = 2 out = recurrent.add_batch(sample_tree, batch_size) tree.assert_same_structure(sample_tree, out) flat_in = tree.flatten(sample_tree) flat_out = tree.flatten(out) for in_array, out_array in zip(flat_in, flat_out): self.assertEqual(out_array.shape[0], batch_size) self.assertEqual(out_array.shape[1:], in_array.shape)
def testAssertSameStructure_differentNumElements(self): with self.assertRaisesRegex( ValueError, ("The two structures don't have the same nested structure\\.\n\n" "First structure:.*?\n\n" "Second structure:.*\n\n" "More specifically: Substructure " r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' 'substructure "type=str str=spam" is not\n' "Entire first structure:\n" r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" "Entire second structure:\n" r"\(\., \.\)")): tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS)
def test_sample_variable_length_trajectory(self): with self._client.trajectory_writer(10) as writer: for i in range(10): writer.append([np.ones([3, 3], np.int32) * i]) writer.create_item(TABLE, 1.0, { 'last': writer.history[0][-1], 'all': writer.history[0][:], }) dataset = trajectory_dataset.TrajectoryDataset( tf.constant(self._client.server_address), table=tf.constant(TABLE), dtypes={ 'last': tf.int32, 'all': tf.int32, }, shapes={ 'last': tf.TensorShape([3, 3]), 'all': tf.TensorShape([None, 3, 3]), }, max_in_flight_samples_per_worker=1, flexible_batch_size=1) # Continue sample until we have observed all the trajectories. seen_lengths = set() while len(seen_lengths) < 10: sample = self._sample_from(dataset, 1)[0] # The structure should always be the same. tree.assert_same_structure( sample, replay_sample.ReplaySample( info=replay_sample.SampleInfo( key=1, probability=1.0, table_size=10, priority=0.5, ), data={ 'last': None, 'all': None })) seen_lengths.add(sample.data['all'].shape[0]) self.assertEqual(seen_lengths, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
def testMapStructureWithStrings(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") inp_a = ab_tuple(a="foo", b=("bar", "baz")) inp_b = ab_tuple(a=2, b=(1, 3)) out = tree.map_structure(lambda string, repeats: string * repeats, inp_a, inp_b) self.assertEqual("foofoo", out.a) self.assertEqual("bar", out.b[0]) self.assertEqual("bazbazbaz", out.b[1]) nt = ab_tuple(a=("something", "something_else"), b="yet another thing") rev_nt = tree.map_structure(lambda x: x[::-1], nt) # Check the output is the correct structure, and all strings are reversed. tree.assert_same_structure(nt, rev_nt) self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) self.assertEqual(nt.b[::-1], rev_nt.b)
def test_sync_params(self): mock_learner = mock.MagicMock() frame_count = 428 params = self.initial_params mock_learner.params_for_actor.return_value = frame_count, params traj_len = 10 actor = actor_lib.Actor( agent=self.agent, env=self.env, learner=mock_learner, unroll_length=traj_len, ) received_frame_count, received_params = actor.pull_params() self.assertEqual(received_frame_count, frame_count) tree.assert_same_structure(received_params, params) tree.map_structure(np.testing.assert_array_almost_equal, received_params, params)
def test_converts_spec_lists_into_tuples(self): for _ in range(10): data = [ (np.ones([1, 1], dtype=np.int32), ), [ np.ones([3, 3], dtype=np.int8), (np.ones([2, 2], dtype=np.float64), ) ], ] self._client.insert(data, {'dist': 1}) dataset = reverb_dataset.ReplayDataset( self._client.server_address, table='dist', dtypes=[ (tf.int32, ), [ tf.int8, (tf.float64, ), ], ], shapes=[ (tf.TensorShape([1, 1]), ), [ tf.TensorShape([3, 3]), (tf.TensorShape([2, 2]), ), ], ], max_in_flight_samples_per_worker=100) got = self._sample_from(dataset, 10) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) self.assertIsInstance(sample.info.key, np.uint64) tree.assert_same_structure(sample.data, ( (None, ), ( None, (None, ), ), ))
def test_sample_fixed_length_trajectory(self): self._populate_replay() dataset = trajectory_dataset.TrajectoryDataset( tf.constant(self._client.server_address), table=tf.constant(TABLE), dtypes=DTYPES, shapes=SHAPES, max_in_flight_samples_per_worker=1, flexible_batch_size=1) tree.assert_same_structure( self._sample_from(dataset, 1)[0], replay_sample.ReplaySample(info=replay_sample.SampleInfo( key=1, probability=1.0, table_size=10, priority=0.5, ), data=SHAPES))
def test_stack_sequence_fields(self): """Tests that `stack_sequence_fields` behaves correctly on nested data.""" stacked = tree_utils.stack_sequence_fields(TEST_SEQUENCE) # Check that the stacked output has the correct structure. tree.assert_same_structure(stacked, TEST_SEQUENCE[0]) # Check that the leaves have the correct array shapes. self.assertEqual(stacked['action'].shape, (3, 1)) self.assertEqual(stacked['observation'][0].shape, (3, 3)) self.assertEqual(stacked['reward'].shape, (3, )) # Check values. self.assertEqual(stacked['observation'][0].tolist(), [ [0., 1., 2.], [1., 2., 3.], [2., 3., 4.], ]) self.assertEqual(stacked['action'].tolist(), [[1.], [0.5], [0.3]]) self.assertEqual(stacked['reward'].tolist(), [1., 0., 0.5])
def assertConformsToSpec(self, value, spec): """Checks that `value` conforms to `spec`. Args: value: A potentially nested structure of numpy arrays or scalars. spec: A potentially nested structure of `specs.Array` instances. """ try: tree.assert_same_structure(value, spec) except (TypeError, ValueError) as e: self.fail( "`spec` and `value` have mismatching structures: {}".format(e)) def validate(path, item, array_spec): try: array_spec.validate(item) except ValueError as e: raise ValueError( "Value at path {!r} failed validation: {}.".format( "/".join(map(str, path)), e)) tree.map_structure_with_path(validate, value, spec)
def testAssertSameStructure_listStructureWithAndWithoutTypes(self): structure1_list = [[[1, 2], 3], 4, [5, 6]] with self.assertRaisesRegex(TypeError, "don't have the same sequence type"): tree.assert_same_structure(STRUCTURE1, structure1_list) tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False) tree.assert_same_structure(STRUCTURE1, structure1_list, check_types=False)
def test_nested_features(self): graph_0 = utils_np.networkxs_to_graphs_tuple( [_generate_graph(0, 3), _generate_graph(1, 2)]) graph_1 = utils_np.networkxs_to_graphs_tuple([_generate_graph(2, 2)]) graph_2 = utils_np.networkxs_to_graphs_tuple([_generate_graph(3, 3)]) graphs_ = [ gr.map(tf.convert_to_tensor, graphs.ALL_FIELDS) for gr in [graph_0, graph_1, graph_2] ] def _create_nested_fields(graphs_tuple): new_nodes = ({ "a": graphs_tuple.nodes, "b": [graphs_tuple.nodes + 1, graphs_tuple.nodes + 2] }, ) new_edges = [{ "c": graphs_tuple.edges + 5, "d": (graphs_tuple.edges + 1, graphs_tuple.edges + 3), }] new_globals = [] return graphs_tuple.replace(nodes=new_nodes, edges=new_edges, globals=new_globals) graphs_ = [_create_nested_fields(gr) for gr in graphs_] concat_graph = utils_tf.concat(graphs_, axis=0) actual_nodes = concat_graph.nodes actual_edges = concat_graph.edges actual_globals = concat_graph.globals expected_nodes = tree.map_structure(lambda *x: tf.concat(x, axis=0), *[gr.nodes for gr in graphs_]) expected_edges = tree.map_structure(lambda *x: tf.concat(x, axis=0), *[gr.edges for gr in graphs_]) expected_globals = tree.map_structure(lambda *x: tf.concat(x, axis=0), *[gr.globals for gr in graphs_]) tree.assert_same_structure(expected_nodes, actual_nodes) tree.assert_same_structure(expected_edges, actual_edges) tree.assert_same_structure(expected_globals, actual_globals) tree.map_structure(self.assertAllEqual, expected_nodes, actual_nodes) tree.map_structure(self.assertAllEqual, expected_edges, actual_edges) tree.map_structure(self.assertAllEqual, expected_globals, actual_globals) # Borrowed from `test_concat_first_axis`: self.assertAllEqual(np.array([3, 2, 2, 3]), concat_graph.n_node) self.assertAllEqual(np.array([2, 1, 1, 2]), concat_graph.n_edge) self.assertAllEqual(np.array([1, 2, 4, 6, 8, 9]), concat_graph.senders) self.assertAllEqual(np.array([0, 0, 3, 5, 7, 7]), concat_graph.receivers)
def training_iteration(self) -> ResultDict: # Trigger asynchronous rollouts on all RolloutWorkers. # - Rollout results are sent directly to correct replay buffer # shards, instead of here (to the driver). with self._timers[SAMPLE_TIMER]: sample_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers() or [self.workers.local_worker()], ray_wait_timeout_s=self.config["sample_wait_timeout"], max_remote_requests_in_flight_per_actor=2, remote_fn=self._sample_and_send_to_buffer, ) # Update sample counters. for sample_result in sample_results.values(): for (env_steps, agent_steps) in sample_result: self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps # Trigger asynchronous training update requests on all learning # policies. with self._timers[LEARN_ON_BATCH_TIMER]: pol_actors = [] args = [] for pid, pol_actor, repl_actor in self.distributed_learners: pol_actors.append(pol_actor) args.append([repl_actor, pid]) train_results = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=pol_actors, ray_wait_timeout_s=self.config["learn_wait_timeout"], max_remote_requests_in_flight_per_actor=2, remote_fn=self._update_policy, remote_args=args, ) # Update sample counters. for train_result in train_results.values(): for result in train_result: if NUM_AGENT_STEPS_TRAINED in result: self._counters[NUM_AGENT_STEPS_TRAINED] += result[ NUM_AGENT_STEPS_TRAINED] # For those policies that have been updated in this iteration # (not all policies may have undergone an updated as we are # requesting updates asynchronously): # - Gather train infos. # - Update weights to those remote rollout workers that contain # the respective policy. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: train_infos = {} policy_weights = {} for pol_actor, policy_results in train_results.items(): results_have_same_structure = True for result1, result2 in zip(policy_results, policy_results[1:]): try: tree.assert_same_structure(result1, result2) except (ValueError, TypeError): results_have_same_structure = False break if len(policy_results) > 1 and results_have_same_structure: policy_result = tree.map_structure( lambda *_args: sum(_args) / len(policy_results), *policy_results) else: policy_result = policy_results[-1] if policy_result: pid = self.distributed_learners.get_policy_id(pol_actor) train_infos[pid] = policy_result policy_weights[pid] = pol_actor.get_weights.remote() policy_weights_ref = ray.put(policy_weights) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], "league_builder": self.league_builder.__getstate__(), } for worker in self.workers.remote_workers(): worker.set_weights.remote(policy_weights_ref, global_vars) return train_infos
def __init__(self, server_address: Union[str, tf.Tensor], table: Union[str, tf.Tensor], dtypes: Any, shapes: Any, max_in_flight_samples_per_worker: int, num_workers_per_iterator: int = -1, max_samples_per_stream: int = -1, rate_limiter_timeout_ms: int = -1, flexible_batch_size: int = -1): """Constructs a new TimestepDataset. Args: server_address: Address of gRPC ReverbService. table: Probability table to sample from. dtypes: Dtypes of the data output. Can be nested. shapes: Shapes of the data output. Can be nested. max_in_flight_samples_per_worker: The number of samples requested in each batch of samples. Higher values give higher throughput but too big values can result in skewed sampling distributions as large number of samples are fetched from single snapshot of the replay (followed by a period of lower activity as the samples are consumed). A good rule of thumb is to set this value to 2-3x times the batch size used. num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number of worker threads to create per dataset iterator. When the selected table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be used to avoid races causing invalid ordering of items. For all other samplers, this value should be roughly equal to the number of threads available on the CPU. max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum number of samples to fetch from a stream before a new call is made. Keeping this number low ensures that the data is fetched uniformly from all server. rate_limiter_timeout_ms: (Defaults to -1: infinite). Timeout (in milliseconds) to wait on the rate limiter when sampling from the table. If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to `Table::Sample` describing how long to wait for the rate limiter to allow sampling. The first time that a request times out (across any of the workers), the Dataset iterator is closed and the sequence is considered finished. flexible_batch_size: (Defaults to -1: auto selected) The maximum number of items to sampled from `Table` with single call. Values > 1 enables `Table::SampleFlexibleBatch` to return more than one item (but no more than `flexible_batch_size`) in a single call without releasing the table lock iff the rate limiter allows it. NOTE! It is unlikely that you need to tune this value yourself. The auto selected value should almost always be preferred. Larger `flexible_batch_size` values result a bias towards sampling over inserts. In highly overloaded systems this results in higher sample QPS and lower insert QPS compared to lower `flexible_batch_size` values. Raises: ValueError: If `dtypes` and `shapes` don't share the same structure. ValueError: If `max_in_flight_samples_per_worker` is not a positive integer. ValueError: If `num_workers_per_iterator` is not a positive integer or -1. ValueError: If `max_samples_per_stream` is not a positive integer or -1. ValueError: If `rate_limiter_timeout_ms < -1`. ValueError: If `flexible_batch_size` is not a positive integer or -1. """ tree.assert_same_structure(dtypes, shapes, False) if max_in_flight_samples_per_worker < 1: raise ValueError( 'max_in_flight_samples_per_worker (%d) must be a positive integer' % max_in_flight_samples_per_worker) if num_workers_per_iterator < 1 and num_workers_per_iterator != -1: raise ValueError( 'num_workers_per_iterator (%d) must be a positive integer or -1' % num_workers_per_iterator) if max_samples_per_stream < 1 and max_samples_per_stream != -1: raise ValueError( 'max_samples_per_stream (%d) must be a positive integer or -1' % max_samples_per_stream) if rate_limiter_timeout_ms < -1: raise ValueError('rate_limiter_timeout_ms (%d) must be an integer >= -1' % rate_limiter_timeout_ms) if flexible_batch_size < 1 and flexible_batch_size != -1: raise ValueError( 'flexible_batch_size (%d) must be a positive integer or -1' % flexible_batch_size) # Add the info fields (all scalars). dtypes = replay_sample.ReplaySample( info=replay_sample.SampleInfo.tf_dtypes(), data=dtypes) shapes = replay_sample.ReplaySample( info=replay_sample.SampleInfo( key=tf.TensorShape([]), probability=tf.TensorShape([]), table_size=tf.TensorShape([]), priority=tf.TensorShape([])), data=shapes) # The tf.data API doesn't fully support lists so we convert all uses of # lists into tuples. dtypes = _convert_lists_to_tuples(dtypes) shapes = _convert_lists_to_tuples(shapes) self._server_address = server_address self._table = table self._dtypes = dtypes self._shapes = shapes self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker self._num_workers_per_iterator = num_workers_per_iterator self._max_samples_per_stream = max_samples_per_stream self._rate_limiter_timeout_ms = rate_limiter_timeout_ms self._flexible_batch_size = flexible_batch_size if _is_tf1_runtime(): # Disabling to avoid errors given the different tf.data.Dataset init args # between v1 and v2 APIs. # pytype: disable=wrong-arg-count super().__init__() else: # DatasetV2 requires the dataset as a variant tensor during init. super().__init__(self._as_variant_tensor())
def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): """Validate a value from a potentially nested spec.""" tree.assert_same_structure(value, spec) tree.map_structure(lambda s, v: s.validate(v), spec, value)
def append(self, data: Any, *, partial_step: bool = False): """Columnwise append of data leaf nodes to internal buffers. If `data` includes fields or sub structures which haven't been present in any previous calls then the types and shapes of the new fields are extracted and used to validate future `append` calls. The structure of `history` is also updated to include the union of the structure across all `append` calls. When new fields are added after the first step then the newly created history field will be filled with `None` in all preceding positions. This results in the equal indexing across columns. That is `a[i]` and `b[i]` references the same step in the sequence even if `b` was first observed after `a` had already been seen. It is possible to create a "step" using more than one `append` call by setting the `partial_step` flag. Partial steps can be used when some parts of the step becomes available only as a result of inserting (and learning from) trajectories that include the fields available first (e.g learn from the SARS trajectory to select the next action in an on-policy agent). In the final `append` call of the step, `partial_step` must be set to False. Failing to "close" the partial step will result in error as the same field must NOT be provided more than once in the same step. Args: data: The (possibly nested) structure to make available for new items to reference. partial_step: If `True` then the step is not considered "done" with this call. See above for more details. Defaults to `False`. Returns: References to the data structured just like provided `data`. Raises: ValueError: If the same column is provided more than once in the same step. """ # Unless it is the first step, check that the structure is the same. if self._structure is None: self._update_structure(tree.map_structure(lambda _: None, data)) try: tree.assert_same_structure(data, self._structure, True) expanded_data = data except ValueError: try: # If `data` is a subset of the full spec then we can simply fill in the # gaps with None. expanded_data = _tree_merge_into(source=data, target=self._structure) except ValueError: # `data` contains fields which haven't been observed before so we need # expand the spec using the union of the history and `data`. self._update_structure( _tree_union(self._structure, tree.map_structure(lambda x: None, data))) # Now that the structure has been updated to include all the fields in # `data` we are able to expand `data` to the full structure. Note that # if `data` is a superset of the previous history structure then this # "expansion" is just a noop. expanded_data = _tree_merge_into(data, self._structure) # Use our custom mapping to flatten the expanded structure into columns. flat_column_data = self._flatten(expanded_data) # If the last step is still open then verify that already populated columns # are None in the new `data`. if self._last_step_is_open: for i, (column, column_data) in enumerate( zip(self._column_history, flat_column_data)): if column_data is None or column.can_set_last: continue raise ValueError( f'Field {self._get_path_for_column_index(i)} has already been set ' f'in the active step by previous (partial) append call and thus ' f'must be omitted or set to None but got: {column_data}') # Flatten the data and pass it to the C++ writer for column wise append. In # all columns where data is provided (i.e not None) will return a reference # to the data (`pybind.WeakCellRef`) which is used to define trajectories # for `create_item`. The columns which did not receive a value (i.e None) # will return None. if partial_step: flat_column_data_references = self._writer.AppendPartial( flat_column_data) else: flat_column_data_references = self._writer.Append(flat_column_data) # Append references to respective columns. Note that we use the expanded # structure in order to populate the columns missing from the data with # None. for column, data_reference in zip(self._column_history, flat_column_data_references): # If the last step is still open (i.e `partial_step` was set) then we # populate that step instead of creating a new one. if not self._last_step_is_open: column.append(data_reference) elif data_reference is not None: column.set_last(data_reference) # Save the flag so the next `append` call either populates the same step # or begins a new step. self._last_step_is_open = partial_step # Unpack the column data into the expanded structure. expanded_structured_data_references = self._unflatten( flat_column_data_references) # Return the referenced structured in the same way as `data`. If only a # subset of the fields were present in the input data then only these fields # will exist in the output. return _tree_filter(expanded_structured_data_references, data)
def update(state: RunningStatisticsState, batch: types.NestedArray, *, config: NestStatisticsConfig = NestStatisticsConfig(), weights: Optional[jnp.ndarray] = None, std_min_value: float = 1e-6, std_max_value: float = 1e6, pmap_axis_name: Optional[str] = None, validate_shapes: bool = True) -> RunningStatisticsState: """Updates the running statistics with the given batch of data. Note: data batch and state elements (mean, etc.) must have the same structure. Note: by default will use int32 for counts and float32 for accumulated variance. This results in an integer overflow after 2^31 data points and degrading precision after 2^24 batch updates or even earlier if variance updates have large dynamic range. To improve precision, consider setting jax_enable_x64 to True, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision Arguments: state: The running statistics before the update. batch: The data to be used to update the running statistics. config: The config that specifies which leaves of the nested structure should the running statistics be computed for. weights: Weights of the batch data. Should match the batch dimensions. Passing a weight of 2. should be equivalent to updating on the corresponding data point twice. std_min_value: Minimum value for the standard deviation. std_max_value: Maximum value for the standard deviation. pmap_axis_name: Name of the pmapped axis, if any. validate_shapes: If true, the shapes of all leaves of the batch will be validated. Enabled by default. Doesn't impact performance when jitted. Returns: Updated running statistics. """ # We require exactly the same structure to avoid issues when flattened # batch and state have different order of elements. tree.assert_same_structure(batch, state.mean) batch_shape = tree.flatten(batch)[0].shape # We assume the batch dimensions always go first. batch_dims = batch_shape[:len(batch_shape) - tree.flatten(state.mean)[0].ndim] batch_axis = range(len(batch_dims)) if weights is None: step_increment = np.prod(batch_dims) else: step_increment = jnp.sum(weights) if pmap_axis_name is not None: step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name) count = state.count + step_increment # Validation is important. If the shapes don't match exactly, but are # compatible, arrays will be silently broadcasted resulting in incorrect # statistics. if validate_shapes: if weights is not None: if weights.shape != batch_dims: raise ValueError(f'{weights.shape} != {batch_dims}') _validate_batch_shapes(batch, state.mean, batch_dims) def _compute_node_statistics( path: Path, mean: jnp.ndarray, summed_variance: jnp.ndarray, batch: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: assert isinstance(mean, jnp.ndarray), type(mean) assert isinstance(summed_variance, jnp.ndarray), type(summed_variance) if not _is_path_included(config, path): # Return unchanged. return mean, summed_variance # The mean and the sum of past variances are updated with Welford's # algorithm using batches (see https://stackoverflow.com/q/56402955). diff_to_old_mean = batch - mean if weights is not None: expanded_weights = jnp.reshape( weights, list(weights.shape) + [1] * (batch.ndim - weights.ndim)) diff_to_old_mean = diff_to_old_mean * expanded_weights mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count if pmap_axis_name is not None: mean_update = jax.lax.psum(mean_update, axis_name=pmap_axis_name) mean = mean + mean_update diff_to_new_mean = batch - mean variance_update = diff_to_old_mean * diff_to_new_mean variance_update = jnp.sum(variance_update, axis=batch_axis) if pmap_axis_name is not None: variance_update = jax.lax.psum(variance_update, axis_name=pmap_axis_name) summed_variance = summed_variance + variance_update return mean, summed_variance updated_stats = tree_utils.fast_map_structure_with_path( _compute_node_statistics, state.mean, state.summed_variance, batch) # map_structure_up_to is slow, so shortcut if we know the input is not # structured. if isinstance(state.mean, jnp.ndarray): mean, summed_variance = updated_stats else: # Reshape the updated stats from `nest(mean, summed_variance)` to # `nest(mean), nest(summed_variance)`. mean, summed_variance = [ tree.map_structure_up_to(state.mean, lambda s, i=idx: s[i], updated_stats) for idx in range(2) ] def compute_std(path: Path, summed_variance: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray: assert isinstance(summed_variance, jnp.ndarray) if not _is_path_included(config, path): return std # Summed variance can get negative due to rounding errors. summed_variance = jnp.maximum(summed_variance, 0) std = jnp.sqrt(summed_variance / count) std = jnp.clip(std, std_min_value, std_max_value) return std std = tree_utils.fast_map_structure_with_path(compute_std, summed_variance, state.std) return RunningStatisticsState(count=count, mean=mean, summed_variance=summed_variance, std=std)
def training_step(self) -> ResultDict: # Trigger asynchronous rollouts on all RolloutWorkers. # - Rollout results are sent directly to correct replay buffer # shards, instead of here (to the driver). with self._timers[SAMPLE_TIMER]: # if there are no remote workers (e.g. num_workers=0) if not self.workers.remote_workers(): worker = self.workers.local_worker() statistics = worker.apply(self._sample_and_send_to_buffer) sample_results = {worker: [statistics]} else: self._sampling_actor_manager.call_on_all_available( self._sample_and_send_to_buffer) sample_results = self._sampling_actor_manager.get_ready() # Update sample counters. for sample_result in sample_results.values(): for (env_steps, agent_steps) in sample_result: self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps # Trigger asynchronous training update requests on all learning # policies. with self._timers[LEARN_ON_BATCH_TIMER]: for pid, pol_actor, repl_actor in self.distributed_learners: if pol_actor not in self._learner_worker_manager.workers: self._learner_worker_manager.add_workers(pol_actor) self._learner_worker_manager.call(self._update_policy, actor=pol_actor, fn_args=[repl_actor, pid]) train_results = self._learner_worker_manager.get_ready() # Update sample counters. for train_result in train_results.values(): for result in train_result: if NUM_AGENT_STEPS_TRAINED in result: self._counters[NUM_AGENT_STEPS_TRAINED] += result[ NUM_AGENT_STEPS_TRAINED] # For those policies that have been updated in this iteration # (not all policies may have undergone an updated as we are # requesting updates asynchronously): # - Gather train infos. # - Update weights to those remote rollout workers that contain # the respective policy. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: train_infos = {} policy_weights = {} for pol_actor, policy_results in train_results.items(): results_have_same_structure = True for result1, result2 in zip(policy_results, policy_results[1:]): try: tree.assert_same_structure(result1, result2) except (ValueError, TypeError): results_have_same_structure = False break if len(policy_results) > 1 and results_have_same_structure: policy_result = tree.map_structure( lambda *_args: sum(_args) / len(policy_results), *policy_results) else: policy_result = policy_results[-1] if policy_result: pid = self.distributed_learners.get_policy_id(pol_actor) train_infos[pid] = policy_result policy_weights[pid] = pol_actor.get_weights.remote() policy_weights_ref = ray.put(policy_weights) global_vars = { "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], "league_builder": self.league_builder.__getstate__(), } for worker in self.workers.remote_workers(): worker.set_weights.remote(policy_weights_ref, global_vars) return train_infos