Пример #1
0
        def dataset_fn(unused_ctx):
            def gen():
                yield 0
                yield 1

            dataset = tf.data.Dataset.from_generator(gen, (tf.int64))
            return dataset.map(lambda _: utils.tpu_encode(self.data))
Пример #2
0
    def test_packed_bits(self, stacked):
        env = gym.make('gfootball:GFootball-11_vs_11_easy_stochastic-SMM-v0',
                       stacked=stacked)
        env.reset()
        for _ in range(10):
            obs, _, done, _ = env.step(env.action_space.sample())

            baseline_obs = tf.cast(np.array(obs), tf.float32)

            packed_obs = observation.PackedBitsObservation.observation(
                env, obs)
            packed_obs = tf.convert_to_tensor(packed_obs)
            tpu_obs = observation.unpackbits(utils.tpu_encode(packed_obs))
            non_tpu_obs = observation.unpackbits(packed_obs)
            # baseline_obs has less than 16 channels, so first channels should
            # correspond to baseline_obs and then all the rest should be 0
            self.assertAllEqual(baseline_obs, tpu_obs[..., :obs.shape[-1]])
            self.assertAllEqual(baseline_obs, non_tpu_obs[..., :obs.shape[-1]])
            self.assertAllEqual(
                tf.math.reduce_sum(tpu_obs[..., obs.shape[-1]:]), 0)
            self.assertAllEqual(
                tf.math.reduce_sum(non_tpu_obs[..., obs.shape[-1]:]), 0)

            if done:
                env.reset()
        env.close()
Пример #3
0
    def test_dataset(self):
        def gen():
            yield 0

        dataset = tf.data.Dataset.from_generator(gen, tf.int64)
        dataset = dataset.map(lambda _: utils.tpu_encode(self.data))
        encoded = list(dataset)[0]
        decoded = utils.tpu_decode(encoded)

        for a, b in zip(decoded, self.data):
            self.assertAllEqual(a, b)
Пример #4
0
    def test_simple(self):
        encoded = utils.tpu_encode(self.data)
        decoded = utils.tpu_decode(encoded)

        self.assertEqual(tf.int32, encoded[1].dtype)
        self.assertIsInstance(encoded[2], utils.TPUEncodedUInt8)
        self.assertEqual(tf.bfloat16, encoded[3].dtype)
        self.assertIsInstance(encoded[4], utils.TPUEncodedUInt8)
        self.assertIsInstance(encoded[5], utils.TPUEncodedUInt8)

        for a, b in zip(decoded, self.data):
            self.assertAllEqual(a, b)