def setUp(self): super().setUp() test_utils.ensure_flag('test_tmpdir') self._model_fn = functools.partial( models.Policy, body=lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()), )
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) # `create_tempfile` needs access to --test_tmpdir, however in the OSS world # pytest doesn't run `absltest.main`, so we need to manually parse the flags test_utils.ensure_flag('test_tmpdir') filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) f = tl.InitializerFromFile(filename) init_value = f(params.shape, rng()) self.assertEqual(tl.to_list(init_value), tl.to_list(params))
def setUp(self): super().setUp() self._serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=2) self._repr_length = 100 self._serialization_utils_kwargs = { 'observation_serializer': self._serializer, 'action_serializer': self._serializer, 'representation_length': self._repr_length, } test_utils.ensure_flag('test_tmpdir')
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) # `create_tempfile` needs access to --test_tmpdir, however in the OSS world # pytest doesn't run `absltest.main`, so we need to manually parse the flags test_utils.ensure_flag('test_tmpdir') filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) initializer = initializers.InitializerFromFile(filename) input_shape = (3, 2) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual('%s' % init_value, '%s' % params)
def setUp(self): super().setUp() test_utils.ensure_flag('test_tmpdir')
def setUp(self): super().setUp() test_utils.ensure_flag('test_tmpdir') self._old_is_allow_float64 = tf_np.is_allow_float64() tf_np.set_allow_float64(False)
def setUp(self): super().setUp() gin.clear_config() gin.add_config_file_search_path(_CONFIG_DIR) test_utils.ensure_flag('test_tmpdir')