Exemple #1
0
def convert_to_non_torch_type(stats):
    """Converts values in `stats` to non-Tensor numpy or python types.

    Args:
        stats (any): Any (possibly nested) struct, the values in which will be
            converted and returned as a new struct with all torch tensors
            being converted to numpy types.

    Returns:
        Any: A new struct with the same structure as `stats`, but with all
            values converted to non-torch Tensor types.
    """

    # The mapping function used to numpyize torch Tensors.
    def mapping(item):
        if isinstance(item, torch.Tensor):
            return item.cpu().item() if len(item.size()) == 0 else \
                item.detach().cpu().numpy()
        else:
            return item

    return tree.map_structure(mapping, stats)
Exemple #2
0
        def true_fn():
            batch_size = 1
            req = force_tuple(
                action_dist.required_model_output_shape(
                    self.action_space, self.model.model_config))
            # Add a batch dimension?
            if len(action_dist.inputs.shape) == len(req) + 1:
                batch_size = tf.shape(action_dist.inputs)[0]

            # Function to produce random samples from primitive space
            # components: (Multi)Discrete or Box.
            def random_component(component):
                if isinstance(component, Discrete):
                    return tf.random.uniform(shape=(batch_size, ) +
                                             component.shape,
                                             maxval=component.n,
                                             dtype=component.dtype)
                elif isinstance(component, MultiDiscrete):
                    return tf.random.uniform(shape=(batch_size, ) +
                                             component.shape,
                                             maxval=component.nvec,
                                             dtype=component.dtype)
                elif isinstance(component, Box):
                    if component.bounded_above.all() and \
                            component.bounded_below.all():
                        return tf.random.uniform(shape=(batch_size, ) +
                                                 component.shape,
                                                 minval=component.low,
                                                 maxval=component.high,
                                                 dtype=component.dtype)
                    else:
                        return tf.random.normal(shape=(batch_size, ) +
                                                component.shape,
                                                dtype=component.dtype)

            actions = tree.map_structure(random_component,
                                         self.action_space_struct)
            return actions
Exemple #3
0
    def _write_last(self):
        # Maybe determine the delta to the next time we would write a sequence.
        if self._end_of_episode_behavior in (EndBehavior.TRUNCATE,
                                             EndBehavior.ZERO_PAD):
            delta = self._sequence_length - self._writer.episode_steps
            if delta < 0:
                delta = (self._period + delta) % self._period

        # Handle various end-of-episode cases.
        if self._end_of_episode_behavior is EndBehavior.CONTINUE:
            self._maybe_create_item(self._sequence_length, end_of_episode=True)

        elif self._end_of_episode_behavior is EndBehavior.WRITE:
            # Drop episodes that are too short.
            if self._writer.episode_steps < self._sequence_length:
                return
            self._maybe_create_item(self._sequence_length,
                                    end_of_episode=True,
                                    force=True)

        elif self._end_of_episode_behavior is EndBehavior.TRUNCATE:
            self._maybe_create_item(self._sequence_length - delta,
                                    end_of_episode=True,
                                    force=True)

        elif self._end_of_episode_behavior is EndBehavior.ZERO_PAD:
            zero_step = tree.map_structure(
                lambda x: np.zeros_like(x[-2].numpy()), self._writer.history)
            for _ in range(delta):
                self._writer.append(zero_step)

            self._maybe_create_item(self._sequence_length,
                                    end_of_episode=True,
                                    force=True)
        else:
            raise ValueError(
                f'Unhandled end of episode behavior: {self._end_of_episode_behavior}.'
                ' This should never happen, please contact Acme dev team.')
Exemple #4
0
  def _replicated_step(self):
    # Update target network
    online_variables = (
        *self._observation_network.variables,
        *self._critic_network.variables,
        *self._policy_network.variables,
    )
    target_variables = (
        *self._target_observation_network.variables,
        *self._target_critic_network.variables,
        *self._target_policy_network.variables,
    )

    # Make online -> target network update ops.
    if tf.math.mod(self._num_steps, self._target_update_period) == 0:
      for src, dest in zip(online_variables, target_variables):
        dest.assign(src)
    self._num_steps.assign_add(1)

    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)

    # This mirrors the structure of the fetches returned by self._step(),
    # but the Tensors are replaced with replicated Tensors, one per accelerator.
    replicated_fetches = self._replicator.run(self._step, args=(sample,))

    def reduce_mean_over_replicas(replicated_value):
      """Averages a replicated_value across replicas."""
      # The "axis=None" arg means reduce across replicas, not internal axes.
      return self._replicator.reduce(
          reduce_op=tf.distribute.ReduceOp.MEAN,
          value=replicated_value,
          axis=None)

    fetches = tree.map_structure(reduce_mean_over_replicas, replicated_fetches)

    return fetches
Exemple #5
0
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""

        try:
            history = self._writer.history
        except RuntimeError:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        # Add the timestep to the buffer.
        current_step = dict(
            # Observation was passed at the previous add call.
            action=action,
            reward=next_timestep.reward,
            discount=next_timestep.discount,
            # Start of episode indicator was passed at the previous add call.
            **({
                'extras': extras
            } if extras else {}))
        self._writer.append(current_step)

        # Record the next observation and write.
        self._writer.append(dict(observation=next_timestep.observation,
                                 start_of_episode=next_timestep.first()),
                            partial_step=True)
        self._write()

        if next_timestep.last():
            # Complete the row by appending zeros to remaining open fields.
            # TODO(b/183945808): remove this when fields are no longer expected to be
            # of equal length on the learner side.
            dummy_step = tree.map_structure(np.zeros_like, current_step)
            self._writer.append(dummy_step)
            self._write_last()
            self.reset()
Exemple #6
0
    def __call__(self, x: TensorStructType, update: bool = True) -> \
            TensorStructType:
        if self.no_preprocessor:
            x = tree.map_structure(lambda x_: np.asarray(x_), x)
        else:
            x = np.asarray(x)

        def _helper(x, rs, buffer, shape):
            # Discrete|MultiDiscrete spaces -> No normalization.
            if shape is None:
                return x

            # Keep dtype as is througout this filter.
            orig_dtype = x.dtype

            if update:
                if len(x.shape) == len(rs.shape) + 1:
                    # The vectorized case.
                    for i in range(x.shape[0]):
                        rs.push(x[i])
                        buffer.push(x[i])
                else:
                    # The unvectorized case.
                    rs.push(x)
                    buffer.push(x)
            if self.demean:
                x = x - rs.mean
            if self.destd:
                x = x / (rs.std + SMALL_NUMBER)
            if self.clip:
                x = np.clip(x, -self.clip, self.clip)
            return x.astype(orig_dtype)

        if self.no_preprocessor:
            return tree.map_structure_up_to(x, _helper, x, self.rs,
                                            self.buffer, self.shape)
        else:
            return _helper(x, self.rs, self.buffer, self.shape)
  def test_ema_on_changing_data(self):
    def f():
      return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

    init_fn, _ = transform.transform(f)
    params = init_fn(random.PRNGKey(428))

    def g(x):
      return moving_averages.EMAParamsTree(0.2)(x)
    init_fn, apply_fn = transform.without_apply_rng(
        transform.transform_with_state(g))
    _, params_state = init_fn(None, params)
    params, params_state = apply_fn(None, params_state, params)
    # Let's modify our params.
    changed_params = tree.map_structure(lambda t: 2. * t, params)
    ema_params, params_state = apply_fn(None, params_state, changed_params)

    # ema_params should be different from changed params!
    tree.assert_same_structure(changed_params, ema_params)
    for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
      self.assertEqual(p1.shape, p2.shape)
      with self.assertRaisesRegex(AssertionError, "Not equal to tolerance"):
        np.testing.assert_allclose(p1, p2, atol=1e-6)
Exemple #8
0
    def test_make_dataset_with_batch_size(self):
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(None, ) + spec.shape, dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        expected_spec = adders.Step(observation=expected_spec.observations,
                                    action=expected_spec.actions,
                                    reward=expected_spec.rewards,
                                    discount=expected_spec.discounts,
                                    start_of_episode=specs.Array(
                                        shape=(batch_size, ), dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Exemple #9
0
def create_no_nodefeatures_graphs(features, topology):
    comp_to_id = {'R': 1, 'L': 2, 'C': 3, 'V': 4}
    comp_to_value = {'R': 50, 'L': 1e-3, 'C': 1e-6, 'V': 0}
    graphs_list = []
    for circuit in range(len(features.keys())):
        circuit_features = features["circuit_{}".format(circuit + 1)]
        circuit_topology = topology["circuit_{}".format(circuit + 1)]
        nodes = []
        edges = []
        senders = []
        receivers = []
        for sender, receiver in circuit_topology.items():
            for i in range(len(receiver)):
                senders.append(float(sender))
                receivers.append(float(receiver[i][1]))
                edges.append([float(comp_to_id[receiver[i][0][0]]), float(comp_to_value[receiver[i][0][0]])])
        maximum = 0
        for i in range(len(circuit_topology.keys())):
            nodes.append([0.0])


        for i in range(len(nodes)):
            while len(nodes[i]) < 7:
                nodes[i].append(0.0)

        graph = {
            "nodes": nodes,
            "edges": edges,
            "senders": senders,
            "receivers": receivers,
            "globals": [0.0, 0.0, 0.0]
        }
        graphs_list.append(graph)

    graphs_tuple = utils_np.data_dicts_to_graphs_tuple(graphs_list)
    graphs_tuple = tree.map_structure(lambda x: tf.constant(x) if x is not None else None, graphs_tuple)
    return graphs_tuple
Exemple #10
0
    def last_action_for(self,
                        agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
        """Returns the last action for the specified AgentID, or zeros.

        The "last" action is the most recent one taken by the agent.

        Args:
            agent_id: The agent's ID to get the last action for.

        Returns:
            Last action the specified AgentID has executed.
            Zeros in case the agent has never performed any actions in the
            episode.
        """
        policy_id = self.policy_for(agent_id)
        policy = self.policy_map[policy_id]

        # Agent has already taken at least one action in the episode.
        if agent_id in self._agent_to_last_action:
            if policy.config.get("_disable_action_flattening"):
                return self._agent_to_last_action[agent_id]
            else:
                return flatten_to_single_ndarray(
                    self._agent_to_last_action[agent_id])
        # Agent has not acted yet, return all zeros.
        else:
            if policy.config.get("_disable_action_flattening"):
                return tree.map_structure(
                    lambda s: np.zeros_like(s.sample(), s.dtype)
                    if hasattr(s, "dtype") else np.zeros_like(s.sample()),
                    policy.action_space_struct,
                )
            else:
                flat = flatten_to_single_ndarray(policy.action_space.sample())
                if hasattr(policy.action_space, "dtype"):
                    return np.zeros_like(flat, dtype=policy.action_space.dtype)
                return np.zeros_like(flat)
Exemple #11
0
def _get_input_signature(module: snt.Module):
  """Get module input signature.

  Works even if the module with signature is wrapper into snt.Sequentual or
  snt.DeepRNN.

  Args:
    module: the module which input signature we need to get. The module has to
      either have input_signature itself (i.e. you have to run create_variables
      on the module), or it has to be a module (with input_signature) wrapped in
      (one or multiple) snt.Sequential or snt.DeepRNNs.

  Returns:
    Input signature of the module or None if it's not available.
  """
  if hasattr(module, '_input_signature'):
    return module._input_signature  # pylint: disable=protected-access

  if isinstance(module, snt.Sequential):
    first_layer = module._layers[0]  # pylint: disable=protected-access
    return _get_input_signature(first_layer)

  if isinstance(module, snt.DeepRNN):
    first_layer = module._layers[0]  # pylint: disable=protected-access
    input_signature = _get_input_signature(first_layer)

    # Wrapping a module in DeepRNN changes its state shape, so we need to bring
    # it up to date.
    state = module.initial_state(1)
    input_signature[-1] = tree.map_structure(
        lambda t: tf.TensorSpec((None,) + t.shape[1:], t.dtype), state)

    return input_signature

  # If we get here we can't determine the input signature. So give up.
  raise ValueError('module instance has no input_signature attribute; run '
                   'create_variables to add this annotation.')
Exemple #12
0
def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
    """Convert all the components of the element to match the space dtypes.

    Args:
        element: The element to be converted.
        sampled_element: An element sampled from a space to be matched
            to.

    Returns:
        The input element, but with all its components converted to match
        the space dtypes.
    """
    def map_(elem, s):
        if isinstance(s, np.ndarray):
            if not isinstance(elem, np.ndarray):
                assert isinstance(
                    elem, (float, int)
                ), f"ERROR: `elem` ({elem}) must be np.array, float or int!"
                if s.shape == ():
                    elem = np.array(elem, dtype=s.dtype)
                else:
                    raise ValueError(
                        "Element should be of type np.ndarray but is instead of \
                            type {}".format(type(elem)))
            elif s.dtype != elem.dtype:
                elem = elem.astype(s.dtype)

        elif isinstance(s, int):
            if isinstance(elem, float) and elem.is_integer():
                elem = int(elem)

        return elem

    return tree.map_structure(map_,
                              element,
                              sampled_element,
                              check_types=False)
Exemple #13
0
    def logp(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        # Single tensor input (all merged).
        if isinstance(x, torch.Tensor):
            split_indices = []
            for dist in self.flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                elif isinstance(dist, TorchMultiCategorical) and \
                        dist.action_space is not None:
                    split_indices.append(int(np.prod(dist.action_space.shape)))
                else:
                    sample = dist.sample()
                    # Cover Box(shape=()) case.
                    if len(sample.shape) == 1:
                        split_indices.append(1)
                    else:
                        split_indices.append(sample.size()[1])
            split_x = list(torch.split(x, split_indices, dim=1))
        # Structured or flattened (by single action component) input.
        else:
            split_x = tree.flatten(x)

        def map_(val, dist):
            # Remove extra categorical dimension.
            if isinstance(dist, TorchCategorical):
                val = (torch.squeeze(val, dim=-1)
                       if len(val.shape) > 1 else val).int()
            return dist.logp(val)

        # Remove extra categorical dimension and take the logp of each
        # component.
        flat_logps = tree.map_structure(map_, split_x,
                                        self.flat_child_distributions)

        return functools.reduce(lambda a, b: a + b, flat_logps)
Exemple #14
0
    def _write_last(self):
        # Create a final step.
        final_step = utils.final_step_like(self._buffer[0],
                                           self._next_observation)

        # Append the final step.
        self._buffer.append(final_step)
        self._writer.append(final_step)
        self._step += 1

        # Determine the delta to the next time we would write a sequence.
        first_write = self._step <= self._max_sequence_length
        if first_write:
            delta = self._max_sequence_length - self._step
        else:
            delta = (self._period -
                     (self._step - self._max_sequence_length)) % self._period

        # Bump up to the position where we will write a sequence.
        self._step += delta

        if self._pad_end_of_episode:
            zero_step = tree.map_structure(utils.zeros_like, final_step)

            # Pad with zeros to get a full sequence.
            for _ in range(delta):
                self._buffer.append(zero_step)
                self._writer.append(zero_step)
        elif not first_write:
            # Pop items from the buffer to get a truncated sequence.
            # Note: this is consistent with the padding loop above, since adding zero
            # steps pops the left-most elements. Here we just pop without padding.
            for _ in range(delta):
                self._buffer.popleft()

        # Write priorities for the sequence.
        self._maybe_add_priorities()
Exemple #15
0
def unsquash_action(action, action_space_struct):
    """Unsquashes all components in `action` according to the given Space.

    Inverse of `normalize_action()`. Useful for mapping policy action
    outputs (normalized between -1.0 and 1.0) to an env's action space.
    Unsquashing results in cont. action component values between the
    given Space's bounds (`low` and `high`). This only applies to Box
    components within the action space, whose dtype is float32 or float64.

    Args:
        action (Any): The action to be unsquashed. This could be any complex
            action, e.g. a dict or tuple.
        action_space_struct (Any): The action space struct,
            e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).

    Returns:
        Any: The input action, but unsquashed, according to the space's
            bounds. An unsquashed action is ready to be sent to the
            environment (`BaseEnv.send_actions([unsquashed actions])`).
    """
    def map_(a, s):
        if (isinstance(s, gym.spaces.Box) and np.all(s.bounded_below)
                and np.all(s.bounded_above)):
            if s.dtype == np.float32 or s.dtype == np.float64:
                # Assuming values are roughly between -1.0 and 1.0 ->
                # unsquash them to the given bounds.
                a = s.low + (a + 1.0) * (s.high - s.low) / 2.0
                # Clip to given bounds, just in case the squashed values were
                # outside [-1.0, 1.0].
                a = np.clip(a, s.low, s.high)
            elif np.issubdtype(s.dtype, np.integer):
                # For Categorical and MultiCategorical actions, shift the selection
                # into the proper range.
                a = s.low + a
        return a

    return tree.map_structure(map_, action, action_space_struct)
def _build_sarsa_example(sequences):
    """Convert raw sequences into a Reverb n-step SARSA sample."""

    o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])
    o_t = tree.map_structure(lambda t: t[1], sequences['observation'])
    a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])
    a_t = tree.map_structure(lambda t: t[1], sequences['action'])
    r_t = tree.map_structure(lambda t: t[0], sequences['reward'])
    p_t = tree.map_structure(
        lambda d, st: d[0] * tf.cast(st[1] != dm_env.StepType.LAST, d.dtype),
        sequences['discount'], sequences['step_type'])

    info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
                             probability=tf.constant(1.0, tf.float64),
                             table_size=tf.constant(0, tf.int64),
                             priority=tf.constant(1.0, tf.float64))
    return reverb.ReplaySample(info=info,
                               data=(o_tm1, a_tm1, r_t, p_t, o_t, a_t))
Exemple #17
0
def save_to_path(ckpt_dir: str, state: CheckpointState):
    """Save the state in ckpt_dir."""

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    is_numpy = lambda x: isinstance(x, (np.ndarray, jnp.DeviceArray))
    flat_state = tree.flatten(state)
    nest_exemplar = tree.map_structure(is_numpy, state)

    array_path = os.path.join(ckpt_dir, _ARRAY_NAME)
    logging.info('Saving flattened array nest to %s', array_path)

    def _disabled_seek(*_):
        raise AttributeError('seek() is disabled on this object.')

    with open(array_path, 'wb') as f:
        setattr(f, 'seek', _disabled_seek)
        np.savez(f, *flat_state)

    exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME)
    logging.info('Saving nest exemplar to %s', exemplar_path)
    with open(exemplar_path, 'wb') as f:
        pickle.dump(nest_exemplar, f)
Exemple #18
0
    def _maybe_create_item(self,
                           sequence_length: int,
                           *,
                           end_of_episode: bool = False,
                           force: bool = False):

        # Check conditions under which a new item is created.
        first_write = self._writer.episode_steps == sequence_length
        # NOTE(bshahr): the following line assumes that the only way sequence_length
        # is less than self._sequence_length, is if the episode is shorter than
        # self._sequence_length.
        period_reached = (
            self._writer.episode_steps > self._sequence_length
            and ((self._writer.episode_steps - self._sequence_length) %
                 self._period == 0))

        if not first_write and not period_reached and not force:
            return

        # TODO(b/183945808): will need to change to adhere to the new protocol.
        if not end_of_episode:
            get_traj = operator.itemgetter(slice(-sequence_length - 1, -1))
        else:
            get_traj = operator.itemgetter(slice(-sequence_length, None))

        history = self._writer.history
        trajectory = base.Trajectory(**tree.map_structure(get_traj, history))

        # Compute priorities for the buffer.
        table_priorities = utils.calculate_priorities(self._priority_fns,
                                                      trajectory)

        # Create a prioritized item for each table.
        for table_name, priority in table_priorities.items():
            self._writer.create_item(table_name, priority, trajectory)
            self._writer.flush(self._max_in_flight_items)
    def test_ignore_regex(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = transform.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2, ignore_regex=".*w")(x)

        init_fn, apply_fn = transform.without_apply_rng(
            transform.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # W should be the same!
        # ... but b should have changed!
        self.assertTrue(
            (changed_params["linear"]["b"] != ema_params["linear"]["b"]).all())
        self.assertTrue(
            (changed_params["linear"]["w"] == ema_params["linear"]["w"]).all())
Exemple #20
0
    def shuffle(self) -> "SampleBatch":
        """Shuffles the rows of this batch in-place.

        Returns:
            This very (now shuffled) SampleBatch.

        Raises:
            ValueError: If self[SampleBatch.SEQ_LENS] is defined.

        Examples:
            >>> from ray.rllib.policy.sample_batch import SampleBatch
            >>> batch = SampleBatch({"a": [1, 2, 3, 4]})  # doctest: +SKIP
            >>> print(batch.shuffle()) # doctest: +SKIP
            {"a": [4, 1, 3, 2]}
        """

        # Shuffling the data when we have `seq_lens` defined is probably
        # a bad idea!
        if self.get(SampleBatch.SEQ_LENS) is not None:
            raise ValueError(
                "SampleBatch.shuffle not possible when your data has "
                "`seq_lens` defined!"
            )

        # Get a permutation over the single items once and use the same
        # permutation for all the data (otherwise, data would become
        # meaningless).
        permutation = np.random.permutation(self.count)

        self_as_dict = {k: v for k, v in self.items()}
        shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
        self.update(shuffled)
        # Flush cache such that intercepted values are recalculated after the
        # shuffling.
        self.intercepted_values = {}
        return self
Exemple #21
0
  def __call__(self, inputs: types.NestedTensor) -> tf.Tensor:
    # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...].
    tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples)
    shape = tf.shape(tree.flatten(tiled_inputs)[0])
    n, b = shape[0], shape[1]
    tf.debugging.assert_equal(n, self._num_action_samples,
                              'Internal Error. Unexpected tiled_inputs shape.')
    dummy_zeros_n_b = tf.zeros((n, b))
    # Reshape to [N * B, ...].
    merge = lambda x: snt.merge_leading_dims(x, 2)
    tiled_inputs = tree.map_structure(merge, tiled_inputs)

    tiled_actions = self._actor_network(tiled_inputs)

    # Compute Q-values and the resulting tempered probabilities.
    q = self._critic_network(tiled_inputs, tiled_actions)
    boltzmann_logits = q / self._beta

    boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b,
                                             2)
    # [B, N]
    boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0))
    # Resample one action per batch according to the Boltzmann distribution.
    action_idx = tfp.distributions.Categorical(logits=boltzmann_logits).sample()
    # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to
    # the batch dimension.
    action_idx = tf.stack((tf.range(b), action_idx), axis=1)

    tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2)
    action_dim = len(tiled_actions.get_shape().as_list())
    tiled_actions = tf.transpose(tiled_actions,
                                 perm=[1, 0] + list(range(2, action_dim)))
    # [B, ...]
    action_sample = tf.gather_nd(tiled_actions, action_idx)

    return action_sample
Exemple #22
0
def transition_dataset_from_spec(
        spec: specs.EnvironmentSpec) -> tf.data.Dataset:
    """Constructs fake dataset of Reverb N-step transition samples.

  Args:
    spec: Constructed fake transitions match the provided specification.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverbSample
    object indefinitely.
  """

    observation = _generate_from_spec(spec.observations)
    action = _generate_from_spec(spec.actions)
    reward = _generate_from_spec(spec.rewards)
    discount = _generate_from_spec(spec.discounts)
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Exemple #23
0
    def transform(self,
                  ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        env_id = ac_data.env_id
        agent_id = ac_data.agent_id
        assert (
            env_id is not None and agent_id is not None
        ), f"StateBufferConnector requires env_id(f{env_id}) and agent_id(f{agent_id})"

        action, states, fetches = self._states[env_id][agent_id]

        # TODO(jungong): Support buffering more than 1 prev actions.
        if action is not None:
            d[SampleBatch.ACTIONS] = action  # Last action
        else:
            # Default zero action.
            d[SampleBatch.ACTIONS] = tree.map_structure(
                lambda s: np.zeros_like(s.sample(), s.dtype)
                if hasattr(s, "dtype") else np.zeros_like(s.sample()),
                self._action_space_struct,
            )

        if states is None:
            states = self._initial_states
        for i, v in enumerate(states):
            d["state_out_{}".format(i)] = v

        # Also add extra fetches if available.
        if fetches:
            d.update(fetches)

        return ac_data
Exemple #24
0
    def test_make_dataset_with_sequence_length_size(self):
        sequence_length = 6
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            server_address=self.server_address,
            environment_spec=environment_spec,
            sequence_length=sequence_length)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(sequence_length, ) + spec.shape,
                                 dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        expected_spec = adders.Step(observation=expected_spec.observations,
                                    action=expected_spec.actions,
                                    reward=expected_spec.rewards,
                                    discount=expected_spec.discounts,
                                    start_of_episode=specs.Array(
                                        shape=(sequence_length, ), dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Exemple #25
0
    def step(self, action: types.NestedArray) -> dm_env.TimeStep:
        """Steps the environment."""
        if self._reset_next_step:
            return self.reset()

        observation, reward, done, info = self._environment.step(action)
        self._reset_next_step = done
        self._last_info = info

        # Convert the type of the reward based on the spec, respecting the scalar or
        # array property.
        reward = tree.map_structure(
            lambda x, t: (  # pylint: disable=g-long-lambda
                t.dtype.type(x)
                if np.isscalar(x) else np.asarray(x, dtype=t.dtype)),
            reward,
            self.reward_spec())

        if done:
            truncated = info.get('TimeLimit.truncated', False)
            if truncated:
                return dm_env.truncation(reward, observation)
            return dm_env.termination(reward, observation)
        return dm_env.transition(reward, observation)
Exemple #26
0
    def test_actions_and_log_pis_symbolic(self):
        observation1_np = self.env.reset()
        observation2_np = self.env.step(self.env.action_space.sample())[0]

        observations_np = {}
        for key in observation1_np.keys():
            observations_np[key] = np.stack(
                (observation1_np[key],
                 observation2_np[key])).astype(np.float32)

        observations_tf = tree.map_structure(
            lambda x: tf.constant(x, dtype=tf.float32), observations_np)

        actions = self.policy.actions(observations_tf)
        with self.assertRaises(NotImplementedError):
            log_pis = self.policy.log_pis(observations_tf, actions)

        self.assertEqual(actions.shape, (2, *self.env.action_shape))

        self.evaluate(tf.compat.v1.global_variables_initializer())

        actions_np = self.evaluate(actions)

        self.assertEqual(actions_np.shape, (2, *self.env.action_shape))
Exemple #27
0
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset:
    """Fake dataset of Reverb N-step transition samples.

  Args:
    environment: Used to create a fake transition by looking at the observation,
      action, discount and reward specs.

  Returns:
    tf.data.Dataset that produces the same fake N-step transition ReverSample
    object indefinitely.
  """

    observation = environment.observation_spec().generate_value()
    action = environment.action_spec().generate_value()
    reward = environment.reward_spec().generate_value()
    discount = environment.discount_spec().generate_value()
    data = types.Transition(observation, action, reward, discount, observation)

    info = tree.map_structure(
        lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype),
        reverb.SampleInfo.tf_dtypes())
    sample = reverb.ReplaySample(info=info, data=data)

    return tf.data.Dataset.from_tensors(sample).repeat()
Exemple #28
0
    def prev_action_for(self,
                        agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
        """Returns the previous action for the specified agent, or zeros.

        The "previous" action is the one taken one timestep before the
        most recent action taken by the agent.

        Args:
            agent_id: The agent's ID to get the previous action for.

        Returns:
            Previous action the specified AgentID has executed.
            Zero in case the agent has never performed any actions (or only
            one) in the episode.
        """
        policy_id = self.policy_for(agent_id)
        policy = self.policy_map[policy_id]

        # We are at t > 1 -> There has been a previous action by this agent.
        if agent_id in self._agent_to_prev_action:
            if policy.config.get("_disable_action_flattening"):
                return self._agent_to_prev_action[agent_id]
            else:
                return flatten_to_single_ndarray(
                    self._agent_to_prev_action[agent_id])
        # We're at t <= 1, so return all zeros.
        else:
            if policy.config.get("_disable_action_flattening"):
                return tree.map_structure(
                    lambda a: np.zeros_like(a, a.dtype)
                    if hasattr(a, "dtype")  # noqa
                    else np.zeros_like(a),  # noqa
                    self.last_action_for(agent_id),
                )
            else:
                return np.zeros_like(self.last_action_for(agent_id))
Exemple #29
0
    def test_pmap_update_nested(self):
        local_device_count = jax.local_device_count()
        state = running_statistics.init_state({
            'a':
            specs.Array((5, ), jnp.float32),
            'b':
            specs.Array((2, ), jnp.float32)
        })

        x = {
            'a':
            (jnp.arange(15 * local_device_count,
                        dtype=jnp.float32)).reshape(local_device_count, 3, 5),
            'b':
            (jnp.arange(6 * local_device_count,
                        dtype=jnp.float32)).reshape(local_device_count, 3, 2),
        }

        devices = jax.local_devices()
        state = jax.device_put_replicated(state, devices)
        pmap_axis_name = 'i'
        state = jax.pmap(
            functools.partial(update_and_validate,
                              pmap_axis_name=pmap_axis_name),
            pmap_axis_name)(state, x)
        state = jax.pmap(
            functools.partial(update_and_validate,
                              pmap_axis_name=pmap_axis_name),
            pmap_axis_name)(state, x)
        normalized = jax.pmap(running_statistics.normalize)(x, state)

        mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)),
                                  normalized)
        std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
        tree.map_structure(
            lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean)
        tree.map_structure(lambda x: self.assert_allclose(x, jnp.ones_like(x)),
                           std)
Exemple #30
0
    def get_action_dist(action_space,
                        config,
                        dist_type=None,
                        framework="tf",
                        **kwargs):
        """Returns a distribution class and size for the given action space.

        Args:
            action_space (Space): Action space of the target gym env.
            config (Optional[dict]): Optional model config.
            dist_type (Optional[str]): Identifier of the action distribution
                interpreted as a hint.
            framework (str): One of "tf", "tfe", or "torch".
            kwargs (dict): Optional kwargs to pass on to the Distribution's
                constructor.

        Returns:
            Tuple:
                - dist_class (ActionDistribution): Python class of the
                    distribution.
                - dist_dim (int): The size of the input vector to the
                    distribution.
        """

        dist = None
        config = config or MODEL_DEFAULTS
        # Custom distribution given.
        if config.get("custom_action_dist"):
            action_dist_name = config["custom_action_dist"]
            logger.debug(
                "Using custom action distribution {}".format(action_dist_name))
            dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
        # Dist_type is given directly as a class.
        elif type(dist_type) is type and \
                issubclass(dist_type, ActionDistribution) and \
                dist_type not in (
                MultiActionDistribution, TorchMultiActionDistribution):
            dist = dist_type
        # Box space -> DiagGaussian OR Deterministic.
        elif isinstance(action_space, gym.spaces.Box):
            if len(action_space.shape) > 1:
                raise UnsupportedSpaceException(
                    "Action space has multiple dimensions "
                    "{}. ".format(action_space.shape) +
                    "Consider reshaping this into a single dimension, "
                    "using a custom action distribution, "
                    "using a Tuple action space, or the multi-agent API.")
            # TODO(sven): Check for bounds and return SquashedNormal, etc..
            if dist_type is None:
                dist = TorchDiagGaussian if framework == "torch" \
                    else DiagGaussian
            elif dist_type == "deterministic":
                dist = TorchDeterministic if framework == "torch" \
                    else Deterministic
        # Discrete Space -> Categorical.
        elif isinstance(action_space, gym.spaces.Discrete):
            dist = TorchCategorical if framework == "torch" else Categorical
        # Tuple/Dict Spaces -> MultiAction.
        elif dist_type in (MultiActionDistribution,
                           TorchMultiActionDistribution) or \
                isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
            flat_action_space = flatten_space(action_space)
            child_dists_and_in_lens = tree.map_structure(
                lambda s: ModelCatalog.get_action_dist(
                    s, config, framework=framework), flat_action_space)
            child_dists = [e[0] for e in child_dists_and_in_lens]
            input_lens = [int(e[1]) for e in child_dists_and_in_lens]
            return partial(
                (TorchMultiActionDistribution
                 if framework == "torch" else MultiActionDistribution),
                action_space=action_space,
                child_distributions=child_dists,
                input_lens=input_lens), int(sum(input_lens))
        # Simplex -> Dirichlet.
        elif isinstance(action_space, Simplex):
            if framework == "torch":
                # TODO(sven): implement
                raise NotImplementedError(
                    "Simplex action spaces not supported for torch.")
            dist = Dirichlet
        # MultiDiscrete -> MultiCategorical.
        elif isinstance(action_space, gym.spaces.MultiDiscrete):
            dist = TorchMultiCategorical if framework == "torch" else \
                MultiCategorical
            return partial(dist, input_lens=action_space.nvec), \
                int(sum(action_space.nvec))
        # Unknown type -> Error.
        else:
            raise NotImplementedError("Unsupported args: {} {}".format(
                action_space, dist_type))

        return dist, dist.required_model_output_shape(action_space, config)