Exemplo n.º 1
0
    def test_transform_nest(self):
        ntuple = NTuple(
            a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
            b=torch.zeros((4, )))
        transformed_ntuple = transform_nest(
            ntuple, field='a.x', func=lambda x: x + 1.0)
        ntuple.a.update({'x': torch.ones(())})
        nest.map_structure(self.assertEqual, transformed_ntuple, ntuple)

        ntuple = NTuple(
            a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
            b=NTuple(a=torch.zeros((4, )), b=NTuple(a=[1], b=[1])))
        transformed_ntuple = transform_nest(
            ntuple, field='b.b.b', func=lambda _: [2])
        ntuple = ntuple._replace(
            b=ntuple.b._replace(b=ntuple.b.b._replace(b=[2])))
        nest.map_structure(self.assertEqual, transformed_ntuple, ntuple)

        ntuple = NTuple(a=1, b=2)
        transformed_ntuple = transform_nest(ntuple, None, NestSum())
        self.assertEqual(transformed_ntuple, 3)

        tuples = [("a", 12), ("b", 13)]
        nested = collections.OrderedDict(tuples)

        def _check_path(path, e):
            self.assertEqual(nested[path], e)

        res = nest.py_map_structure_with_path(_check_path, nested)
        nest.assert_same_structure(nested, res)
Exemplo n.º 2
0
 def _predict(self,
              inputs=None,
              noise=None,
              batch_size=None,
              training=True):
     if inputs is None:
         assert self._input_tensor_spec is None
         if noise is None:
             assert batch_size is not None
             noise = torch.randn(batch_size, self._noise_dim)
         gen_inputs = noise
     else:
         nest.assert_same_structure(inputs, self._input_tensor_spec)
         batch_size = nest.get_nest_batch_size(inputs)
         if noise is None:
             noise = torch.randn(batch_size, self._noise_dim)
         else:
             assert noise.shape[0] == batch_size
             assert noise.shape[1] == self._noise_dim
         gen_inputs = [noise, inputs]
     if self._predict_net and not training:
         outputs = self._predict_net(gen_inputs)[0]
     else:
         outputs = self._net(gen_inputs)[0]
     return outputs, gen_inputs
Exemplo n.º 3
0
    def build_distribution(self, input_params):
        """Build a Distribution using ``input_params``.

        Args:
            input_params (nested Tensor): the parameters for build the
                distribution. It should match ``input_params_spec`` provided as
                ``__init__``.
        Returns:
            Distribution:
        """
        nest.assert_same_structure(input_params, self.input_params_spec)
        return self.builder(**input_params)
Exemplo n.º 4
0
def compute_log_probability(distributions, actions):
    """Computes log probability of actions given distribution.

    Args:
        distributions: A possibly batched tuple of distributions.
        actions: A possibly batched action tuple.

    Returns:
        Tensor: the log probability summed over actions in the batch.
    """
    def _compute_log_prob(single_distribution, single_action):
        single_log_prob = single_distribution.log_prob(single_action)
        return single_log_prob

    nest.assert_same_structure(distributions, actions)
    log_probs = nest.map_structure(_compute_log_prob, distributions, actions)
    total_log_probs = sum(nest.flatten(log_probs))
    return total_log_probs