def testTruncation(self): observation = -1 reward = 2.0 discount = 1.0 time_step = ts.truncation(observation, reward, discount) self.assertEqual(ts.StepType.LAST, time_step.step_type) self.assertEqual(-1, time_step.observation) self.assertEqual(2.0, time_step.reward) self.assertEqual(1.0, time_step.discount)
def testTruncation(self): observation = tf.constant(-1) reward = tf.constant(2.0) discount = tf.constant(1.0) time_step = ts.truncation(observation, reward, discount) time_step_ = self.evaluate(time_step) self.assertEqual(ts.StepType.LAST, time_step_.step_type) self.assertEqual(-1, time_step_.observation) self.assertEqual(2.0, time_step_.reward) self.assertEqual(1.0, time_step_.discount)
def testTruncationBatched(self): observation = np.array([[-1], [-1]]) reward = np.array([2., 2.]) discount = np.array([1., 1.]) time_step = ts.truncation(observation, reward, discount) self.assertItemsEqual([ts.StepType.LAST] * 2, time_step.step_type) self.assertItemsEqual(observation, time_step.observation) self.assertItemsEqual(reward, time_step.reward) self.assertItemsEqual(discount, time_step.discount)
def testTruncationMultiRewards(self): observation = np.array([[-1], [-1]]) reward = [np.array([[2.], [2.]]), np.array([[3., 3.], [4., 4.]])] discount = np.array([1., 1.]) time_step = ts.truncation(observation, reward, discount) time_step_with_outerdims = ts.truncation( observation, reward, discount, outer_dims=[2]) self.assertItemsEqual([ts.StepType.LAST] * 2, time_step.step_type) self.assertItemsEqual( [ts.StepType.LAST] * 2, time_step_with_outerdims.step_type) self.assertItemsEqual(observation, time_step.observation) self.assertItemsEqual(observation, time_step_with_outerdims.observation) self.assertAllEqual(reward[0], time_step.reward[0]) self.assertAllEqual(reward[1], time_step.reward[1]) self.assertAllEqual(reward[0], time_step_with_outerdims.reward[0]) self.assertAllEqual(reward[1], time_step_with_outerdims.reward[1]) self.assertItemsEqual(discount, time_step.discount) self.assertItemsEqual(discount, time_step_with_outerdims.discount)
def testTruncationMultiRewards(self): observation = tf.constant([[-1], [-1]]) reward = [tf.constant([[2.], [2.]]), tf.constant([[3., 3.], [4., 4.]])] discount = tf.constant(0.5) time_step = ts.truncation(observation, reward, discount) time_step_ = self.evaluate(time_step) time_step_with_outerdims = ts.truncation( observation, reward, discount, outer_dims=[2]) time_step_with_outerdims_ = self.evaluate(time_step_with_outerdims) self.assertItemsEqual([ts.StepType.LAST] * 2, time_step_.step_type) self.assertItemsEqual([ts.StepType.LAST] * 2, time_step_with_outerdims_.step_type) self.assertItemsEqual([-1, -1], time_step_.observation) self.assertItemsEqual([-1, -1], time_step_with_outerdims_.observation) self.assertAllEqual(reward[0], time_step_.reward[0]) self.assertAllEqual(reward[1], time_step_.reward[1]) self.assertAllEqual(reward[0], time_step_with_outerdims_.reward[0]) self.assertAllEqual(reward[1], time_step_with_outerdims_.reward[1]) self.assertItemsEqual([0.5, 0.5], time_step_.discount)