Exemplo n.º 1
0
 def setUp(self):
     super().setUp()
     test_utils.ensure_flag('test_tmpdir')
     self._model_fn = functools.partial(
         models.Policy,
         body=lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
             tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()),
     )
Exemplo n.º 2
0
 def test_from_file(self):
   params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]])
   # `create_tempfile` needs access to --test_tmpdir, however in the OSS world
   # pytest doesn't run `absltest.main`, so we need to manually parse the flags
   test_utils.ensure_flag('test_tmpdir')
   filename = self.create_tempfile('params.npy').full_path
   with open(filename, 'wb') as f:
     np.save(f, params)
   f = tl.InitializerFromFile(filename)
   init_value = f(params.shape, rng())
   self.assertEqual(tl.to_list(init_value), tl.to_list(params))
Exemplo n.º 3
0
 def setUp(self):
     super().setUp()
     self._serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                vocab_size=2)
     self._repr_length = 100
     self._serialization_utils_kwargs = {
         'observation_serializer': self._serializer,
         'action_serializer': self._serializer,
         'representation_length': self._repr_length,
     }
     test_utils.ensure_flag('test_tmpdir')
Exemplo n.º 4
0
 def test_from_file(self):
   params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]])
   # `create_tempfile` needs access to --test_tmpdir, however in the OSS world
   # pytest doesn't run `absltest.main`, so we need to manually parse the flags
   test_utils.ensure_flag('test_tmpdir')
   filename = self.create_tempfile('params.npy').full_path
   with open(filename, 'wb') as f:
     np.save(f, params)
   initializer = initializers.InitializerFromFile(filename)
   input_shape = (3, 2)
   init_value = initializer(input_shape, random.get_prng(0))
   self.assertEqual('%s' % init_value, '%s' % params)
Exemplo n.º 5
0
 def setUp(self):
   super().setUp()
   test_utils.ensure_flag('test_tmpdir')
Exemplo n.º 6
0
 def setUp(self):
     super().setUp()
     test_utils.ensure_flag('test_tmpdir')
     self._old_is_allow_float64 = tf_np.is_allow_float64()
     tf_np.set_allow_float64(False)
Exemplo n.º 7
0
 def setUp(self):
     super().setUp()
     gin.clear_config()
     gin.add_config_file_search_path(_CONFIG_DIR)
     test_utils.ensure_flag('test_tmpdir')