Ejemplo n.º 1
0
    def debatch_timestep(self, ts):
        """Debatches a single timestep.
    Returns bs length of timesteps."""

        traj_spec = self._traj_spec

        def f(arr):
            if arr is None:
                return arr
            l = np.split(arr, len(arr))
            # remove the leading dimension
            l = list(map(functools.partial(np.squeeze, axis=0), l))
            return l

        # split along the batch dimension
        d = nest.map_structure_up_to(traj_spec, f, ts)

        # determine the batch size
        lens = [
            len(v) for v in filter(lambda k: k is not None,
                                   nest.flatten_up_to(traj_spec, d))
        ]
        bs = lens[0]
        assert all(x == bs for x in lens)

        # Flatten and replicate by packing the sequence bs times.
        d = nest.flatten_up_to(traj_spec, d)

        l = []
        for i in range(bs):
            l.append(
                nest.pack_sequence_as(
                    traj_spec, list(map(lambda k: k
                                        if k is None else k[i], d))))
        return l
Ejemplo n.º 2
0
 def split_batch(template, tf_structure):
     split_flatten = zip(*[
         tf.split(t, self.batch_size)
         for t in nest.flatten_up_to(template, tf_structure)
     ])
     return [
         nest.pack_sequence_as(template, flatten)
         for flatten in split_flatten
     ]
Ejemplo n.º 3
0
    def debatch_and_stack(self):
        """Remove the leading batch dimension and then stack on timestamp.
        Returns list of stacked timesteps for each batch."""
        traj_spec = self._traj_spec

        def f(arr):
            if arr is None:
                return arr
            l = np.split(arr, len(arr))
            # remove the leading dimension
            l = list(map(functools.partial(np.squeeze, axis=0), l))
            return l

        l = []
        for traj in self._trajs:
            # split along the batch dimension
            d = nest.map_structure_up_to(traj_spec, f, traj)

            # determine the batch size
            lens = [
                len(v) for v in filter(lambda k: k is not None,
                                       nest.flatten_up_to(traj_spec, d))
            ]
            bs = lens[0]
            assert all(x == bs for x in lens)

            # Flatten and replicate by packing the sequence bs times.
            d = nest.flatten_up_to(traj_spec, d)
            if not l:
                l = [[] for _ in range(bs)]

            for i in range(bs):
                l[i].append(
                    nest.pack_sequence_as(
                        traj_spec,
                        list(map(lambda k: k if k is None else k[i], d))))

        return list(
            map(
                functools.partial(Trajectory._stack,
                                  traj_spec=self._traj_spec), l))
Ejemplo n.º 4
0
 def flatten(self, struct_input):
   return tuple(nest.flatten_up_to(self.template_spec, struct_input))
Ejemplo n.º 5
0
 def __init__(self, fields, specs, templates):
   self.fields = fields # dict fields
   self._structure = namedlist(fields)
   self.spec = self.structure(specs) # whole data structure
   self.template_spec = self.structure(templates)
   self.flatten_spec = nest.flatten_up_to(self.template_spec, self.spec)