예제 #1
0
  def testTruncation(self):
    observation = -1
    reward = 2.0
    discount = 1.0
    time_step = ts.truncation(observation, reward, discount)

    self.assertEqual(ts.StepType.LAST, time_step.step_type)
    self.assertEqual(-1, time_step.observation)
    self.assertEqual(2.0, time_step.reward)
    self.assertEqual(1.0, time_step.discount)
예제 #2
0
 def testTruncation(self):
   observation = tf.constant(-1)
   reward = tf.constant(2.0)
   discount = tf.constant(1.0)
   time_step = ts.truncation(observation, reward, discount)
   time_step_ = self.evaluate(time_step)
   self.assertEqual(ts.StepType.LAST, time_step_.step_type)
   self.assertEqual(-1, time_step_.observation)
   self.assertEqual(2.0, time_step_.reward)
   self.assertEqual(1.0, time_step_.discount)
예제 #3
0
  def testTruncationBatched(self):
    observation = np.array([[-1], [-1]])
    reward = np.array([2., 2.])
    discount = np.array([1., 1.])
    time_step = ts.truncation(observation, reward, discount)

    self.assertItemsEqual([ts.StepType.LAST] * 2, time_step.step_type)
    self.assertItemsEqual(observation, time_step.observation)
    self.assertItemsEqual(reward, time_step.reward)
    self.assertItemsEqual(discount, time_step.discount)
예제 #4
0
  def testTruncationMultiRewards(self):
    observation = np.array([[-1], [-1]])
    reward = [np.array([[2.], [2.]]),
              np.array([[3., 3.], [4., 4.]])]
    discount = np.array([1., 1.])
    time_step = ts.truncation(observation, reward, discount)

    time_step_with_outerdims = ts.truncation(
        observation, reward, discount, outer_dims=[2])

    self.assertItemsEqual([ts.StepType.LAST] * 2, time_step.step_type)
    self.assertItemsEqual(
        [ts.StepType.LAST] * 2, time_step_with_outerdims.step_type)
    self.assertItemsEqual(observation, time_step.observation)
    self.assertItemsEqual(observation, time_step_with_outerdims.observation)
    self.assertAllEqual(reward[0], time_step.reward[0])
    self.assertAllEqual(reward[1], time_step.reward[1])
    self.assertAllEqual(reward[0], time_step_with_outerdims.reward[0])
    self.assertAllEqual(reward[1], time_step_with_outerdims.reward[1])
    self.assertItemsEqual(discount, time_step.discount)
    self.assertItemsEqual(discount, time_step_with_outerdims.discount)
예제 #5
0
  def testTruncationMultiRewards(self):
    observation = tf.constant([[-1], [-1]])
    reward = [tf.constant([[2.], [2.]]),
              tf.constant([[3., 3.], [4., 4.]])]
    discount = tf.constant(0.5)
    time_step = ts.truncation(observation, reward, discount)
    time_step_ = self.evaluate(time_step)

    time_step_with_outerdims = ts.truncation(
        observation, reward, discount, outer_dims=[2])
    time_step_with_outerdims_ = self.evaluate(time_step_with_outerdims)

    self.assertItemsEqual([ts.StepType.LAST] * 2, time_step_.step_type)
    self.assertItemsEqual([ts.StepType.LAST] * 2,
                          time_step_with_outerdims_.step_type)
    self.assertItemsEqual([-1, -1], time_step_.observation)
    self.assertItemsEqual([-1, -1], time_step_with_outerdims_.observation)
    self.assertAllEqual(reward[0], time_step_.reward[0])
    self.assertAllEqual(reward[1], time_step_.reward[1])
    self.assertAllEqual(reward[0], time_step_with_outerdims_.reward[0])
    self.assertAllEqual(reward[1], time_step_with_outerdims_.reward[1])
    self.assertItemsEqual([0.5, 0.5], time_step_.discount)