def test_raises_unknown_uri(self): backend = data_backend_example.DataBackendExample() with self.assertRaisesRegex(status.StatusNotOk, 'Unknown URI'): asyncio.run( backend.materialize( tff_computation_proto.Data(uri='unknown_uri'), tff.to_type(())))
def test_combo_data_with_comp_and_tensor(self): type_spec = computation_types.SequenceType(tf.int64) ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3), type_spec)) proto = pb.Computation( data=pb.Data(uri='foo://bar'), type=type_serialization.serialize_type(type_spec)) arg_val = self._loop.run_until_complete( ex.create_value( collections.OrderedDict([('x', proto), ('y', 10)]), computation_types.StructType([('x', type_spec), ('y', tf.int32)]))) @computations.tf_computation(type_spec, tf.int32) def comp(x, y): return tf.cast(x.reduce(np.int64(0), lambda p, q: p + q), tf.int32) + y comp_val = self._loop.run_until_complete(ex.create_value(comp)) ret_val = self._loop.run_until_complete( ex.create_call(comp_val, arg_val)) self.assertIsInstance(ret_val, eager_tf_executor.EagerValue) self.assertEqual(str(ret_val.type_signature), 'int32') self.assertEqual(self._loop.run_until_complete(ret_val.compute()), 13) ex.close()
def test_data_proto_tensor(self): ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', 10, tf.int32)) proto = pb.Computation(data=pb.Data(uri='foo://bar'), type=type_serialization.serialize_type( computation_types.TensorType(tf.int32))) val = self._loop.run_until_complete(ex.create_value(proto)) self.assertIsInstance(val, eager_tf_executor.EagerValue) self.assertEqual(str(val.type_signature), 'int32') self.assertEqual(self._loop.run_until_complete(val.compute()), 10) ex.close()
def main(_: Sequence[str]) -> None: def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor: # In order to de-reference data uri's bundled in TFF computations, a # DataExecutor must exist in the runtime context to process those uri's and # return the underlying data. We can wrap an EagerTFExecutor (which handles # TF operations) with a DataExecutor instance defined with a DataBackend # object. return tff.framework.DataExecutor( tff.framework.EagerTFExecutor(device), data_backend=NumpyArrDataBackend()) # Executor factory used by the runtime context to spawn executors to run TFF # computations. factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn) # Context in which to execute the following computation. ctx = tff.framework.ExecutionContext(executor_fn=factory) tff.framework.set_default_context(ctx) # Type of the data returned by the DataBackend. element_type = tff.types.TensorType(tf.int32) element_type_proto = tff.framework.serialize_type(element_type) # We construct a list of uri's as our references to the dataset. uris = [f'uri://{i}' for i in range(3)] # The uris are embedded in TFF computation protos so they can be processed by # TFF executors. arguments = [ pb.Computation(data=pb.Data(uri=uri), type=element_type_proto) for uri in uris ] # The embedded uris are passed to a DataDescriptor which recognizes the # underlying dataset as federated and allows combining it with a federated # computation. data_handle = tff.framework.DataDescriptor( None, arguments, tff.FederatedType(element_type, tff.CLIENTS), len(arguments)) # Federated computation that sums the values in the arrays. @tff.federated_computation(tff.types.FederatedType(element_type, tff.CLIENTS)) def foo(x): @tff.tf_computation(element_type) def local_sum(nums): return tf.math.reduce_sum(nums) return tff.federated_sum(tff.federated_map(local_sum, x)) # Should print 18. print(foo(data_handle))
def test_data_proto_dataset(self): type_spec = computation_types.SequenceType(tf.int64) ex = data_executor.DataExecutor( eager_tf_executor.EagerTFExecutor(), TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3), type_spec)) proto = pb.Computation( data=pb.Data(uri='foo://bar'), type=type_serialization.serialize_type(type_spec)) val = self._loop.run_until_complete(ex.create_value(proto)) self.assertIsInstance(val, eager_tf_executor.EagerValue) self.assertEqual(str(val.type_signature), 'int64*') self.assertCountEqual([ x.numpy() for x in iter(self._loop.run_until_complete(val.compute())) ], [0, 1, 2]) ex.close()
def CreateDataDescriptor(arg_uris: List[str], arg_type: computation_types.Type): """Constructs a `DataDescriptor` instance targeting a `tff.DataBackend`. Args: arg_uris: List of URIs compatible with the data backend embedded in the given `tff.framework.ExecutionContext`. arg_type: The type of data referenced by the URIs. An instance of `tff.Type`. Returns: Instance of `DataDescriptor` """ arg_type_proto = serialize_type(arg_type) args = [ pb.Computation(data=pb.Data(uri=uri), type=arg_type_proto) for uri in arg_uris ] return DataDescriptor( None, args, computation_types.FederatedType(arg_type, placements.CLIENTS), len(args))
def proto(self): return pb.Computation(type=type_serialization.serialize_type( self.type_signature), data=pb.Data(uri=self._uri))
def test_raises_no_uri(self): backend = data_backend_example.DataBackendExample() with self.assertRaisesRegex(status.StatusNotOk, 'non-URI data blocks'): asyncio.run( backend.materialize(tff_computation_proto.Data(), tff.to_type(())))
def test_materialize_returns(self, uri, type_signature, expected_value): backend = data_backend_example.DataBackendExample() value = asyncio.run( backend.materialize(tff_computation_proto.Data(uri=uri), tff.to_type(type_signature))) self.assertEqual(value, expected_value)