def _graph_fn_get_probabilities_log_probs(self, logits): """ Creates properties/parameters and log-probs from some reshaped output. Args: logits (SingleDataOp): The output of some layer that is already reshaped according to our action Space. Returns: tuple (2x SingleDataOp): parameters (DataOp): The parameters, ready to be passed to a Distribution object's get_distribution API-method (usually some probabilities or loc/scale pairs). log_probs (DataOp): Simply the log(parameters). """ if get_backend() == "tf": if isinstance(self.action_space, IntBox): # Discrete actions. parameters = tf.maximum(x=tf.nn.softmax(logits=logits, axis=-1), y=SMALL_NUMBER) parameters._batch_rank = 0 # Log probs. log_probs = tf.log(x=parameters) log_probs._batch_rank = 0 elif isinstance(self.action_space, FloatBox): # Continuous actions. mean, log_sd = tf.split(value=logits, num_or_size_splits=2, axis=1) # Remove moments rank. mean = tf.squeeze(input=mean, axis=1) log_sd = tf.squeeze(input=log_sd, axis=1) # Clip log_sd. log(SMALL_NUMBER) is negative. log_sd = tf.clip_by_value(t=log_sd, clip_value_min=log(SMALL_NUMBER), clip_value_max=-log(SMALL_NUMBER)) # Turn log sd into sd. sd = tf.exp(x=log_sd) parameters = DataOpTuple(mean, sd) log_probs = DataOpTuple(tf.log(x=mean), log_sd) else: raise NotImplementedError return parameters, log_probs elif get_backend() == "pytorch": if isinstance(self.action_space, IntBox): # Discrete actions. softmax_logits = torch.softmax(logits, dim=-1) parameters = torch.max(softmax_logits, SMALL_NUMBER_TORCH) # Log probs. log_probs = torch.log(parameters) elif isinstance(self.action_space, FloatBox): # Continuous actions. mean, log_sd = torch.split(logits, split_size_or_sections=2, dim=1) # Remove moments rank. mean = torch.squeeze(mean, dim=1) log_sd = torch.squeeze(log_sd, dim=1) # Clip log_sd. log(SMALL_NUMBER) is negative. log_sd = torch.clamp(log_sd, min=LOG_SMALL_NUMBER, max=-LOG_SMALL_NUMBER) # Turn log sd into sd. sd = torch.exp(log_sd) parameters = DataOpTuple(mean, sd) log_probs = DataOpTuple(torch.log(mean), log_sd) else: raise NotImplementedError return parameters, log_probs
def _graph_fn_call(self, inputs, initial_c_and_h_states=None, sequence_length=None): """ Args: inputs (SingleDataOp): The data to pass through the layer (batch of n items, m timesteps). Position of batch- and time-ranks in the input depend on `self.time_major` setting. initial_c_and_h_states (DataOpTuple): The initial cell- and hidden-states to use. None for the default behavior (all zeros). The cell-state in an LSTM is passed between cells from step to step and only affected by element-wise operations. The hidden state is identical to the output of the LSTM on the previous time step. sequence_length (Optional[SingleDataOp]): An int tensor mapping each batch item to a sequence length such that the remaining time slots for each batch item are filled with zeros. Returns: tuple: - The outputs over all timesteps of the LSTM. - DataOpTuple: The final cell- and hidden-states. """ if get_backend() == "tf": # Convert to tf's LSTMStateTuple from DataOpTuple. if initial_c_and_h_states is not None: initial_c_and_h_states = tf.nn.rnn_cell.LSTMStateTuple( initial_c_and_h_states[0], initial_c_and_h_states[1]) # We are running the LSTM as a dynamic while-loop. if self.static_loop is False: lstm_out, lstm_state_tuple = tf.nn.dynamic_rnn( cell=self.lstm, inputs=inputs, sequence_length=sequence_length, initial_state=initial_c_and_h_states, parallel_iterations=self.parallel_iterations, swap_memory=self.swap_memory, time_major=self.in_space.time_major, dtype="float") # We are running with a fixed number of time steps (static unroll). else: # Set to zeros as tf lstm object does not handle None. if initial_c_and_h_states is None: shape = (tf.shape( inputs)[0 if self.in_space.time_major is False else 1], self.units) initial_c_and_h_states = tf.nn.rnn_cell.LSTMStateTuple( tf.zeros(shape=shape, dtype=tf.float32), tf.zeros(shape=shape, dtype=tf.float32)) output_list = list() lstm_state_tuple = initial_c_and_h_states # TODO: Add option to reset the internal state in the middle of this loop iff some reset signal # TODO: (e.g. terminal) is True during the loop. inputs.set_shape([self.static_loop] + inputs.shape.as_list()[1:]) #for input_, terminal in zip(tf.unstack(inputs), tf.unstack(terminals)): for input_ in tf.unstack(inputs): # If the episode ended, the core state should be reset before the next. #core_state = nest.map_structure(functools.partial(tf.where, d), # initial_core_state, core_state) output, lstm_state_tuple = self.lstm( input_, lstm_state_tuple) output_list.append(output) lstm_out = tf.stack(output_list) # Only return last value. if self.return_sequences is False: if self.in_space.time_major is True: lstm_out = lstm_out[-1] else: lstm_out = lstm_out[:, -1] lstm_out._batch_rank = 0 # Return entire sequence. else: lstm_out._batch_rank = 0 if self.in_space.time_major is False else 1 lstm_out._time_rank = 0 if self.in_space.time_major is True else 1 # Returns: Unrolled-outputs (time series of all encountered h-states), final c- and h-states. return lstm_out, DataOpTuple(lstm_state_tuple) elif get_backend() == "pytorch": # TODO init hidden state has to be available at create variable time to use. inputs = torch.cat(inputs).view(len(inputs), 1, -1) # TODO: support `self.return_sequences` = False out, self.hidden_state = self.lstm(inputs, self.hidden_state) return out, DataOpTuple(self.hidden_state)
def dtype(self): return DataOpTuple([c.dtype for c in self])
def _graph_fn_update_alpha(self, alpha_loss, alpha_loss_per_item): alpha_step_op, _, _ = self.alpha_optimizer.step( DataOpTuple([self.log_alpha]), alpha_loss, alpha_loss_per_item) return alpha_step_op