Exemplo n.º 1
0
 def append_fields(self, fields, elements):
     """Append items from elements to buffers specified by the names in fields.
     Args:
         fields (str|list[str]): the names used for representing
             the corresponding fields of the buffer.
         elements (list): items to be appended to the corresponding buffer
     """
     for field, e in zip(common.as_list(fields), common.as_list(elements)):
         self._field_to_buffer_mapping[field].append(e)
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        targets = common.as_list(info.target)
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latent = self._multi_step_latent_rollout(latent,
                                                     self._num_unroll_steps,
                                                     info.action, state)

        loss = 0
        for i, decoder in enumerate(self._decoders):
            # [num_unroll_steps + 1)*B, ...]
            train_info = decoder.train_step(sim_latent).info
            train_info_spec = dist_utils.extract_spec(train_info)
            train_info = dist_utils.distributions_to_params(train_info)
            train_info = alf.nest.map_structure(
                lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                    shape[1:]), train_info)
            # [num_unroll_steps + 1, B, ...]
            train_info = dist_utils.params_to_distributions(
                train_info, train_info_spec)
            target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                            targets[i])
            loss_info = decoder.calc_loss(target, train_info, info.mask.t())
            loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0),
                                               loss_info)
            loss += loss_info.loss

        loss_info = LossInfo(loss=loss, extra=loss)

        return AlgStep(output=latent, state=state, info=loss_info)
Exemplo n.º 3
0
 def pop_fields(self, fields):
     """Pop elements from buffers specified by the names in fields.
     Args:
         fields (str|list[str]): the names used for representing
             the corresponding fields of the buffer.
     """
     for field in common.as_list(fields):
         self._field_to_buffer_mapping[field].pop()
Exemplo n.º 4
0
 def popn_fields(self, fields, n):
     """Pop n elements from buffers for each field specified by fields.
     Args:
         fields (str|list[str]): the names used for representing
             the corresponding fields of the buffer.
         n (int): the number of elements to pop
     """
     for field in common.as_list(fields):
         del self._field_to_buffer_mapping[field][:n]
Exemplo n.º 5
0
 def __init__(self, buffer_fields):
     """The init function of RecorderBuffer.
     Args:
         buffer_fields (str|list[str]): the names used for representing
             the corresponding fields of the buffer.
     """
     self._field_to_buffer_mapping = dict()
     self._buffer_fields = common.as_list(buffer_fields)
     self._create_buffer_for_each_fields()
 def get_decoder(self, target_field):
     """Get the decoder which predicts the target specified by ``target_name``.
     Args:
         target_field (str): the name of the prediction quantity corresponding
             to the decoder
     Returns:
         decoder (Algorithm)
     """
     decoder_ind = common.as_list(self._target_fields).index(target_field)
     return self._decoders[decoder_ind]
Exemplo n.º 7
0
    def _plot_value_curve(self,
                          name,
                          values,
                          xticks=None,
                          legends=None,
                          fig_size=2,
                          linewidth=2,
                          height=128,
                          width=128):
        """Generate the value curve for elements in values.
        Args:
            name (str): the name of the plot
            values (np.array|list[np.array]): each element from the list
                corresponding to one curve in the generated figure. If values
                is np.array, then a single curve will be generated for values.
            xticks (None|np.array): values for the x-axis of the plot. If None,
                a default value of ``range(len(values[0]))`` will be used.
            legends (None|list[str]): name for each element from values. No
                legends if None is provided
            fig_size (int): the size of the figure
            linewidth (int): the width of the line used in the plot
            height (int): the height of the rendered image in terms of pixels
            width (int): the width of the rendered image in terms of pixels
        """

        values = common.as_list(values)

        if xticks is None:
            xticks = range(len(values[0]))
        else:
            assert len(xticks) == len(
                values[0]), ("xticks should have the "
                             "same length as the elements of values")

        fig, ax = plt.subplots(figsize=(fig_size, fig_size))

        for value in values:
            ax.plot(xticks, value, linewidth=linewidth)
        if legends is not None:
            plt.legend(legends)
        ax.set_title(name)

        img = _get_img_from_fig(fig, height=height, width=width)
        plt.close(fig)
        return img
    def predict_multi_step(self,
                           init_latent,
                           actions,
                           target_field=None,
                           state=None):
        """Perform multi-step predictions based on the initial latent
            representation and actions sequences.
        Args:
            init_latent (Tensor): the latent representation for the initial
                step of the prediction
            actions (Tensor): [B, unroll_steps, action_dim]
            target_field (None|str|[str]): the name or a list if names of the
                quantities to be predicted. It is used for selecting the
                corresponding decoder. If None, all the available decoders will
                be used for generating predictions.
            state:
        Returns:
            prediction (Tensor|[Tensor]): predicted target of shape
                [B, unroll_steps + 1, d], where d is the dimension of
                the predicted target. The return is a list of Tensors when
                there are multiple targets to be predicted.
        """
        num_unroll_steps = actions.shape[1]

        assert num_unroll_steps > 0

        sim_latent = self._multi_step_latent_rollout(init_latent,
                                                     num_unroll_steps, actions,
                                                     state)

        predictions = []
        if target_field == None:
            for decoder in self._decoders:
                predictions.append(decoder.predict_step(sim_latent).info)
        else:
            target_field = common.as_list(target_field)
            for field in target_field:
                decoder = self.get_decoder(field)
                predictions.append(decoder.predict_step(sim_latent).info)

        return predictions[0] if len(predictions) == 1 else predictions
    def __init__(self,
                 observation_spec,
                 action_spec,
                 num_unroll_steps,
                 decoder_ctor,
                 encoding_net_ctor,
                 dynamics_net_ctor,
                 encoding_optimizer=None,
                 dynamics_optimizer=None,
                 debug_summaries=False,
                 name="PredictiveRepresentationLearner"):
        """
        Args:
            observation_spec (nested TensorSpec): describing the observation.
            action_spec (nested BoundedTensorSpec): describing the action.
            num_unroll_steps (int): the number of future steps to predict.
            decoder_ctor (Callable|[Callable]): each individual constructor is
                called as ``decoder_ctor(observation)`` to
                construct the decoder algorithm. It should follow the ``Algorithm``
                interface. In addition to the interface of ``Algorithm``, it should
                also implement a member function ``get_target_fields()``, which
                returns a nest of the names of target fields. See ``SimpleDecoder``
                for an example of decoder.
            encoding_net_ctor (Callable): called as ``encoding_net_ctor(observation_spec)``
                to construct the encoding ``Network``. The network takes raw observation
                as input and output the latent representation. encoding_net can
                be an RNN.
            dynamics_net_ctor (Callable): called as ``dynamics_net_ctor(action_spec)``
                to construct the dynamics ``Network``. It must be an RNN. The
                constructed network takes action as input and outputs the future
                latent representation. If the state_spec of the dynamics net is
                exactly same as the state_spec of the encoding net, the current
                state of the encoding net will be used as the initial state of the
                dynamics net. Otherwise, a linear projection will be used to
                convert the current latent represenation to the initial state for
                the dynamics net.
            encoding_optimizer (Optimizer|None): if provided, will be used to optimize
                the parameter for the encoding net.
            dynamics_optimizer (Optimizer|None): if provided, will be used to optimize
                the parameter for the dynamics net.
            debug_summaries (bool): whether to generate debug summaries
            name (str): name of this instance.
        """
        encoding_net = encoding_net_ctor(observation_spec)
        super().__init__(train_state_spec=encoding_net.state_spec,
                         debug_summaries=debug_summaries,
                         name=name)

        self._encoding_net = encoding_net
        if encoding_optimizer is not None:
            self.add_optimizer(encoding_optimizer, [self._encoding_net])
        repr_spec = self._encoding_net.output_spec

        decoder_ctors = common.as_list(decoder_ctor)

        self._decoders = torch.nn.ModuleList()
        self._target_fields = []

        for decoder_ctor in decoder_ctors:
            decoder = decoder_ctor(repr_spec,
                                   debug_summaries=debug_summaries,
                                   append_target_field_to_name=True,
                                   name=name + ".decoder")
            target_field = decoder.get_target_fields()
            self._decoders.append(decoder)

            assert len(alf.nest.flatten(decoder.train_state_spec)) == 0, (
                "RNN decoder is not suported")
            self._target_fields.append(target_field)

        if len(self._target_fields) == 1:
            self._target_fields = self._target_fields[0]

        self._num_unroll_steps = num_unroll_steps

        self._output_spec = repr_spec

        if num_unroll_steps > 0:
            self._dynamics_net = dynamics_net_ctor(action_spec)
            self._dynamics_state_dims = alf.nest.map_structure(
                lambda spec: spec.numel,
                alf.nest.flatten(self._dynamics_net.state_spec))
            assert sum(
                self._dynamics_state_dims) > 0, ("dynamics_net should be RNN")
            compatible_state = True

            try:
                alf.nest.assert_same_structure(self._dynamics_net.state_spec,
                                               self._encoding_net.state_spec)
                compatible_state = all(
                    alf.nest.flatten(
                        alf.nest.map_structure(lambda s1, s2: s1 == s2,
                                               self._dynamics_net.state_spec,
                                               self._encoding_net.state_spec)))
            except Exception:
                compatible_state = False

            self._latent_to_dstate_fc = None
            modules = [self._dynamics_net]
            if not compatible_state:
                self._latent_to_dstate_fc = alf.layers.FC(
                    repr_spec.numel, sum(self._dynamics_state_dims))
                modules.append(self._latent_to_dstate_fc)

            if dynamics_optimizer is not None:
                self.add_optimizer(dynamics_optimizer, modules)