def q_retrace(rewards, dones, q_i, values, rho_i, n_envs, n_steps, gamma): """ Calculates the target Q-retrace :param rewards: ([TensorFlow Tensor]) The rewards :param dones: ([TensorFlow Tensor]) :param q_i: ([TensorFlow Tensor]) The Q values for actions taken :param values: ([TensorFlow Tensor]) The output of the value functions :param rho_i: ([TensorFlow Tensor]) The importance weight for each action :param n_envs: (int) The number of environments :param n_steps: (int) The number of steps to run for each environment :param gamma: (float) The discount value :return: ([TensorFlow Tensor]) the target Q-retrace """ rho_bar = batch_to_seq(tf.minimum(1.0, rho_i), n_envs, n_steps, True) # list of len steps, shape [n_envs] reward_seq = batch_to_seq(rewards, n_envs, n_steps, True) # list of len steps, shape [n_envs] done_seq = batch_to_seq(dones, n_envs, n_steps, True) # list of len steps, shape [n_envs] q_is = batch_to_seq(q_i, n_envs, n_steps, True) value_sequence = batch_to_seq(values, n_envs, n_steps + 1, True) final_value = value_sequence[-1] qret = final_value qrets = [] for i in range(n_steps - 1, -1, -1): check_shape([qret, done_seq[i], reward_seq[i], rho_bar[i], q_is[i], value_sequence[i]], [[n_envs]] * 6) qret = reward_seq[i] + gamma * qret * (1.0 - done_seq[i]) qrets.append(qret) qret = (rho_bar[i] * (qret - q_is[i])) + value_sequence[i] qrets = qrets[::-1] qret = seq_to_batch(qrets, flat=True) return qret
def strip(var, n_envs, n_steps, flat=False): """ Removes the last step in the batch :param var: (TensorFlow Tensor) The input Tensor :param n_envs: (int) The number of environments :param n_steps: (int) The number of steps to run for each environment :param flat: (bool) If the input Tensor is flat :return: (TensorFlow Tensor) the input tensor, without the last step in the batch """ out_vars = batch_to_seq(var, n_envs, n_steps + 1, flat) return seq_to_batch(out_vars[:-1], flat)
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None, net_arch=None, act_fun=tf.tanh, cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn", **kwargs): # state_shape = [n_lstm * 2] dim because of the cell and hidden states of the LSTM super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, state_shape=(2 * n_lstm, ), reuse=reuse, scale=(feature_extraction == "cnn")) self._kwargs_check(feature_extraction, kwargs) if net_arch is None: # Legacy mode if layers is None: layers = [64, 64] else: warnings.warn( "The layers parameter is deprecated. Use the net_arch parameter instead." ) with tf.variable_scope("model", reuse=reuse): if feature_extraction == "cnn": extracted_features = cnn_extractor(self.processed_obs, **kwargs) else: extracted_features = tf.layers.flatten(self.processed_obs) for i, layer_size in enumerate(layers): extracted_features = act_fun( linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size, init_scale=np.sqrt(2))) input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps) masks = batch_to_seq(self.dones_ph, self.n_env, n_steps) rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm, layer_norm=layer_norm) rnn_output = seq_to_batch(rnn_output) value_fn = linear(rnn_output, 'vf', 1) self._proba_distribution, self._policy, self.q_value = \ self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output) self._value_fn = value_fn else: # Use the new net_arch parameter if layers is not None: warnings.warn( "The new net_arch parameter overrides the deprecated layers parameter." ) if feature_extraction == "cnn": raise NotImplementedError() with tf.variable_scope("model", reuse=reuse): latent = tf.layers.flatten(self.processed_obs) policy_only_layers = [ ] # Layer sizes of the network that only belongs to the policy network value_only_layers = [ ] # Layer sizes of the network that only belongs to the value network # Iterate through the shared layers and build the shared parts of the network lstm_layer_constructed = False for idx, layer in enumerate(net_arch): if isinstance(layer, int): # Check that this is a shared layer layer_size = layer latent = act_fun( linear(latent, "shared_fc{}".format(idx), layer_size, init_scale=np.sqrt(2))) elif layer == "lstm": if lstm_layer_constructed: raise ValueError( "The net_arch parameter must only contain one occurrence of 'lstm'!" ) input_sequence = batch_to_seq(latent, self.n_env, n_steps) masks = batch_to_seq(self.dones_ph, self.n_env, n_steps) rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm, layer_norm=layer_norm) latent = seq_to_batch(rnn_output) lstm_layer_constructed = True else: assert isinstance( layer, dict ), "Error: the net_arch list can only contain ints and dicts" if 'pi' in layer: assert isinstance( layer['pi'], list ), "Error: net_arch[-1]['pi'] must contain a list of integers." policy_only_layers = layer['pi'] if 'vf' in layer: assert isinstance( layer['vf'], list ), "Error: net_arch[-1]['vf'] must contain a list of integers." value_only_layers = layer['vf'] break # From here on the network splits up in policy and value network # Build the non-shared part of the policy-network latent_policy = latent for idx, pi_layer_size in enumerate(policy_only_layers): if pi_layer_size == "lstm": raise NotImplementedError( "LSTMs are only supported in the shared part of the policy network." ) assert isinstance( pi_layer_size, int ), "Error: net_arch[-1]['pi'] must only contain integers." latent_policy = act_fun( linear(latent_policy, "pi_fc{}".format(idx), pi_layer_size, init_scale=np.sqrt(2))) # Build the non-shared part of the value-network latent_value = latent for idx, vf_layer_size in enumerate(value_only_layers): if vf_layer_size == "lstm": raise NotImplementedError( "LSTMs are only supported in the shared part of the value function " "network.") assert isinstance( vf_layer_size, int ), "Error: net_arch[-1]['vf'] must only contain integers." latent_value = act_fun( linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2))) if not lstm_layer_constructed: raise ValueError( "The net_arch parameter must contain at least one occurrence of 'lstm'!" ) self._value_fn = linear(latent_value, 'vf', 1) # TODO: why not init_scale = 0.001 here like in the feedforward self._proba_distribution, self._policy, self.q_value = \ self.pdtype.proba_distribution_from_latent(latent_policy, latent_value) self._setup_init()