Exemple #1
0
    def test_strategy(self, num_cores):
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver('')
        topology = tf.tpu.experimental.initialize_tpu_system(resolver)
        da = tf.tpu.experimental.DeviceAssignment.build(topology,
                                                        num_replicas=num_cores)
        strategy = tf.distribute.experimental.TPUStrategy(resolver,
                                                          device_assignment=da)

        def dataset_fn(unused_ctx):
            def gen():
                yield 0
                yield 1

            dataset = tf.data.Dataset.from_generator(gen, (tf.int64))
            return dataset.map(lambda _: utils.tpu_encode(self.data))

        dataset = strategy.experimental_distribute_datasets_from_function(
            dataset_fn)
        encoded = next(iter(dataset))

        decoded = strategy.run(
            tf.function(lambda args: utils.tpu_decode(args, encoded)),
            (encoded, ))
        decoded = tf.nest.map_structure(
            lambda t: strategy.experimental_local_results(t)[0], decoded)

        for a, b in zip(decoded, self.data):
            self.assertAllEqual(a, b)
Exemple #2
0
    def test_dataset(self):
        def gen():
            yield 0

        dataset = tf.data.Dataset.from_generator(gen, tf.int64)
        dataset = dataset.map(lambda _: utils.tpu_encode(self.data))
        encoded = list(dataset)[0]
        decoded = utils.tpu_decode(encoded)

        for a, b in zip(decoded, self.data):
            self.assertAllEqual(a, b)
Exemple #3
0
    def test_simple(self):
        encoded = utils.tpu_encode(self.data)
        decoded = utils.tpu_decode(encoded)

        self.assertEqual(tf.int32, encoded[1].dtype)
        self.assertIsInstance(encoded[2], utils.TPUEncodedUInt8)
        self.assertEqual(tf.bfloat16, encoded[3].dtype)
        self.assertIsInstance(encoded[4], utils.TPUEncodedUInt8)
        self.assertIsInstance(encoded[5], utils.TPUEncodedUInt8)

        for a, b in zip(decoded, self.data):
            self.assertAllEqual(a, b)