def convert_to_non_torch_type(stats): """Converts values in `stats` to non-Tensor numpy or python types. Args: stats (any): Any (possibly nested) struct, the values in which will be converted and returned as a new struct with all torch tensors being converted to numpy types. Returns: Any: A new struct with the same structure as `stats`, but with all values converted to non-torch Tensor types. """ # The mapping function used to numpyize torch Tensors. def mapping(item): if isinstance(item, torch.Tensor): return item.cpu().item() if len(item.size()) == 0 else \ item.detach().cpu().numpy() else: return item return tree.map_structure(mapping, stats)
def true_fn(): batch_size = 1 req = force_tuple( action_dist.required_model_output_shape( self.action_space, self.model.model_config)) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = tf.shape(action_dist.inputs)[0] # Function to produce random samples from primitive space # components: (Multi)Discrete or Box. def random_component(component): if isinstance(component, Discrete): return tf.random.uniform(shape=(batch_size, ) + component.shape, maxval=component.n, dtype=component.dtype) elif isinstance(component, MultiDiscrete): return tf.random.uniform(shape=(batch_size, ) + component.shape, maxval=component.nvec, dtype=component.dtype) elif isinstance(component, Box): if component.bounded_above.all() and \ component.bounded_below.all(): return tf.random.uniform(shape=(batch_size, ) + component.shape, minval=component.low, maxval=component.high, dtype=component.dtype) else: return tf.random.normal(shape=(batch_size, ) + component.shape, dtype=component.dtype) actions = tree.map_structure(random_component, self.action_space_struct) return actions
def _write_last(self): # Maybe determine the delta to the next time we would write a sequence. if self._end_of_episode_behavior in (EndBehavior.TRUNCATE, EndBehavior.ZERO_PAD): delta = self._sequence_length - self._writer.episode_steps if delta < 0: delta = (self._period + delta) % self._period # Handle various end-of-episode cases. if self._end_of_episode_behavior is EndBehavior.CONTINUE: self._maybe_create_item(self._sequence_length, end_of_episode=True) elif self._end_of_episode_behavior is EndBehavior.WRITE: # Drop episodes that are too short. if self._writer.episode_steps < self._sequence_length: return self._maybe_create_item(self._sequence_length, end_of_episode=True, force=True) elif self._end_of_episode_behavior is EndBehavior.TRUNCATE: self._maybe_create_item(self._sequence_length - delta, end_of_episode=True, force=True) elif self._end_of_episode_behavior is EndBehavior.ZERO_PAD: zero_step = tree.map_structure( lambda x: np.zeros_like(x[-2].numpy()), self._writer.history) for _ in range(delta): self._writer.append(zero_step) self._maybe_create_item(self._sequence_length, end_of_episode=True, force=True) else: raise ValueError( f'Unhandled end of episode behavior: {self._end_of_episode_behavior}.' ' This should never happen, please contact Acme dev team.')
def _replicated_step(self): # Update target network online_variables = ( *self._observation_network.variables, *self._critic_network.variables, *self._policy_network.variables, ) target_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, *self._target_policy_network.variables, ) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. sample = next(self._iterator) # This mirrors the structure of the fetches returned by self._step(), # but the Tensors are replaced with replicated Tensors, one per accelerator. replicated_fetches = self._replicator.run(self._step, args=(sample,)) def reduce_mean_over_replicas(replicated_value): """Averages a replicated_value across replicas.""" # The "axis=None" arg means reduce across replicas, not internal axes. return self._replicator.reduce( reduce_op=tf.distribute.ReduceOp.MEAN, value=replicated_value, axis=None) fetches = tree.map_structure(reduce_mean_over_replicas, replicated_fetches) return fetches
def add(self, action: types.NestedArray, next_timestep: dm_env.TimeStep, extras: types.NestedArray = ()): """Record an action and the following timestep.""" try: history = self._writer.history except RuntimeError: raise ValueError( 'adder.add_first must be called before adder.add.') # Add the timestep to the buffer. current_step = dict( # Observation was passed at the previous add call. action=action, reward=next_timestep.reward, discount=next_timestep.discount, # Start of episode indicator was passed at the previous add call. **({ 'extras': extras } if extras else {})) self._writer.append(current_step) # Record the next observation and write. self._writer.append(dict(observation=next_timestep.observation, start_of_episode=next_timestep.first()), partial_step=True) self._write() if next_timestep.last(): # Complete the row by appending zeros to remaining open fields. # TODO(b/183945808): remove this when fields are no longer expected to be # of equal length on the learner side. dummy_step = tree.map_structure(np.zeros_like, current_step) self._writer.append(dummy_step) self._write_last() self.reset()
def __call__(self, x: TensorStructType, update: bool = True) -> \ TensorStructType: if self.no_preprocessor: x = tree.map_structure(lambda x_: np.asarray(x_), x) else: x = np.asarray(x) def _helper(x, rs, buffer, shape): # Discrete|MultiDiscrete spaces -> No normalization. if shape is None: return x # Keep dtype as is througout this filter. orig_dtype = x.dtype if update: if len(x.shape) == len(rs.shape) + 1: # The vectorized case. for i in range(x.shape[0]): rs.push(x[i]) buffer.push(x[i]) else: # The unvectorized case. rs.push(x) buffer.push(x) if self.demean: x = x - rs.mean if self.destd: x = x / (rs.std + SMALL_NUMBER) if self.clip: x = np.clip(x, -self.clip, self.clip) return x.astype(orig_dtype) if self.no_preprocessor: return tree.map_structure_up_to(x, _helper, x, self.rs, self.buffer, self.shape) else: return _helper(x, self.rs, self.buffer, self.shape)
def test_ema_on_changing_data(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2)(x) init_fn, apply_fn = transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # ema_params should be different from changed params! tree.assert_same_structure(changed_params, ema_params) for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)): self.assertEqual(p1.shape, p2.shape) with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"): np.testing.assert_allclose(p1, p2, atol=1e-6)
def test_make_dataset_with_batch_size(self): batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size) def make_tensor_spec(spec): return tf.TensorSpec(shape=(None, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) expected_spec = adders.Step(observation=expected_spec.observations, action=expected_spec.actions, reward=expected_spec.rewards, discount=expected_spec.discounts, start_of_episode=specs.Array( shape=(batch_size, ), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def create_no_nodefeatures_graphs(features, topology): comp_to_id = {'R': 1, 'L': 2, 'C': 3, 'V': 4} comp_to_value = {'R': 50, 'L': 1e-3, 'C': 1e-6, 'V': 0} graphs_list = [] for circuit in range(len(features.keys())): circuit_features = features["circuit_{}".format(circuit + 1)] circuit_topology = topology["circuit_{}".format(circuit + 1)] nodes = [] edges = [] senders = [] receivers = [] for sender, receiver in circuit_topology.items(): for i in range(len(receiver)): senders.append(float(sender)) receivers.append(float(receiver[i][1])) edges.append([float(comp_to_id[receiver[i][0][0]]), float(comp_to_value[receiver[i][0][0]])]) maximum = 0 for i in range(len(circuit_topology.keys())): nodes.append([0.0]) for i in range(len(nodes)): while len(nodes[i]) < 7: nodes[i].append(0.0) graph = { "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers, "globals": [0.0, 0.0, 0.0] } graphs_list.append(graph) graphs_tuple = utils_np.data_dicts_to_graphs_tuple(graphs_list) graphs_tuple = tree.map_structure(lambda x: tf.constant(x) if x is not None else None, graphs_tuple) return graphs_tuple
def last_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: """Returns the last action for the specified AgentID, or zeros. The "last" action is the most recent one taken by the agent. Args: agent_id: The agent's ID to get the last action for. Returns: Last action the specified AgentID has executed. Zeros in case the agent has never performed any actions in the episode. """ policy_id = self.policy_for(agent_id) policy = self.policy_map[policy_id] # Agent has already taken at least one action in the episode. if agent_id in self._agent_to_last_action: if policy.config.get("_disable_action_flattening"): return self._agent_to_last_action[agent_id] else: return flatten_to_single_ndarray( self._agent_to_last_action[agent_id]) # Agent has not acted yet, return all zeros. else: if policy.config.get("_disable_action_flattening"): return tree.map_structure( lambda s: np.zeros_like(s.sample(), s.dtype) if hasattr(s, "dtype") else np.zeros_like(s.sample()), policy.action_space_struct, ) else: flat = flatten_to_single_ndarray(policy.action_space.sample()) if hasattr(policy.action_space, "dtype"): return np.zeros_like(flat, dtype=policy.action_space.dtype) return np.zeros_like(flat)
def _get_input_signature(module: snt.Module): """Get module input signature. Works even if the module with signature is wrapper into snt.Sequentual or snt.DeepRNN. Args: module: the module which input signature we need to get. The module has to either have input_signature itself (i.e. you have to run create_variables on the module), or it has to be a module (with input_signature) wrapped in (one or multiple) snt.Sequential or snt.DeepRNNs. Returns: Input signature of the module or None if it's not available. """ if hasattr(module, '_input_signature'): return module._input_signature # pylint: disable=protected-access if isinstance(module, snt.Sequential): first_layer = module._layers[0] # pylint: disable=protected-access return _get_input_signature(first_layer) if isinstance(module, snt.DeepRNN): first_layer = module._layers[0] # pylint: disable=protected-access input_signature = _get_input_signature(first_layer) # Wrapping a module in DeepRNN changes its state shape, so we need to bring # it up to date. state = module.initial_state(1) input_signature[-1] = tree.map_structure( lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state) return input_signature # If we get here we can't determine the input signature. So give up. raise ValueError('module instance has no input_signature attribute; run ' 'create_variables to add this annotation.')
def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any: """Convert all the components of the element to match the space dtypes. Args: element: The element to be converted. sampled_element: An element sampled from a space to be matched to. Returns: The input element, but with all its components converted to match the space dtypes. """ def map_(elem, s): if isinstance(s, np.ndarray): if not isinstance(elem, np.ndarray): assert isinstance( elem, (float, int) ), f"ERROR: `elem` ({elem}) must be np.array, float or int!" if s.shape == (): elem = np.array(elem, dtype=s.dtype) else: raise ValueError( "Element should be of type np.ndarray but is instead of \ type {}".format(type(elem))) elif s.dtype != elem.dtype: elem = elem.astype(s.dtype) elif isinstance(s, int): if isinstance(elem, float) and elem.is_integer(): elem = int(elem) return elem return tree.map_structure(map_, element, sampled_element, check_types=False)
def logp(self, x): if isinstance(x, np.ndarray): x = torch.Tensor(x) # Single tensor input (all merged). if isinstance(x, torch.Tensor): split_indices = [] for dist in self.flat_child_distributions: if isinstance(dist, TorchCategorical): split_indices.append(1) elif isinstance(dist, TorchMultiCategorical) and \ dist.action_space is not None: split_indices.append(int(np.prod(dist.action_space.shape))) else: sample = dist.sample() # Cover Box(shape=()) case. if len(sample.shape) == 1: split_indices.append(1) else: split_indices.append(sample.size()[1]) split_x = list(torch.split(x, split_indices, dim=1)) # Structured or flattened (by single action component) input. else: split_x = tree.flatten(x) def map_(val, dist): # Remove extra categorical dimension. if isinstance(dist, TorchCategorical): val = (torch.squeeze(val, dim=-1) if len(val.shape) > 1 else val).int() return dist.logp(val) # Remove extra categorical dimension and take the logp of each # component. flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) return functools.reduce(lambda a, b: a + b, flat_logps)
def _write_last(self): # Create a final step. final_step = utils.final_step_like(self._buffer[0], self._next_observation) # Append the final step. self._buffer.append(final_step) self._writer.append(final_step) self._step += 1 # Determine the delta to the next time we would write a sequence. first_write = self._step <= self._max_sequence_length if first_write: delta = self._max_sequence_length - self._step else: delta = (self._period - (self._step - self._max_sequence_length)) % self._period # Bump up to the position where we will write a sequence. self._step += delta if self._pad_end_of_episode: zero_step = tree.map_structure(utils.zeros_like, final_step) # Pad with zeros to get a full sequence. for _ in range(delta): self._buffer.append(zero_step) self._writer.append(zero_step) elif not first_write: # Pop items from the buffer to get a truncated sequence. # Note: this is consistent with the padding loop above, since adding zero # steps pops the left-most elements. Here we just pop without padding. for _ in range(delta): self._buffer.popleft() # Write priorities for the sequence. self._maybe_add_priorities()
def unsquash_action(action, action_space_struct): """Unsquashes all components in `action` according to the given Space. Inverse of `normalize_action()`. Useful for mapping policy action outputs (normalized between -1.0 and 1.0) to an env's action space. Unsquashing results in cont. action component values between the given Space's bounds (`low` and `high`). This only applies to Box components within the action space, whose dtype is float32 or float64. Args: action (Any): The action to be unsquashed. This could be any complex action, e.g. a dict or tuple. action_space_struct (Any): The action space struct, e.g. `{"a": Box()}` for a space: Dict({"a": Box()}). Returns: Any: The input action, but unsquashed, according to the space's bounds. An unsquashed action is ready to be sent to the environment (`BaseEnv.send_actions([unsquashed actions])`). """ def map_(a, s): if (isinstance(s, gym.spaces.Box) and np.all(s.bounded_below) and np.all(s.bounded_above)): if s.dtype == np.float32 or s.dtype == np.float64: # Assuming values are roughly between -1.0 and 1.0 -> # unsquash them to the given bounds. a = s.low + (a + 1.0) * (s.high - s.low) / 2.0 # Clip to given bounds, just in case the squashed values were # outside [-1.0, 1.0]. a = np.clip(a, s.low, s.high) elif np.issubdtype(s.dtype, np.integer): # For Categorical and MultiCategorical actions, shift the selection # into the proper range. a = s.low + a return a return tree.map_structure(map_, action, action_space_struct)
def _build_sarsa_example(sequences): """Convert raw sequences into a Reverb n-step SARSA sample.""" o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation']) o_t = tree.map_structure(lambda t: t[1], sequences['observation']) a_tm1 = tree.map_structure(lambda t: t[0], sequences['action']) a_t = tree.map_structure(lambda t: t[1], sequences['action']) r_t = tree.map_structure(lambda t: t[0], sequences['reward']) p_t = tree.map_structure( lambda d, st: d[0] * tf.cast(st[1] != dm_env.StepType.LAST, d.dtype), sequences['discount'], sequences['step_type']) info = reverb.SampleInfo(key=tf.constant(0, tf.uint64), probability=tf.constant(1.0, tf.float64), table_size=tf.constant(0, tf.int64), priority=tf.constant(1.0, tf.float64)) return reverb.ReplaySample(info=info, data=(o_tm1, a_tm1, r_t, p_t, o_t, a_t))
def save_to_path(ckpt_dir: str, state: CheckpointState): """Save the state in ckpt_dir.""" if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) is_numpy = lambda x: isinstance(x, (np.ndarray, jnp.DeviceArray)) flat_state = tree.flatten(state) nest_exemplar = tree.map_structure(is_numpy, state) array_path = os.path.join(ckpt_dir, _ARRAY_NAME) logging.info('Saving flattened array nest to %s', array_path) def _disabled_seek(*_): raise AttributeError('seek() is disabled on this object.') with open(array_path, 'wb') as f: setattr(f, 'seek', _disabled_seek) np.savez(f, *flat_state) exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) logging.info('Saving nest exemplar to %s', exemplar_path) with open(exemplar_path, 'wb') as f: pickle.dump(nest_exemplar, f)
def _maybe_create_item(self, sequence_length: int, *, end_of_episode: bool = False, force: bool = False): # Check conditions under which a new item is created. first_write = self._writer.episode_steps == sequence_length # NOTE(bshahr): the following line assumes that the only way sequence_length # is less than self._sequence_length, is if the episode is shorter than # self._sequence_length. period_reached = ( self._writer.episode_steps > self._sequence_length and ((self._writer.episode_steps - self._sequence_length) % self._period == 0)) if not first_write and not period_reached and not force: return # TODO(b/183945808): will need to change to adhere to the new protocol. if not end_of_episode: get_traj = operator.itemgetter(slice(-sequence_length - 1, -1)) else: get_traj = operator.itemgetter(slice(-sequence_length, None)) history = self._writer.history trajectory = base.Trajectory(**tree.map_structure(get_traj, history)) # Compute priorities for the buffer. table_priorities = utils.calculate_priorities(self._priority_fns, trajectory) # Create a prioritized item for each table. for table_name, priority in table_priorities.items(): self._writer.create_item(table_name, priority, trajectory) self._writer.flush(self._max_in_flight_items)
def test_ignore_regex(self): def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6])) init_fn, _ = transform.transform(f) params = init_fn(random.PRNGKey(428)) def g(x): return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x) init_fn, apply_fn = transform.without_apply_rng( transform.transform_with_state(g)) _, params_state = init_fn(None, params) params, params_state = apply_fn(None, params_state, params) # Let's modify our params. changed_params = tree.map_structure(lambda t: 2. * t, params) ema_params, params_state = apply_fn(None, params_state, changed_params) # W should be the same! # ... but b should have changed! self.assertTrue( (changed_params["linear"]["b"] != ema_params["linear"]["b"]).all()) self.assertTrue( (changed_params["linear"]["w"] == ema_params["linear"]["w"]).all())
def shuffle(self) -> "SampleBatch": """Shuffles the rows of this batch in-place. Returns: This very (now shuffled) SampleBatch. Raises: ValueError: If self[SampleBatch.SEQ_LENS] is defined. Examples: >>> from ray.rllib.policy.sample_batch import SampleBatch >>> batch = SampleBatch({"a": [1, 2, 3, 4]}) # doctest: +SKIP >>> print(batch.shuffle()) # doctest: +SKIP {"a": [4, 1, 3, 2]} """ # Shuffling the data when we have `seq_lens` defined is probably # a bad idea! if self.get(SampleBatch.SEQ_LENS) is not None: raise ValueError( "SampleBatch.shuffle not possible when your data has " "`seq_lens` defined!" ) # Get a permutation over the single items once and use the same # permutation for all the data (otherwise, data would become # meaningless). permutation = np.random.permutation(self.count) self_as_dict = {k: v for k, v in self.items()} shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict) self.update(shuffled) # Flush cache such that intercepted values are recalculated after the # shuffling. self.intercepted_values = {} return self
def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) shape = tf.shape(tree.flatten(tiled_inputs)[0]) n, b = shape[0], shape[1] tf.debugging.assert_equal(n, self._num_action_samples, 'Internal Error. Unexpected tiled_inputs shape.') dummy_zeros_n_b = tf.zeros((n, b)) # Reshape to [N * B, ...]. merge = lambda x: snt.merge_leading_dims(x, 2) tiled_inputs = tree.map_structure(merge, tiled_inputs) tiled_actions = self._actor_network(tiled_inputs) # Compute Q-values and the resulting tempered probabilities. q = self._critic_network(tiled_inputs, tiled_actions) boltzmann_logits = q / self._beta boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, 2) # [B, N] boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) # Resample one action per batch according to the Boltzmann distribution. action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample() # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to # the batch dimension. action_idx = tf.stack((tf.range(b), action_idx), axis=1) tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) action_dim = len(tiled_actions.get_shape().as_list()) tiled_actions = tf.transpose(tiled_actions, perm=[1, 0] + list(range(2, action_dim))) # [B, ...] action_sample = tf.gather_nd(tiled_actions, action_idx) return action_sample
def transition_dataset_from_spec( spec: specs.EnvironmentSpec) -> tf.data.Dataset: """Constructs fake dataset of Reverb N-step transition samples. Args: spec: Constructed fake transitions match the provided specification. Returns: tf.data.Dataset that produces the same fake N-step transition ReverbSample object indefinitely. """ observation = _generate_from_spec(spec.observations) action = _generate_from_spec(spec.actions) reward = _generate_from_spec(spec.rewards) discount = _generate_from_spec(spec.discounts) data = types.Transition(observation, action, reward, discount, observation) info = tree.map_structure( lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), reverb.SampleInfo.tf_dtypes()) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: d = ac_data.data assert ( type(d) == dict ), "Single agent data must be of type Dict[str, TensorStructType]" env_id = ac_data.env_id agent_id = ac_data.agent_id assert ( env_id is not None and agent_id is not None ), f"StateBufferConnector requires env_id(f{env_id}) and agent_id(f{agent_id})" action, states, fetches = self._states[env_id][agent_id] # TODO(jungong): Support buffering more than 1 prev actions. if action is not None: d[SampleBatch.ACTIONS] = action # Last action else: # Default zero action. d[SampleBatch.ACTIONS] = tree.map_structure( lambda s: np.zeros_like(s.sample(), s.dtype) if hasattr(s, "dtype") else np.zeros_like(s.sample()), self._action_space_struct, ) if states is None: states = self._initial_states for i, v in enumerate(states): d["state_out_{}".format(i)] = v # Also add extra fetches if available. if fetches: d.update(fetches) return ac_data
def test_make_dataset_with_sequence_length_size(self): sequence_length = 6 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( server_address=self.server_address, environment_spec=environment_spec, sequence_length=sequence_length) def make_tensor_spec(spec): return tf.TensorSpec(shape=(sequence_length, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) expected_spec = adders.Step(observation=expected_spec.observations, action=expected_spec.actions, reward=expected_spec.rewards, discount=expected_spec.discounts, start_of_episode=specs.Array( shape=(sequence_length, ), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def step(self, action: types.NestedArray) -> dm_env.TimeStep: """Steps the environment.""" if self._reset_next_step: return self.reset() observation, reward, done, info = self._environment.step(action) self._reset_next_step = done self._last_info = info # Convert the type of the reward based on the spec, respecting the scalar or # array property. reward = tree.map_structure( lambda x, t: ( # pylint: disable=g-long-lambda t.dtype.type(x) if np.isscalar(x) else np.asarray(x, dtype=t.dtype)), reward, self.reward_spec()) if done: truncated = info.get('TimeLimit.truncated', False) if truncated: return dm_env.truncation(reward, observation) return dm_env.termination(reward, observation) return dm_env.transition(reward, observation)
def test_actions_and_log_pis_symbolic(self): observation1_np = self.env.reset() observation2_np = self.env.step(self.env.action_space.sample())[0] observations_np = {} for key in observation1_np.keys(): observations_np[key] = np.stack( (observation1_np[key], observation2_np[key])).astype(np.float32) observations_tf = tree.map_structure( lambda x: tf.constant(x, dtype=tf.float32), observations_np) actions = self.policy.actions(observations_tf) with self.assertRaises(NotImplementedError): log_pis = self.policy.log_pis(observations_tf, actions) self.assertEqual(actions.shape, (2, *self.env.action_shape)) self.evaluate(tf.compat.v1.global_variables_initializer()) actions_np = self.evaluate(actions) self.assertEqual(actions_np.shape, (2, *self.env.action_shape))
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) info = tree.map_structure( lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), reverb.SampleInfo.tf_dtypes()) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def prev_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: """Returns the previous action for the specified agent, or zeros. The "previous" action is the one taken one timestep before the most recent action taken by the agent. Args: agent_id: The agent's ID to get the previous action for. Returns: Previous action the specified AgentID has executed. Zero in case the agent has never performed any actions (or only one) in the episode. """ policy_id = self.policy_for(agent_id) policy = self.policy_map[policy_id] # We are at t > 1 -> There has been a previous action by this agent. if agent_id in self._agent_to_prev_action: if policy.config.get("_disable_action_flattening"): return self._agent_to_prev_action[agent_id] else: return flatten_to_single_ndarray( self._agent_to_prev_action[agent_id]) # We're at t <= 1, so return all zeros. else: if policy.config.get("_disable_action_flattening"): return tree.map_structure( lambda a: np.zeros_like(a, a.dtype) if hasattr(a, "dtype") # noqa else np.zeros_like(a), # noqa self.last_action_for(agent_id), ) else: return np.zeros_like(self.last_action_for(agent_id))
def test_pmap_update_nested(self): local_device_count = jax.local_device_count() state = running_statistics.init_state({ 'a': specs.Array((5, ), jnp.float32), 'b': specs.Array((2, ), jnp.float32) }) x = { 'a': (jnp.arange(15 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 5), 'b': (jnp.arange(6 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 2), } devices = jax.local_devices() state = jax.device_put_replicated(state, devices) pmap_axis_name = 'i' state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) normalized = jax.pmap(running_statistics.normalize)(x, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure(lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def get_action_dist(action_space, config, dist_type=None, framework="tf", **kwargs): """Returns a distribution class and size for the given action space. Args: action_space (Space): Action space of the target gym env. config (Optional[dict]): Optional model config. dist_type (Optional[str]): Identifier of the action distribution interpreted as a hint. framework (str): One of "tf", "tfe", or "torch". kwargs (dict): Optional kwargs to pass on to the Distribution's constructor. Returns: Tuple: - dist_class (ActionDistribution): Python class of the distribution. - dist_dim (int): The size of the input vector to the distribution. """ dist = None config = config or MODEL_DEFAULTS # Custom distribution given. if config.get("custom_action_dist"): action_dist_name = config["custom_action_dist"] logger.debug( "Using custom action distribution {}".format(action_dist_name)) dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) # Dist_type is given directly as a class. elif type(dist_type) is type and \ issubclass(dist_type, ActionDistribution) and \ dist_type not in ( MultiActionDistribution, TorchMultiActionDistribution): dist = dist_type # Box space -> DiagGaussian OR Deterministic. elif isinstance(action_space, gym.spaces.Box): if len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space has multiple dimensions " "{}. ".format(action_space.shape) + "Consider reshaping this into a single dimension, " "using a custom action distribution, " "using a Tuple action space, or the multi-agent API.") # TODO(sven): Check for bounds and return SquashedNormal, etc.. if dist_type is None: dist = TorchDiagGaussian if framework == "torch" \ else DiagGaussian elif dist_type == "deterministic": dist = TorchDeterministic if framework == "torch" \ else Deterministic # Discrete Space -> Categorical. elif isinstance(action_space, gym.spaces.Discrete): dist = TorchCategorical if framework == "torch" else Categorical # Tuple/Dict Spaces -> MultiAction. elif dist_type in (MultiActionDistribution, TorchMultiActionDistribution) or \ isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)): flat_action_space = flatten_space(action_space) child_dists_and_in_lens = tree.map_structure( lambda s: ModelCatalog.get_action_dist( s, config, framework=framework), flat_action_space) child_dists = [e[0] for e in child_dists_and_in_lens] input_lens = [int(e[1]) for e in child_dists_and_in_lens] return partial( (TorchMultiActionDistribution if framework == "torch" else MultiActionDistribution), action_space=action_space, child_distributions=child_dists, input_lens=input_lens), int(sum(input_lens)) # Simplex -> Dirichlet. elif isinstance(action_space, Simplex): if framework == "torch": # TODO(sven): implement raise NotImplementedError( "Simplex action spaces not supported for torch.") dist = Dirichlet # MultiDiscrete -> MultiCategorical. elif isinstance(action_space, gym.spaces.MultiDiscrete): dist = TorchMultiCategorical if framework == "torch" else \ MultiCategorical return partial(dist, input_lens=action_space.nvec), \ int(sum(action_space.nvec)) # Unknown type -> Error. else: raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) return dist, dist.required_model_output_shape(action_space, config)