def dataset_fn(unused_ctx): def gen(): yield 0 yield 1 dataset = tf.data.Dataset.from_generator(gen, (tf.int64)) return dataset.map(lambda _: utils.tpu_encode(self.data))
def test_packed_bits(self, stacked): env = gym.make('gfootball:GFootball-11_vs_11_easy_stochastic-SMM-v0', stacked=stacked) env.reset() for _ in range(10): obs, _, done, _ = env.step(env.action_space.sample()) baseline_obs = tf.cast(np.array(obs), tf.float32) packed_obs = observation.PackedBitsObservation.observation( env, obs) packed_obs = tf.convert_to_tensor(packed_obs) tpu_obs = observation.unpackbits(utils.tpu_encode(packed_obs)) non_tpu_obs = observation.unpackbits(packed_obs) # baseline_obs has less than 16 channels, so first channels should # correspond to baseline_obs and then all the rest should be 0 self.assertAllEqual(baseline_obs, tpu_obs[..., :obs.shape[-1]]) self.assertAllEqual(baseline_obs, non_tpu_obs[..., :obs.shape[-1]]) self.assertAllEqual( tf.math.reduce_sum(tpu_obs[..., obs.shape[-1]:]), 0) self.assertAllEqual( tf.math.reduce_sum(non_tpu_obs[..., obs.shape[-1]:]), 0) if done: env.reset() env.close()
def test_dataset(self): def gen(): yield 0 dataset = tf.data.Dataset.from_generator(gen, tf.int64) dataset = dataset.map(lambda _: utils.tpu_encode(self.data)) encoded = list(dataset)[0] decoded = utils.tpu_decode(encoded) for a, b in zip(decoded, self.data): self.assertAllEqual(a, b)
def test_simple(self): encoded = utils.tpu_encode(self.data) decoded = utils.tpu_decode(encoded) self.assertEqual(tf.int32, encoded[1].dtype) self.assertIsInstance(encoded[2], utils.TPUEncodedUInt8) self.assertEqual(tf.bfloat16, encoded[3].dtype) self.assertIsInstance(encoded[4], utils.TPUEncodedUInt8) self.assertIsInstance(encoded[5], utils.TPUEncodedUInt8) for a, b in zip(decoded, self.data): self.assertAllEqual(a, b)