Ejemplo n.º 1
0
    def test_transform_nest(self):
        ntuple = NTuple(
            a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
            b=torch.zeros((4, )))
        transformed_ntuple = transform_nest(
            ntuple, field='a.x', func=lambda x: x + 1.0)
        ntuple.a.update({'x': torch.ones(())})
        nest.map_structure(self.assertEqual, transformed_ntuple, ntuple)

        ntuple = NTuple(
            a=dict(x=torch.zeros(()), y=torch.zeros((2, 4))),
            b=NTuple(a=torch.zeros((4, )), b=NTuple(a=[1], b=[1])))
        transformed_ntuple = transform_nest(
            ntuple, field='b.b.b', func=lambda _: [2])
        ntuple = ntuple._replace(
            b=ntuple.b._replace(b=ntuple.b.b._replace(b=[2])))
        nest.map_structure(self.assertEqual, transformed_ntuple, ntuple)

        ntuple = NTuple(a=1, b=2)
        transformed_ntuple = transform_nest(ntuple, None, NestSum())
        self.assertEqual(transformed_ntuple, 3)

        tuples = [("a", 12), ("b", 13)]
        nested = collections.OrderedDict(tuples)

        def _check_path(path, e):
            self.assertEqual(nested[path], e)

        res = nest.py_map_structure_with_path(_check_path, nested)
        nest.assert_same_structure(nested, res)
Ejemplo n.º 2
0
 def observation(self, observation):
     for field in self._fields:
         observation = transform_nest(
             nested=observation,
             field=field,
             func=self.transform_observation)
     return observation
Ejemplo n.º 3
0
def image_scale_transformer(observation, fields=None, min=-1.0, max=1.0):
    """Scale image to min and max (0->min, 255->max).

    Args:
        observation (nested Tensor): If observation is a nested structure, only
            ``namedtuple`` and ``dict`` are supported for now.
        fields (list[str]): the fields to be applied with the transformation. If
            None, then ``observation`` must be a ``Tensor`` with dtype ``uint8``.
            A field str can be a multi-step path denoted by "A.B.C".
        min (float): normalize minimum to this value
        max (float): normalize maximum to this value
    Returns:
        Transfromed observation
    """

    def _transform_image(obs):
        assert isinstance(obs, torch.Tensor), str(type(obs)) + ' is not Tensor'
        assert obs.dtype == torch.uint8, "Image must have dtype uint8!"
        obs = obs.type(torch.float32)
        return ((max - min) / 255.) * obs + min

    fields = fields or [None]
    for field in fields:
        observation = nest.transform_nest(
            nested=observation, field=field, func=_transform_image)
    return observation