def append_fields(self, fields, elements): """Append items from elements to buffers specified by the names in fields. Args: fields (str|list[str]): the names used for representing the corresponding fields of the buffer. elements (list): items to be appended to the corresponding buffer """ for field, e in zip(common.as_list(fields), common.as_list(elements)): self._field_to_buffer_mapping[field].append(e)
def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info targets = common.as_list(info.target) batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latent = self._multi_step_latent_rollout(latent, self._num_unroll_steps, info.action, state) loss = 0 for i, decoder in enumerate(self._decoders): # [num_unroll_steps + 1)*B, ...] train_info = decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), targets[i]) loss_info = decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) loss += loss_info.loss loss_info = LossInfo(loss=loss, extra=loss) return AlgStep(output=latent, state=state, info=loss_info)
def pop_fields(self, fields): """Pop elements from buffers specified by the names in fields. Args: fields (str|list[str]): the names used for representing the corresponding fields of the buffer. """ for field in common.as_list(fields): self._field_to_buffer_mapping[field].pop()
def popn_fields(self, fields, n): """Pop n elements from buffers for each field specified by fields. Args: fields (str|list[str]): the names used for representing the corresponding fields of the buffer. n (int): the number of elements to pop """ for field in common.as_list(fields): del self._field_to_buffer_mapping[field][:n]
def __init__(self, buffer_fields): """The init function of RecorderBuffer. Args: buffer_fields (str|list[str]): the names used for representing the corresponding fields of the buffer. """ self._field_to_buffer_mapping = dict() self._buffer_fields = common.as_list(buffer_fields) self._create_buffer_for_each_fields()
def get_decoder(self, target_field): """Get the decoder which predicts the target specified by ``target_name``. Args: target_field (str): the name of the prediction quantity corresponding to the decoder Returns: decoder (Algorithm) """ decoder_ind = common.as_list(self._target_fields).index(target_field) return self._decoders[decoder_ind]
def _plot_value_curve(self, name, values, xticks=None, legends=None, fig_size=2, linewidth=2, height=128, width=128): """Generate the value curve for elements in values. Args: name (str): the name of the plot values (np.array|list[np.array]): each element from the list corresponding to one curve in the generated figure. If values is np.array, then a single curve will be generated for values. xticks (None|np.array): values for the x-axis of the plot. If None, a default value of ``range(len(values[0]))`` will be used. legends (None|list[str]): name for each element from values. No legends if None is provided fig_size (int): the size of the figure linewidth (int): the width of the line used in the plot height (int): the height of the rendered image in terms of pixels width (int): the width of the rendered image in terms of pixels """ values = common.as_list(values) if xticks is None: xticks = range(len(values[0])) else: assert len(xticks) == len( values[0]), ("xticks should have the " "same length as the elements of values") fig, ax = plt.subplots(figsize=(fig_size, fig_size)) for value in values: ax.plot(xticks, value, linewidth=linewidth) if legends is not None: plt.legend(legends) ax.set_title(name) img = _get_img_from_fig(fig, height=height, width=width) plt.close(fig) return img
def predict_multi_step(self, init_latent, actions, target_field=None, state=None): """Perform multi-step predictions based on the initial latent representation and actions sequences. Args: init_latent (Tensor): the latent representation for the initial step of the prediction actions (Tensor): [B, unroll_steps, action_dim] target_field (None|str|[str]): the name or a list if names of the quantities to be predicted. It is used for selecting the corresponding decoder. If None, all the available decoders will be used for generating predictions. state: Returns: prediction (Tensor|[Tensor]): predicted target of shape [B, unroll_steps + 1, d], where d is the dimension of the predicted target. The return is a list of Tensors when there are multiple targets to be predicted. """ num_unroll_steps = actions.shape[1] assert num_unroll_steps > 0 sim_latent = self._multi_step_latent_rollout(init_latent, num_unroll_steps, actions, state) predictions = [] if target_field == None: for decoder in self._decoders: predictions.append(decoder.predict_step(sim_latent).info) else: target_field = common.as_list(target_field) for field in target_field: decoder = self.get_decoder(field) predictions.append(decoder.predict_step(sim_latent).info) return predictions[0] if len(predictions) == 1 else predictions
def __init__(self, observation_spec, action_spec, num_unroll_steps, decoder_ctor, encoding_net_ctor, dynamics_net_ctor, encoding_optimizer=None, dynamics_optimizer=None, debug_summaries=False, name="PredictiveRepresentationLearner"): """ Args: observation_spec (nested TensorSpec): describing the observation. action_spec (nested BoundedTensorSpec): describing the action. num_unroll_steps (int): the number of future steps to predict. decoder_ctor (Callable|[Callable]): each individual constructor is called as ``decoder_ctor(observation)`` to construct the decoder algorithm. It should follow the ``Algorithm`` interface. In addition to the interface of ``Algorithm``, it should also implement a member function ``get_target_fields()``, which returns a nest of the names of target fields. See ``SimpleDecoder`` for an example of decoder. encoding_net_ctor (Callable): called as ``encoding_net_ctor(observation_spec)`` to construct the encoding ``Network``. The network takes raw observation as input and output the latent representation. encoding_net can be an RNN. dynamics_net_ctor (Callable): called as ``dynamics_net_ctor(action_spec)`` to construct the dynamics ``Network``. It must be an RNN. The constructed network takes action as input and outputs the future latent representation. If the state_spec of the dynamics net is exactly same as the state_spec of the encoding net, the current state of the encoding net will be used as the initial state of the dynamics net. Otherwise, a linear projection will be used to convert the current latent represenation to the initial state for the dynamics net. encoding_optimizer (Optimizer|None): if provided, will be used to optimize the parameter for the encoding net. dynamics_optimizer (Optimizer|None): if provided, will be used to optimize the parameter for the dynamics net. debug_summaries (bool): whether to generate debug summaries name (str): name of this instance. """ encoding_net = encoding_net_ctor(observation_spec) super().__init__(train_state_spec=encoding_net.state_spec, debug_summaries=debug_summaries, name=name) self._encoding_net = encoding_net if encoding_optimizer is not None: self.add_optimizer(encoding_optimizer, [self._encoding_net]) repr_spec = self._encoding_net.output_spec decoder_ctors = common.as_list(decoder_ctor) self._decoders = torch.nn.ModuleList() self._target_fields = [] for decoder_ctor in decoder_ctors: decoder = decoder_ctor(repr_spec, debug_summaries=debug_summaries, append_target_field_to_name=True, name=name + ".decoder") target_field = decoder.get_target_fields() self._decoders.append(decoder) assert len(alf.nest.flatten(decoder.train_state_spec)) == 0, ( "RNN decoder is not suported") self._target_fields.append(target_field) if len(self._target_fields) == 1: self._target_fields = self._target_fields[0] self._num_unroll_steps = num_unroll_steps self._output_spec = repr_spec if num_unroll_steps > 0: self._dynamics_net = dynamics_net_ctor(action_spec) self._dynamics_state_dims = alf.nest.map_structure( lambda spec: spec.numel, alf.nest.flatten(self._dynamics_net.state_spec)) assert sum( self._dynamics_state_dims) > 0, ("dynamics_net should be RNN") compatible_state = True try: alf.nest.assert_same_structure(self._dynamics_net.state_spec, self._encoding_net.state_spec) compatible_state = all( alf.nest.flatten( alf.nest.map_structure(lambda s1, s2: s1 == s2, self._dynamics_net.state_spec, self._encoding_net.state_spec))) except Exception: compatible_state = False self._latent_to_dstate_fc = None modules = [self._dynamics_net] if not compatible_state: self._latent_to_dstate_fc = alf.layers.FC( repr_spec.numel, sum(self._dynamics_state_dims)) modules.append(self._latent_to_dstate_fc) if dynamics_optimizer is not None: self.add_optimizer(dynamics_optimizer, modules)