コード例 #1
0
  def testNotEqualOtherClass(self):
    spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
    spec_2 = None
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)

    spec_2 = ()
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)
コード例 #2
0
  def testNotEqualOtherClass(self):
    spec_1 = array_spec.BoundedArraySpec(
        (1, 2), np.int32, minimum=[0.0, -0.6], maximum=[1.0, 1.0])
    spec_2 = array_spec.ArraySpec((1, 2), np.int32)
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)

    spec_2 = None
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)

    spec_2 = ()
    self.assertNotEqual(spec_1, spec_2)
    self.assertNotEqual(spec_2, spec_1)
コード例 #3
0
ファイル: safety_game.py プロジェクト: DustinSeltz/Rainbow
    def _compute_observation_spec(self):
        """Helper for `__init__`: compute our environment's observation spec."""
        # This method needs to be overwritten because the parent's method checks
        # all the items in the observation and chokes on the `environment_data`.

        #print("_compute_observation_spec not overwritten") #TODO is this important? Says it should be overwritten and it is hitting here.

        # Start an environment, examine the values it gives to us, and reset things
        # back to default.
        timestep = self.reset()
        observation_spec = {
            k: specs.ArraySpec(v.shape, v.dtype, name=k)
            for k, v in six.iteritems(timestep.observation)
            if k != EXTRA_OBSERVATIONS
        }
        observation_spec[EXTRA_OBSERVATIONS] = dict()
        self._drop_last_episode()
        return observation_spec
コード例 #4
0
    def _compute_observation_spec(self):
        """Helper for `__init__`: compute our environment's observation spec."""
        # Start an environment, examine the values it gives to us, and reset things
        # back to default.
        timestep = self.reset()
        observation_spec = {
            k: specs.ArraySpec(v.shape, v.dtype, name=k)
            for k, v in six.iteritems(timestep.observation)
        }
        # As long as we've got environment result data, we try checking to make sure
        # that the reward types can be added together---a very weak way of measuring
        # whether they are compatible.
        if timestep.reward is not None:
            try:
                _ = timestep.reward + self._default_reward
            except TypeError:
                raise TypeError(
                    'A pycolab game wrapped by an Environment adapter returned '
                    'a first reward whose type is incompatible with the default reward '
                    "given to the adapter's `__init__`.")

        self._drop_last_episode()
        return observation_spec
コード例 #5
0
 def testGenerateValue(self):
   spec = array_spec.ArraySpec((1, 2), np.int32)
   test_value = spec.generate_value()
   spec.validate(test_value)
コード例 #6
0
 def testValidateShape(self):
   spec = array_spec.ArraySpec((1, 2), np.int32)
   spec.validate(np.zeros((1, 2), dtype=np.int32))
   with self.assertRaises(ValueError):
     spec.validate(np.zeros((1, 2, 3), dtype=np.int32))
コード例 #7
0
 def testNotEqualDifferentDtype(self):
   spec_1 = array_spec.ArraySpec((1, 2, 3), np.int64)
   spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
   self.assertNotEqual(spec_1, spec_2)
コード例 #8
0
 def testEqual(self):
   spec_1 = array_spec.ArraySpec((1, 2, 3), np.int32)
   spec_2 = array_spec.ArraySpec((1, 2, 3), np.int32)
   self.assertEqual(spec_1, spec_2)
コード例 #9
0
 def testShape(self):
   spec = array_spec.ArraySpec([1, 2, 3], np.int32)
   self.assertEqual((1, 2, 3), spec.shape)
コード例 #10
0
 def testDtype(self):
   spec = array_spec.ArraySpec((1, 2, 3), np.int32)
   self.assertEqual(np.int32, spec.dtype)
コード例 #11
0
 def testNumpyDtype(self):
   array_spec.ArraySpec((1, 2, 3), np.int32)
コード例 #12
0
 def testStringDtype(self):
   array_spec.ArraySpec((1, 2, 3), "int32")
コード例 #13
0
 def testDtypeTypeError(self):
   with self.assertRaises(TypeError):
     array_spec.ArraySpec((1, 2, 3), "32")
コード例 #14
0
 def testShapeTypeError(self):
   with self.assertRaises(TypeError):
     array_spec.ArraySpec(32, np.int32)