def _get_qvalues(self): qvalues = discount_reward( self._rewards, self._sequence_length_py, discount=self._hparams.discount_factor, normalize=self._hparams.normalize_reward) return qvalues
def _train_policy(self, feed_dict=None): """Updates the policy. Args: TODO """ qvalues = discount_reward( [self._rewards], discount=self._hparams.discount_factor, normalize=self._hparams.normalize_reward) qvalues = qvalues[0, :] fetches = dict(loss=self._train_op) feed_dict_ = { self._observ_inputs: self._observs, self._action_inputs: self._actions, self._advantage_inputs: qvalues} feed_dict_.update(feed_dict or {}) self._train_outputs = self._sess.run(fetches, feed_dict=feed_dict_)
def test_discount_reward(self): """Tests :func:`texar.tf.losses.rewards.discount_reward` """ # 1D reward = np.ones([2], dtype=np.float64) sequence_length = [3, 5] discounted_reward = discount_reward(reward, sequence_length, discount=1.) discounted_reward_n = discount_reward(reward, sequence_length, discount=.1, normalize=True) discounted_reward_ = discount_reward(tf.constant(reward, dtype=tf.float64), sequence_length, discount=1.) discounted_reward_n_ = discount_reward(tf.constant(reward, dtype=tf.float64), sequence_length, discount=.1, normalize=True) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) r, r_n = sess.run([discounted_reward_, discounted_reward_n_]) np.testing.assert_array_almost_equal(discounted_reward, r, decimal=6) np.testing.assert_array_almost_equal(discounted_reward_n, r_n, decimal=6) # 2D reward = np.ones([2, 10], dtype=np.float64) sequence_length = [5, 10] discounted_reward = discount_reward(reward, sequence_length, discount=1.) discounted_reward_n = discount_reward(reward, sequence_length, discount=.1, normalize=True) discounted_reward_ = discount_reward(tf.constant(reward, dtype=tf.float64), sequence_length, discount=1., tensor_rank=2) discounted_reward_n_ = discount_reward(tf.constant(reward, dtype=tf.float64), sequence_length, discount=.1, tensor_rank=2, normalize=True) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) r, r_n = sess.run([discounted_reward_, discounted_reward_n_]) np.testing.assert_array_almost_equal(discounted_reward, r, decimal=6) np.testing.assert_array_almost_equal(discounted_reward_n, r_n, decimal=6)