def test_raises_unknown_uri(self):
     backend = data_backend_example.DataBackendExample()
     with self.assertRaisesRegex(status.StatusNotOk, 'Unknown URI'):
         asyncio.run(
             backend.materialize(
                 tff_computation_proto.Data(uri='unknown_uri'),
                 tff.to_type(())))
    def test_combo_data_with_comp_and_tensor(self):
        type_spec = computation_types.SequenceType(tf.int64)
        ex = data_executor.DataExecutor(
            eager_tf_executor.EagerTFExecutor(),
            TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3),
                            type_spec))
        proto = pb.Computation(
            data=pb.Data(uri='foo://bar'),
            type=type_serialization.serialize_type(type_spec))
        arg_val = self._loop.run_until_complete(
            ex.create_value(
                collections.OrderedDict([('x', proto), ('y', 10)]),
                computation_types.StructType([('x', type_spec),
                                              ('y', tf.int32)])))

        @computations.tf_computation(type_spec, tf.int32)
        def comp(x, y):
            return tf.cast(x.reduce(np.int64(0), lambda p, q: p + q),
                           tf.int32) + y

        comp_val = self._loop.run_until_complete(ex.create_value(comp))
        ret_val = self._loop.run_until_complete(
            ex.create_call(comp_val, arg_val))
        self.assertIsInstance(ret_val, eager_tf_executor.EagerValue)
        self.assertEqual(str(ret_val.type_signature), 'int32')
        self.assertEqual(self._loop.run_until_complete(ret_val.compute()), 13)
        ex.close()
 def test_data_proto_tensor(self):
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'foo://bar', 10, tf.int32))
     proto = pb.Computation(data=pb.Data(uri='foo://bar'),
                            type=type_serialization.serialize_type(
                                computation_types.TensorType(tf.int32)))
     val = self._loop.run_until_complete(ex.create_value(proto))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int32')
     self.assertEqual(self._loop.run_until_complete(val.compute()), 10)
     ex.close()
def main(_: Sequence[str]) -> None:

  def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
    # In order to de-reference data uri's bundled in TFF computations, a
    # DataExecutor must exist in the runtime context to process those uri's and
    # return the underlying data. We can wrap an EagerTFExecutor (which handles
    # TF operations) with a DataExecutor instance defined with a DataBackend
    # object.
    return tff.framework.DataExecutor(
        tff.framework.EagerTFExecutor(device),
        data_backend=NumpyArrDataBackend())

  # Executor factory used by the runtime context to spawn executors to run TFF
  # computations.
  factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)

  # Context in which to execute the following computation.
  ctx = tff.framework.ExecutionContext(executor_fn=factory)
  tff.framework.set_default_context(ctx)

  # Type of the data returned by the DataBackend.
  element_type = tff.types.TensorType(tf.int32)
  element_type_proto = tff.framework.serialize_type(element_type)
  # We construct a list of uri's as our references to the dataset.
  uris = [f'uri://{i}' for i in range(3)]
  # The uris are embedded in TFF computation protos so they can be processed by
  # TFF executors.
  arguments = [
      pb.Computation(data=pb.Data(uri=uri), type=element_type_proto)
      for uri in uris
  ]
  # The embedded uris are passed to a DataDescriptor which recognizes the
  # underlying dataset as federated and allows combining it with a federated
  # computation.
  data_handle = tff.framework.DataDescriptor(
      None, arguments, tff.FederatedType(element_type, tff.CLIENTS),
      len(arguments))

  # Federated computation that sums the values in the arrays.
  @tff.federated_computation(tff.types.FederatedType(element_type, tff.CLIENTS))
  def foo(x):

    @tff.tf_computation(element_type)
    def local_sum(nums):
      return tf.math.reduce_sum(nums)

    return tff.federated_sum(tff.federated_map(local_sum, x))

  # Should print 18.
  print(foo(data_handle))
 def test_data_proto_dataset(self):
     type_spec = computation_types.SequenceType(tf.int64)
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3),
                         type_spec))
     proto = pb.Computation(
         data=pb.Data(uri='foo://bar'),
         type=type_serialization.serialize_type(type_spec))
     val = self._loop.run_until_complete(ex.create_value(proto))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int64*')
     self.assertCountEqual([
         x.numpy()
         for x in iter(self._loop.run_until_complete(val.compute()))
     ], [0, 1, 2])
     ex.close()
def CreateDataDescriptor(arg_uris: List[str], arg_type: computation_types.Type):
  """Constructs a `DataDescriptor` instance targeting a `tff.DataBackend`.

  Args:
    arg_uris: List of URIs compatible with the data backend embedded in the
      given `tff.framework.ExecutionContext`.
    arg_type: The type of data referenced by the URIs. An instance of
      `tff.Type`.

  Returns:
    Instance of `DataDescriptor`
  """
  arg_type_proto = serialize_type(arg_type)
  args = [
      pb.Computation(data=pb.Data(uri=uri), type=arg_type_proto)
      for uri in arg_uris
  ]
  return DataDescriptor(
      None, args, computation_types.FederatedType(arg_type, placements.CLIENTS),
      len(args))
Beispiel #7
0
 def proto(self):
     return pb.Computation(type=type_serialization.serialize_type(
         self.type_signature),
                           data=pb.Data(uri=self._uri))
 def test_raises_no_uri(self):
     backend = data_backend_example.DataBackendExample()
     with self.assertRaisesRegex(status.StatusNotOk, 'non-URI data blocks'):
         asyncio.run(
             backend.materialize(tff_computation_proto.Data(),
                                 tff.to_type(())))
 def test_materialize_returns(self, uri, type_signature, expected_value):
     backend = data_backend_example.DataBackendExample()
     value = asyncio.run(
         backend.materialize(tff_computation_proto.Data(uri=uri),
                             tff.to_type(type_signature)))
     self.assertEqual(value, expected_value)