def test_in_executor_stack(self): type_spec = computation_types.SequenceType(tf.int64) ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(5), type_spec)) ex_fn = lambda device: ex factory = executor_stacks.local_executor_factory( leaf_executor_fn=ex_fn) context = execution_context.ExecutionContext(executor_fn=factory) @computations.tf_computation(type_spec) def foo(ds): return tf.cast(ds.reduce(np.int64(0), lambda p, q: p + q), tf.int32) @computations.federated_computation def bar(): ds = tff_data.data('foo://bar', type_spec) return foo(ds) with context_stack_impl.context_stack.install(context): result = bar() self.assertEqual(result, 10)
def test_combo_data_with_comp_and_tensor(self): type_spec = computation_types.SequenceType(tf.int64) ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3), type_spec)) proto = pb.Computation( data=pb.Data(uri='foo://bar'), type=type_serialization.serialize_type(type_spec)) arg_val = self._loop.run_until_complete( ex.create_value( collections.OrderedDict([('x', proto), ('y', 10)]), computation_types.StructType([('x', type_spec), ('y', tf.int32)]))) @computations.tf_computation(type_spec, tf.int32) def comp(x, y): return tf.cast(x.reduce(np.int64(0), lambda p, q: p + q), tf.int32) + y comp_val = self._loop.run_until_complete(ex.create_value(comp)) ret_val = self._loop.run_until_complete( ex.create_call(comp_val, arg_val)) self.assertIsInstance(ret_val, eager_tf_executor.EagerValue) self.assertEqual(str(ret_val.type_signature), 'int32') self.assertEqual(self._loop.run_until_complete(ret_val.compute()), 13) ex.close()
def test_pass_through_tensor(self): ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'none', None, None)) val = self._loop.run_until_complete(ex.create_value(10, tf.int32)) self.assertIsInstance(val, eager_tf_executor.EagerValue) self.assertEqual(str(val.type_signature), 'int32') self.assertEqual(self._loop.run_until_complete(val.compute()), 10) ex.close()
def test_data_proto_tensor(self): ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', 10, tf.int32)) proto = pb.Computation(data=pb.Data(uri='foo://bar'), type=type_serialization.serialize_type( computation_types.TensorType(tf.int32))) val = self._loop.run_until_complete(ex.create_value(proto)) self.assertIsInstance(val, eager_tf_executor.EagerValue) self.assertEqual(str(val.type_signature), 'int32') self.assertEqual(self._loop.run_until_complete(val.compute()), 10) ex.close()
def test_pass_through_comp(self): ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'none', None, None)) @computations.tf_computation def comp(): return tf.constant(10, tf.int32) val = self._loop.run_until_complete(ex.create_value(comp)) self.assertIsInstance(val, eager_tf_executor.EagerValue) self.assertEqual(str(val.type_signature), '( -> int32)') val2 = self._loop.run_until_complete(ex.create_call(val)) self.assertEqual(str(val2.type_signature), 'int32') self.assertEqual(self._loop.run_until_complete(val2.compute()), 10) ex.close()
def test_data_proto_dataset(self): type_spec = computation_types.SequenceType(tf.int64) ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3), type_spec)) proto = pb.Computation( data=pb.Data(uri='foo://bar'), type=type_serialization.serialize_type(type_spec)) val = self._loop.run_until_complete(ex.create_value(proto)) self.assertIsInstance(val, eager_tf_executor.EagerValue) self.assertEqual(str(val.type_signature), 'int64*') self.assertCountEqual([ x.numpy() for x in iter(self._loop.run_until_complete(val.compute())) ], [0, 1, 2]) ex.close()
def ex_fn(device): return data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(device), TestDataBackend(data_constant))