Пример #1
0
def all_trees_rec_bb(taxa, cm, og):
	'''Return all the trees that can be made out of the set of taxa taxa,
	recursively. The recursive implementation was necessary because otherwise
	taxon was not scoped correctly.'''
	if len(taxa) == 1:
		return [tree.tree(id=taxa.pop())]
	else:
		taxon = taxa.pop()
		trees = (trees_adding_bb(t, taxon, cm, og) for t in all_trees_rec_bb(taxa, cm, og))
		return tree.flatten(trees)
Пример #2
0
    def add_init_obs(
        self,
        episode_id: EpisodeID,
        agent_index: int,
        env_id: EnvID,
        t: int,
        init_obs: TensorType,
    ) -> None:
        """Adds an initial observation (after reset) to the Agent's trajectory.

        Args:
            episode_id (EpisodeID): Unique ID for the episode we are adding the
                initial observation for.
            agent_index (int): Unique int index (starting from 0) for the agent
                within its episode. Not to be confused with AGENT_ID (Any).
            env_id (EnvID): The environment index (in a vectorized setup).
            t (int): The time step (episode length - 1). The initial obs has
                ts=-1(!), then an action/reward/next-obs at t=0, etc..
            init_obs (TensorType): The initial observation tensor (after
            `env.reset()`).
        """
        # Store episode ID + unroll ID, which will be constant throughout this
        # AgentCollector's lifecycle.
        self.episode_id = episode_id
        if self.unroll_id is None:
            self.unroll_id = _AgentCollector._next_unroll_id
            _AgentCollector._next_unroll_id += 1

        if SampleBatch.OBS not in self.buffers:
            self._build_buffers(
                single_row={
                    SampleBatch.OBS: init_obs,
                    SampleBatch.AGENT_INDEX: agent_index,
                    SampleBatch.ENV_ID: env_id,
                    SampleBatch.T: t,
                    SampleBatch.EPS_ID: self.episode_id,
                    SampleBatch.UNROLL_ID: self.unroll_id,
                }
            )

        # Append data to existing buffers.
        flattened = tree.flatten(init_obs)
        for i, sub_obs in enumerate(flattened):
            self.buffers[SampleBatch.OBS][i].append(sub_obs)
        self.buffers[SampleBatch.AGENT_INDEX][0].append(agent_index)
        self.buffers[SampleBatch.ENV_ID][0].append(env_id)
        self.buffers[SampleBatch.T][0].append(t)
        self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
        self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)
Пример #3
0
  def append_sequence(self, sequence: Any):
    """Appends sequence of data to the internal buffer.

    Each element in `sequence` must have the same leading dimension [T].

    A call to `append_sequence` is equivalent to splitting `sequence` along its
    first dimension and calling `append` once for each slice.

    For example:

    ```python

      with client.writer(max_sequence_length=2) as writer:
        sequence = np.array([[1, 2, 3],
                             [4, 5, 6]])

        # Insert two timesteps.
        writer.append_sequence([sequence])

        # Create an item that references the step [4, 5, 6].
        writer.create_item('my_table', num_timesteps=1, priority=1.0)

        # Create an item that references the steps [1, 2, 3] and [4, 5, 6].
        writer.create_item('my_table', num_timesteps=2, priority=1.0)

    ```

    Is equivalent to:

    ```python

      with client.writer(max_sequence_length=2) as writer:
        # Insert two timesteps.
        writer.append([np.array([1, 2, 3])])
        writer.append([np.array([4, 5, 6])])

        # Create an item that references the step [4, 5, 6].
        writer.create_item('my_table', num_timesteps=1, priority=1.0)

        # Create an item that references the steps [1, 2, 3] and [4, 5, 6].
        writer.create_item('my_table', num_timesteps=2, priority=1.0)

    ```

    Args:
      sequence: Batched (possibly nested) structure to make available for items
        to reference.
    """
    self._writer.AppendSequence(tree.flatten(sequence))
Пример #4
0
    def forward(self, input_dict, state, seq_lens):
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(
                input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="tf"
            )
        # Push image observations through our CNNs.
        outs = []
        for i, component in enumerate(tree.flatten(orig_obs)):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i](SampleBatch({SampleBatch.OBS: component}))
                outs.append(cnn_out)
            elif i in self.one_hot:
                if "int" in component.dtype.name:
                    one_hot_in = {
                        SampleBatch.OBS: one_hot(
                            component, self.flattened_input_space[i]
                        )
                    }
                else:
                    one_hot_in = {SampleBatch.OBS: component}
                one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in))
                outs.append(one_hot_out)
            else:
                nn_out, _ = self.flatten[i](
                    SampleBatch(
                        {
                            SampleBatch.OBS: tf.cast(
                                tf.reshape(component, [-1, self.flatten_dims[i]]),
                                tf.float32,
                            )
                        }
                    )
                )
                outs.append(nn_out)
        # Concat all outputs and the non-image inputs.
        out = tf.concat(outs, axis=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out}))

        # No logits/value branches.
        if not self.logits_and_value_model:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_and_value_model(out)
        self._value_out = tf.reshape(values, [-1])
        return logits, []
Пример #5
0
def explicit_ntk(
    fwd_fn: Callable,
    params: hk.Params,
    sigma: hk.Params,
    diag=False,
) -> jnp.ndarray:
    """Calculate J * diag(sigma) * J^T, where J is Jacobian of model with respect to model parameters

   using explicit implementation and einsum

  Args:
    fwd_fn: a function that only takes in parameters and returns model output of
      shape (batch_dim, output_dim).
    params: the model parameters.
    sigma: it has the same structure and array shapes as the parameters of
      model.
    diag: if True, only calculating the diagonal of NTK.

  Returns:
    jnp.ndarray, array of shape (batch_dim, output_dim) if diag==True else
       (batch_dim, output_dim, batch_dim, output_dim)
  """
    jacobian = jax.jacobian(fwd_fn)(params)

    def _get_diag_ntk(jac, sigma):
        # jac has shape (batch_dim, output_dim, params_dims...)
        # jac_2d has shape (batch_dim * output_dim, nb_params)
        batch_dim, output_dim = jac.shape[:2]
        jac_2d = jnp.reshape(jac, (batch_dim * output_dim, -1))
        # sigma_flatten has shape (nb_params,) and will be broadcasted to the same
        # shape as jac_2d
        sigma_flatten = jnp.reshape(sigma, (-1, ))
        # jac_sigma_product has the same shape as jac_2d
        jac_sigma_product = jnp.multiply(jac_2d, sigma_flatten)
        # diag_ntk has shape (batch_dim * output_dim,)
        if diag:
            ntk = jnp.einsum("ij,ji->i", jac_sigma_product, jac_2d.T)
            ntk = jnp.reshape(ntk, (batch_dim, output_dim))
            # ntk has shape (batch_dim, output_dim)
        else:
            ntk = jnp.matmul(jac_sigma_product, jac_2d.T)
            ntk = jnp.reshape(ntk,
                              (batch_dim, output_dim, batch_dim, output_dim))
            # ntk has shape (batch_dim, output_dim, batch_dim, output_dim)
        return ntk

    diag_ntk = tree.map_structure(_get_diag_ntk, jacobian, sigma)
    diag_ntk_sum_array = jnp.stack(tree.flatten(diag_ntk), axis=0).sum(axis=0)
    return diag_ntk_sum_array
Пример #6
0
    def add_path(self, path):
        path = path.copy()
        path_length = tree.flatten(path)[0].shape[0]
        path.update({
            'episode_index_forwards': np.arange(
                path_length,
                dtype=self.fields['episode_index_forwards'].dtype
            )[..., np.newaxis],
            'episode_index_backwards': np.arange(
                path_length,
                dtype=self.fields['episode_index_backwards'].dtype
            )[::-1, np.newaxis],
        })

        return self.add_samples(path)
Пример #7
0
    def test_bias_parameter_shape(self):
        params = self.network.init(self.init_rng_key,
                                   _sample_input(self.input_shape))
        self.assertLen(tree.flatten(params), 2)

        def check_params(path, param):
            if path[-1] == 'b':
                self.assertNotEqual(self.output_shape, param.shape)
                chex.assert_shape(param, (1, ))
            elif path[-1] == 'w':
                chex.assert_shape(param, self.weights_shape)
            else:
                self.fail('Unexpected parameter %s.' % path)

        tree.map_structure_with_path(check_params, params)
Пример #8
0
    def actions_np(self, observations):
        if self._deterministic:
            return self.deterministic_actions_model.predict(observations)
        elif self._smoothing_alpha == 0:
            return self.actions_model.predict(observations)
        else:
            alpha, beta = self._smoothing_alpha, self._smoothing_beta
            raw_latents = self.latents_model.predict(observations)
            self._smoothing_x = (alpha * self._smoothing_x +
                                 (1.0 - alpha) * raw_latents)
            latents = beta * self._smoothing_x

            inputs = tree.flatten((observations, latents))
            # TODO(hartikainen/tf2): This can be removed once we use .numpy()
            return self.actions_model_for_fixed_latents.predict(inputs)
Пример #9
0
    def actions(self, observations):
        """Compute actions for given observations."""
        observations = self._filter_observations(observations)

        first_observation = tree.flatten(observations)[0]
        first_input_rank = tf.size(tree.flatten(self._input_shapes)[0])
        batch_shape = tf.shape(first_observation)[:-first_input_rank]

        shifts, scales = self.shift_and_scale_model(observations)

        if self._deterministic:
            actions = self._action_post_processor(shifts)
        else:
            actions = self.action_distribution.sample(batch_shape,
                                                      bijector_kwargs={
                                                          'scale': {
                                                              'scale': scales
                                                          },
                                                          'shift': {
                                                              'shift': shifts
                                                          }
                                                      })

        return actions
Пример #10
0
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = base.transform(g, state=True)
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
Пример #11
0
        def compute_actions_from_input_dict(
            self,
            input_dict: Dict[str, TensorType],
            explore: bool = None,
            timestep: Optional[int] = None,
            episodes: Optional[List[Episode]] = None,
            **kwargs,
        ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

            if not self.config.get(
                    "eager_tracing") and not tf1.executing_eagerly():
                tf1.enable_eager_execution()

            self._is_training = False

            explore = explore if explore is not None else self.explore
            timestep = timestep if timestep is not None else self.global_timestep
            if isinstance(timestep, tf.Tensor):
                timestep = int(timestep.numpy())

            # Pass lazy (eager) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            input_dict.set_training(False)

            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            self._state_in = state_batches
            self._is_recurrent = state_batches != []

            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(timestep=timestep,
                                                    explore=explore,
                                                    tf_sess=self.get_session())

            ret = self._compute_actions_helper(
                input_dict,
                state_batches,
                # TODO: Passing episodes into a traced method does not work.
                None if self.config["eager_tracing"] else episodes,
                explore,
                timestep,
            )
            # Update our global timestep by the batch size.
            self.global_timestep.assign_add(
                tree.flatten(ret[0])[0].shape.as_list()[0])
            return convert_to_numpy(ret)
Пример #12
0
    def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> None:
        """Adds the given dictionary (row) of values to the Agent's trajectory.

        Args:
            values (Dict[str, TensorType]): Data dict (interpreted as a single
                row) to be added to buffer. Must contain keys:
                SampleBatch.ACTIONS, REWARDS, DONES, and NEXT_OBS.
        """
        if self.unroll_id is None:
            self.unroll_id = _AgentCollector._next_unroll_id
            _AgentCollector._next_unroll_id += 1

        # Next obs -> obs.
        assert SampleBatch.OBS not in values
        values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
        del values[SampleBatch.NEXT_OBS]

        # Make sure EPS_ID/UNROLL_ID stay the same for this agent.
        if SampleBatch.EPS_ID in values:
            assert values[SampleBatch.EPS_ID] == self.episode_id
            del values[SampleBatch.EPS_ID]
        self.buffers[SampleBatch.EPS_ID][0].append(self.episode_id)
        if SampleBatch.UNROLL_ID in values:
            assert values[SampleBatch.UNROLL_ID] == self.unroll_id
            del values[SampleBatch.UNROLL_ID]
        self.buffers[SampleBatch.UNROLL_ID][0].append(self.unroll_id)

        for k, v in values.items():
            if k not in self.buffers:
                self._build_buffers(single_row=values)
            # Do not flatten infos, state_out_ and (if configured) actions.
            # Infos/state-outs may be structs that change from timestep to
            # timestep.
            if (
                k == SampleBatch.INFOS
                or k.startswith("state_out_")
                or (
                    k == SampleBatch.ACTIONS
                    and not self.policy.config.get("_disable_action_flattening")
                )
            ):
                self.buffers[k][0].append(v)
            # Flatten all other columns.
            else:
                flattened = tree.flatten(v)
                for i, sub_list in enumerate(self.buffers[k]):
                    sub_list.append(flattened[i])
        self.agent_steps += 1
Пример #13
0
async def materialize_value(value: Any) -> Any:
    """Returns a structure of materialized values.

  Args:
    value: A materialized value, a value reference, or structure materialized
      values and value references to materialize.
  """
    async def _materialize(value: Any) -> Any:
        if isinstance(value, MaterializableValueReference):
            return await value.get_value()
        else:
            return value

    flattened = tree.flatten(value)
    flattened = await asyncio.gather(*[_materialize(v) for v in flattened])
    return tree.unflatten_as(value, flattened)
Пример #14
0
    def actions(self, observations):
        if 0 < self._smoothing_alpha:
            raise NotImplementedError(
                "TODO(hartikainen): Smoothing alpha temporarily dropped on tf2"
                " migration. Should add it back. See:"
                " https://github.com/rail-berkeley/softlearning/blob/46374df0294b9b5f6dbe65b9471ec491a82b6944/softlearning/policies/base_policy.py#L80")

        observations = self._filter_observations(observations)

        batch_shape = tf.shape(tree.flatten(observations)[0])[:-1]
        actions = self.action_distribution.sample(
            batch_shape, bijector_kwargs={
                self.flow_model.name: {'observations': observations}
            })

        return actions
Пример #15
0
def batch_concat(inputs: types.NestedTensor) -> tf.Tensor:
    """Concatenate a collection of Tensors while preserving the batch dimension.

  This takes a potentially nested collection of tensors, flattens everything
  but the batch (first) dimension, and concatenates along the resulting data
  (second) dimension.

  Args:
    inputs: a tensor or nested collection of tensors.

  Returns:
    A concatenated tensor which maintains the batch dimension but concatenates
    all other data along the flattened second dimension.
  """
    flat_leaves = tree.map_structure(snt.Flatten(), inputs)
    return tf.concat(tree.flatten(flat_leaves), axis=-1)
Пример #16
0
def _tree_merge_into(source, target):
    """Update `target` with content of substructure `source`."""
    path_to_index = {
        path: i
        for i, (path, _) in enumerate(tree.flatten_with_path(target))
    }

    flat_target = tree.flatten(target)
    for path, leaf in tree.flatten_with_path(source):
        if path not in path_to_index:
            raise ValueError(
                f'Cannot expand {source} into {target} as it is not a sub structure.'
            )
        flat_target[path_to_index[path]] = leaf

    return tree.unflatten_as(target, flat_target)
    def preprocess_value(params, value, weight):
        vectors = tree.map_structure(lambda v: tf.reshape(v, [-1]), value)
        norm = tf.norm(tf.concat(tree.flatten(vectors), axis=0),
                       ord=norm_order)
        quantile_record = quantile_query.preprocess_record(params, norm)

        threshold = params.current_estimate * multiplier
        too_large = (norm > threshold)
        adj_weight = tf.cond(too_large, lambda: tf.constant(0.0),
                             lambda: weight)

        weighted_value = tree.map_structure(
            lambda v: tf.math.multiply_no_nan(v, adj_weight), value)

        too_large = tf.cast(too_large, tf.int32)
        return weighted_value, adj_weight, quantile_record, too_large
Пример #18
0
def random_crop_and_resize(images,
                           target_size,
                           min_crop_fraction=0.5,
                           crop_size=None,
                           methods=tf.image.ResizeMethod.BILINEAR):
    """All tensors are cropped to the same size and resized to target size."""
    if not isinstance(target_size, (tuple, list)):
        target_size = [target_size, target_size]
    aspect_ratio = target_size[0] / target_size[1]
    batch_size = tf.shape(tree.flatten(images)[0])[0]
    bboxes = get_random_bounding_box(hw=tf.ones((batch_size, 2)),
                                     aspect_ratio=aspect_ratio,
                                     min_crop_fraction=min_crop_fraction,
                                     new_height=crop_size,
                                     dtype=tf.float32)
    return crop_and_resize(images, bboxes, target_size, methods)
Пример #19
0
    def __init__(self,
                 environment: dm_env.Environment,
                 random_prob: float = 0.):
        super().__init__(environment)
        self._random_prob = random_prob
        if not 0 <= random_prob <= 1:
            raise ValueError(
                f'random_prob ({random_prob}) must be within [0, 1]')

        self._action_spec = self._environment.action_spec()
        if not all(
                map(lambda spec: isinstance(spec, specs.DiscreteArray),
                    tree.flatten(self._action_spec))):
            raise ValueError(
                'RandomActionWrapper requires the action_spec to be a '
                'DiscreteArray or a nested DiscreteArray.')
Пример #20
0
    def __init__(
        self,
        actor_id,
        environment_module,
        environment_fn_name,
        environment_kwargs,
        network_module,
        network_fn_name,
        network_kwargs,
        adder_module,
        adder_fn_name,
        adder_kwargs,
        replay_server_address,
        variable_server_name,
        variable_server_address,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
    ):
        # Counter and Logger
        self._actor_id = actor_id
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            f'actor_{actor_id}')

        # Create the environment
        self._environment = getattr(environment_module,
                                    environment_fn_name)(**environment_kwargs)
        env_spec = acme.make_environment_spec(self._environment)

        # Create actor's network
        self._network = getattr(network_module,
                                network_fn_name)(**network_kwargs)
        tf2_utils.create_variables(self._network, [env_spec.observations])

        self._variables = tree.flatten(self._network.variables)
        self._policy = tf.function(self._network)

        # The adder is used to insert observations into replay.
        self._adder = getattr(adder_module, adder_fn_name)(
            reverb.Client(replay_server_address), **adder_kwargs)

        variable_client = reverb.TFClient(variable_server_address)
        self._variable_dataset = variable_client.dataset(
            table=variable_server_name,
            dtypes=[tf.float32 for _ in self._variables],
            shapes=[v.shape for v in self._variables])
Пример #21
0
 def add_path(self, path):
     # path: dict 6(1000)
     path = path.copy()
     # flatten 遇到dict只消除到 value 一级
     path_length = tree.flatten(path)[0].shape[0]
     path.update({
         'episode_index_forwards':
         np.arange(path_length,
                   dtype=self.fields['episode_index_forwards'].dtype)[
                       ..., np.newaxis],  #[0...999]
         'episode_index_backwards':
         np.arange(path_length,
                   dtype=self.fields['episode_index_backwards'].dtype)
         [::-1, np.newaxis],  ##[999...0]
     })
     # path: dict(7(1000))
     return self.add_samples(path)
Пример #22
0
    def forward(self, input_dict, state, seq_lens):
        if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
            orig_obs = input_dict[SampleBatch.OBS]
        else:
            orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS],
                                                   self.processed_obs_space,
                                                   tensorlib="torch")
        # Push observations through the different components
        # (CNNs, one-hot + FC, etc..).
        outs = []
        for i, component in enumerate(tree.flatten(orig_obs)):
            if i in self.cnns:
                cnn_out, _ = self.cnns[i](SampleBatch(
                    {SampleBatch.OBS: component}))
                outs.append(cnn_out)
            elif i in self.one_hot:
                if component.dtype in [torch.int32, torch.int64, torch.uint8]:
                    one_hot_in = {
                        SampleBatch.OBS:
                        one_hot(component, self.flattened_input_space[i])
                    }
                else:
                    one_hot_in = {SampleBatch.OBS: component}
                one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in))
                outs.append(one_hot_out)
            else:
                nn_out, _ = self.flatten[i](SampleBatch({
                    SampleBatch.OBS:
                    torch.reshape(component, [-1, self.flatten_dims[i]])
                }))
                outs.append(nn_out)

        # Concat all outputs and the non-image inputs.
        out = torch.cat(outs, dim=1)
        # Push through (optional) FC-stack (this may be an empty stack).
        out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out}))

        # No logits/value branches.
        if self.logits_layer is None:
            return out, []

        # Logits- and value branches.
        logits, values = self.logits_layer(out), self.value_layer(out)
        self._value_out = torch.reshape(values, [-1])
        return logits, []
Пример #23
0
    def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
        """Builds the buffers for sample collection, given an example data row.

        Args:
            single_row (Dict[str, TensorType]): A single row (keys=column
                names) of data to base the buffers on.
        """
        for col, data in single_row.items():
            if col in self.buffers:
                continue

            shift = self.shift_before - (
                1
                if col
                in [
                    SampleBatch.OBS,
                    SampleBatch.EPS_ID,
                    SampleBatch.AGENT_INDEX,
                    SampleBatch.ENV_ID,
                    SampleBatch.T,
                    SampleBatch.UNROLL_ID,
                ]
                else 0
            )

            # Store all data as flattened lists, except INFOS and state-out
            # lists. These are monolithic items (infos is a dict that
            # should not be further split, same for state-out items, which
            # could be custom dicts as well).
            if (
                col == SampleBatch.INFOS
                or col.startswith("state_out_")
                or (
                    col == SampleBatch.ACTIONS
                    and not self.policy.config.get("_disable_action_flattening")
                )
            ):
                self.buffers[col] = [[data for _ in range(shift)]]
            else:
                self.buffers[col] = [
                    [v for _ in range(shift)] for v in tree.flatten(data)
                ]
                # Store an example data struct so we know, how to unflatten
                # each data col.
                self.buffer_structs[col] = data
Пример #24
0
def flatten_data(data: Dict[str, TensorStructType]):
    assert (type(data) == dict
            ), "Single agent data must be of type Dict[str, TensorStructType]"

    flattened = {}
    for k, v in data.items():
        if k in [SampleBatch.INFOS, SampleBatch.ACTIONS
                 ] or k.startswith("state_out_"):
            # Do not flatten infos, actions, and state_out_ columns.
            flattened[k] = v
            continue
        if v is None:
            # Keep the same column shape.
            flattened[k] = None
            continue
        flattened[k] = np.array(tree.flatten(v))

    return flattened
  def testEntropyGradients(self, is_multi_actions):
    if is_multi_actions:
      loss = self.multi_op.extra.entropy_loss
      policy_logits_nest = self.multi_policy_logits
    else:
      loss = self.op.extra.entropy_loss
      policy_logits_nest = self.policy_logits

    grad_policy_list = [
        tf.gradients(loss, policy_logits)[0] * self.num_actions
        for policy_logits in nest.flatten(policy_logits_nest)]

    for grad_policy in grad_policy_list:
      self.assertEqual(grad_policy.get_shape(), tf.TensorShape([2, 1, 3]))

    self.assertAllEqual(tf.gradients(loss, self.baseline_values), [None])
    self.assertAllEqual(tf.gradients(loss, self.invalid_grad_inputs),
                        self.invalid_grad_outputs)
Пример #26
0
    def testFlattenAndUnflatten_withDicts(self):
        # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
        named_tuple = collections.namedtuple("A", ("b", "c"))
        mess = [
            "z",
            named_tuple(3, 4), {
                "c": [
                    1,
                    collections.OrderedDict([
                        ("b", 3),
                        ("a", 2),
                    ]),
                ],
                "b": 5
            }, 17
        ]

        flattened = tree.flatten(mess)
        self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])

        structure_of_mess = [
            14,
            named_tuple("a", True),
            {
                "c": [
                    0,
                    collections.OrderedDict([
                        ("b", 9),
                        ("a", 8),
                    ]),
                ],
                "b": 3
            },
            "hi everybody",
        ]

        self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened))

        # Check also that the OrderedDict was created, with the correct key order.
        unflattened_ordered_dict = tree.unflatten_as(structure_of_mess,
                                                     flattened)[2]["c"][1]
        self.assertIsInstance(unflattened_ordered_dict,
                              collections.OrderedDict)
        self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
Пример #27
0
    def __init__(self,
                 client: core.VariableSource,
                 variables: Mapping[str, Sequence[tf.Variable]],
                 update_period: int = 1):
        self._keys = list(variables.keys())
        self._variables = tree.flatten(list(variables.values()))
        self._call_counter = 0
        self._update_period = update_period
        self._client = client
        self._request = lambda: client.get_variables(self._keys)

        # Create a single background thread to fetch variables without necessarily
        # blocking the actor.
        self._executor = futures.ThreadPoolExecutor(max_workers=1)
        self._async_request = lambda: self._executor.submit(self._request)

        # Initialize this client's future to None to indicate to the `update()`
        # method that there is no pending/running request.
        self._future: Optional[futures.Future] = None
Пример #28
0
    def compute_actions(
        self,
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorType], TensorType] = None,
        prev_reward_batch: Union[List[TensorType], TensorType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs,
    ):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        builder = TFRunBuilder(self.get_session(), "compute_actions")

        input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
        if state_batches:
            for i, s in enumerate(state_batches):
                input_dict[f"state_in_{i}"] = s
        if prev_action_batch is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
        if prev_reward_batch is not None:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch

        to_fetch = self._build_compute_actions(
            builder, input_dict=input_dict, explore=explore, timestep=timestep
        )

        # Execute session run to get action (and other fetches).
        fetched = builder.get(to_fetch)

        # Update our global timestep by the batch size.
        self.global_timestep += (
            len(obs_batch)
            if isinstance(obs_batch, list)
            else tree.flatten(obs_batch)[0].shape[0]
        )

        return fetched
Пример #29
0
def flatten_data(data: Dict[str, TensorStructType]):
    assert isinstance(
        data,
        dict), "Single agent data must be of type Dict[str, TensorStructType]"

    flattened = {}
    for k, v in data.items():
        if k in [SampleBatch.INFOS, SampleBatch.ACTIONS
                 ] or k.startswith("state_out_"):
            # Do not flatten infos, actions, and state_out_ columns.
            flattened[k] = v
            continue
        if v is None:
            # Keep the same column shape.
            flattened[k] = None
            continue
        flattened[k] = np.array(tree.flatten(v))
    flattened = SampleBatch(flattened, is_training=False)

    return AgentConnectorsOutput(data, flattened)
Пример #30
0
    async def release(self, value: Any, type_signature: computation_types.Type,
                      key: int) -> None:  # pytype: disable=signature-mismatch
        """Releases `value` from a federated program.

    Args:
      value: A materialized value, a value reference, or a structure of
        materialized values and value references representing the value to
        release.
      type_signature: The `tff.Type` of `value`.
      key: An integer used to reference the released `value`.
    """
        del type_signature  # Unused.
        py_typecheck.check_type(key, int)

        path = self._get_path_for_key(key)
        materialized_value = await value_reference.materialize_value(value)
        flattened_value = tree.flatten(materialized_value)
        await file_utils.write_saved_model(flattened_value,
                                           path,
                                           overwrite=True)
Пример #31
0
    def _build_signature_def(self):
        """Build signature def map for tensorflow SavedModelBuilder.
        """
        # build input signatures
        input_signature = self._extra_input_signature_def()
        input_signature["observations"] = \
            tf.saved_model.utils.build_tensor_info(self._obs_input)

        if self._seq_lens is not None:
            input_signature["seq_lens"] = \
                tf.saved_model.utils.build_tensor_info(self._seq_lens)
        if self._prev_action_input is not None:
            input_signature["prev_action"] = \
                tf.saved_model.utils.build_tensor_info(self._prev_action_input)
        if self._prev_reward_input is not None:
            input_signature["prev_reward"] = \
                tf.saved_model.utils.build_tensor_info(self._prev_reward_input)
        input_signature["is_training"] = \
            tf.saved_model.utils.build_tensor_info(self._is_training)

        for state_input in self._state_inputs:
            input_signature[state_input.name] = \
                tf.saved_model.utils.build_tensor_info(state_input)

        # build output signatures
        output_signature = self._extra_output_signature_def()
        for i, a in enumerate(tree.flatten(self._sampled_action)):
            output_signature["actions_{}".format(i)] = \
                tf.saved_model.utils.build_tensor_info(a)

        for state_output in self._state_outputs:
            output_signature[state_output.name] = \
                tf.saved_model.utils.build_tensor_info(state_output)
        signature_def = (
            tf.saved_model.signature_def_utils.build_signature_def(
                input_signature, output_signature,
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
        signature_def_key = (tf.saved_model.signature_constants.
                             DEFAULT_SERVING_SIGNATURE_DEF_KEY)
        signature_def_map = {signature_def_key: signature_def}
        return signature_def_map