Ejemplo n.º 1
0
    def _graph_fn_get_probabilities_log_probs(self, logits):
        """
        Creates properties/parameters and log-probs from some reshaped output.

        Args:
            logits (SingleDataOp): The output of some layer that is already reshaped
                according to our action Space.

        Returns:
            tuple (2x SingleDataOp):
                parameters (DataOp): The parameters, ready to be passed to a Distribution object's
                    get_distribution API-method (usually some probabilities or loc/scale pairs).
                log_probs (DataOp): Simply the log(parameters).
        """
        if get_backend() == "tf":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                parameters = tf.maximum(x=tf.nn.softmax(logits=logits,
                                                        axis=-1),
                                        y=SMALL_NUMBER)
                parameters._batch_rank = 0
                # Log probs.
                log_probs = tf.log(x=parameters)
                log_probs._batch_rank = 0
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = tf.split(value=logits,
                                        num_or_size_splits=2,
                                        axis=1)
                # Remove moments rank.
                mean = tf.squeeze(input=mean, axis=1)
                log_sd = tf.squeeze(input=log_sd, axis=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = tf.clip_by_value(t=log_sd,
                                          clip_value_min=log(SMALL_NUMBER),
                                          clip_value_max=-log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = tf.exp(x=log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(tf.log(x=mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs

        elif get_backend() == "pytorch":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                softmax_logits = torch.softmax(logits, dim=-1)
                parameters = torch.max(softmax_logits, SMALL_NUMBER_TORCH)
                # Log probs.
                log_probs = torch.log(parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = torch.split(logits,
                                           split_size_or_sections=2,
                                           dim=1)
                # Remove moments rank.
                mean = torch.squeeze(mean, dim=1)
                log_sd = torch.squeeze(log_sd, dim=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = torch.clamp(log_sd,
                                     min=LOG_SMALL_NUMBER,
                                     max=-LOG_SMALL_NUMBER)

                # Turn log sd into sd.
                sd = torch.exp(log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(torch.log(mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs
Ejemplo n.º 2
0
    def _graph_fn_call(self,
                       inputs,
                       initial_c_and_h_states=None,
                       sequence_length=None):
        """
        Args:
            inputs (SingleDataOp): The data to pass through the layer (batch of n items, m timesteps).
                Position of batch- and time-ranks in the input depend on `self.time_major` setting.

            initial_c_and_h_states (DataOpTuple): The initial cell- and hidden-states to use.
                None for the default behavior (all zeros).
                The cell-state in an LSTM is passed between cells from step to step and only affected by element-wise
                operations. The hidden state is identical to the output of the LSTM on the previous time step.

            sequence_length (Optional[SingleDataOp]): An int tensor mapping each batch item to a sequence length
                such that the remaining time slots for each batch item are filled with zeros.

        Returns:
            tuple:
                - The outputs over all timesteps of the LSTM.
                - DataOpTuple: The final cell- and hidden-states.
        """
        if get_backend() == "tf":
            # Convert to tf's LSTMStateTuple from DataOpTuple.
            if initial_c_and_h_states is not None:
                initial_c_and_h_states = tf.nn.rnn_cell.LSTMStateTuple(
                    initial_c_and_h_states[0], initial_c_and_h_states[1])

            # We are running the LSTM as a dynamic while-loop.
            if self.static_loop is False:
                lstm_out, lstm_state_tuple = tf.nn.dynamic_rnn(
                    cell=self.lstm,
                    inputs=inputs,
                    sequence_length=sequence_length,
                    initial_state=initial_c_and_h_states,
                    parallel_iterations=self.parallel_iterations,
                    swap_memory=self.swap_memory,
                    time_major=self.in_space.time_major,
                    dtype="float")
            # We are running with a fixed number of time steps (static unroll).
            else:
                # Set to zeros as tf lstm object does not handle None.
                if initial_c_and_h_states is None:
                    shape = (tf.shape(
                        inputs)[0 if self.in_space.time_major is False else 1],
                             self.units)
                    initial_c_and_h_states = tf.nn.rnn_cell.LSTMStateTuple(
                        tf.zeros(shape=shape, dtype=tf.float32),
                        tf.zeros(shape=shape, dtype=tf.float32))
                output_list = list()
                lstm_state_tuple = initial_c_and_h_states
                # TODO: Add option to reset the internal state in the middle of this loop iff some reset signal
                # TODO: (e.g. terminal) is True during the loop.
                inputs.set_shape([self.static_loop] +
                                 inputs.shape.as_list()[1:])
                #for input_, terminal in zip(tf.unstack(inputs), tf.unstack(terminals)):
                for input_ in tf.unstack(inputs):
                    # If the episode ended, the core state should be reset before the next.
                    #core_state = nest.map_structure(functools.partial(tf.where, d),
                    #                                initial_core_state, core_state)
                    output, lstm_state_tuple = self.lstm(
                        input_, lstm_state_tuple)
                    output_list.append(output)
                lstm_out = tf.stack(output_list)

            # Only return last value.
            if self.return_sequences is False:
                if self.in_space.time_major is True:
                    lstm_out = lstm_out[-1]
                else:
                    lstm_out = lstm_out[:, -1]
                lstm_out._batch_rank = 0
            # Return entire sequence.
            else:
                lstm_out._batch_rank = 0 if self.in_space.time_major is False else 1
                lstm_out._time_rank = 0 if self.in_space.time_major is True else 1

            # Returns: Unrolled-outputs (time series of all encountered h-states), final c- and h-states.
            return lstm_out, DataOpTuple(lstm_state_tuple)

        elif get_backend() == "pytorch":
            # TODO init hidden state has to be available at create variable time to use.
            inputs = torch.cat(inputs).view(len(inputs), 1, -1)
            # TODO: support `self.return_sequences` = False
            out, self.hidden_state = self.lstm(inputs, self.hidden_state)
            return out, DataOpTuple(self.hidden_state)
Ejemplo n.º 3
0
 def dtype(self):
     return DataOpTuple([c.dtype for c in self])
Ejemplo n.º 4
0
 def _graph_fn_update_alpha(self, alpha_loss, alpha_loss_per_item):
     alpha_step_op, _, _ = self.alpha_optimizer.step(
         DataOpTuple([self.log_alpha]), alpha_loss, alpha_loss_per_item)
     return alpha_step_op