Esempio n. 1
0
    def test_in_executor_stack(self):
        type_spec = computation_types.SequenceType(tf.int64)
        ex = data_executor.DataExecutor(
            eager_tf_executor.EagerTFExecutor(),
            TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(5),
                            type_spec))
        ex_fn = lambda device: ex
        factory = executor_stacks.local_executor_factory(
            leaf_executor_fn=ex_fn)
        context = execution_context.ExecutionContext(executor_fn=factory)

        @computations.tf_computation(type_spec)
        def foo(ds):
            return tf.cast(ds.reduce(np.int64(0), lambda p, q: p + q),
                           tf.int32)

        @computations.federated_computation
        def bar():
            ds = tff_data.data('foo://bar', type_spec)
            return foo(ds)

        with context_stack_impl.context_stack.install(context):
            result = bar()

        self.assertEqual(result, 10)
Esempio n. 2
0
    def test_combo_data_with_comp_and_tensor(self):
        type_spec = computation_types.SequenceType(tf.int64)
        ex = data_executor.DataExecutor(
            eager_tf_executor.EagerTFExecutor(),
            TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3),
                            type_spec))
        proto = pb.Computation(
            data=pb.Data(uri='foo://bar'),
            type=type_serialization.serialize_type(type_spec))
        arg_val = self._loop.run_until_complete(
            ex.create_value(
                collections.OrderedDict([('x', proto), ('y', 10)]),
                computation_types.StructType([('x', type_spec),
                                              ('y', tf.int32)])))

        @computations.tf_computation(type_spec, tf.int32)
        def comp(x, y):
            return tf.cast(x.reduce(np.int64(0), lambda p, q: p + q),
                           tf.int32) + y

        comp_val = self._loop.run_until_complete(ex.create_value(comp))
        ret_val = self._loop.run_until_complete(
            ex.create_call(comp_val, arg_val))
        self.assertIsInstance(ret_val, eager_tf_executor.EagerValue)
        self.assertEqual(str(ret_val.type_signature), 'int32')
        self.assertEqual(self._loop.run_until_complete(ret_val.compute()), 13)
        ex.close()
Esempio n. 3
0
 def test_pass_through_tensor(self):
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'none', None, None))
     val = self._loop.run_until_complete(ex.create_value(10, tf.int32))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int32')
     self.assertEqual(self._loop.run_until_complete(val.compute()), 10)
     ex.close()
Esempio n. 4
0
 def test_data_proto_tensor(self):
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'foo://bar', 10, tf.int32))
     proto = pb.Computation(data=pb.Data(uri='foo://bar'),
                            type=type_serialization.serialize_type(
                                computation_types.TensorType(tf.int32)))
     val = self._loop.run_until_complete(ex.create_value(proto))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int32')
     self.assertEqual(self._loop.run_until_complete(val.compute()), 10)
     ex.close()
Esempio n. 5
0
    def test_pass_through_comp(self):
        ex = data_executor.DataExecutor(
            eager_tf_executor.EagerTFExecutor(),
            TestDataBackend(self, 'none', None, None))

        @computations.tf_computation
        def comp():
            return tf.constant(10, tf.int32)

        val = self._loop.run_until_complete(ex.create_value(comp))
        self.assertIsInstance(val, eager_tf_executor.EagerValue)
        self.assertEqual(str(val.type_signature), '( -> int32)')
        val2 = self._loop.run_until_complete(ex.create_call(val))
        self.assertEqual(str(val2.type_signature), 'int32')
        self.assertEqual(self._loop.run_until_complete(val2.compute()), 10)
        ex.close()
Esempio n. 6
0
 def test_data_proto_dataset(self):
     type_spec = computation_types.SequenceType(tf.int64)
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3),
                         type_spec))
     proto = pb.Computation(
         data=pb.Data(uri='foo://bar'),
         type=type_serialization.serialize_type(type_spec))
     val = self._loop.run_until_complete(ex.create_value(proto))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int64*')
     self.assertCountEqual([
         x.numpy()
         for x in iter(self._loop.run_until_complete(val.compute()))
     ], [0, 1, 2])
     ex.close()
Esempio n. 7
0
 def ex_fn(device):
     return data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(device),
         TestDataBackend(data_constant))