コード例 #1
0
 def test_executor_create_value_with_intrinsic_as_pb_computation(self):
     loop = asyncio.get_event_loop()
     ex = _make_test_executor()
     val = loop.run_until_complete(
         ex.create_value(
             pb.Computation(intrinsic=pb.Intrinsic(uri='generic_zero'),
                            type=type_serialization.serialize_type(
                                tf.int32))))
     self.assertIsInstance(val, federated_executor.FederatedExecutorValue)
     self.assertEqual(str(val.type_signature), 'int32')
     self.assertIs(val.internal_representation, intrinsic_defs.GENERIC_ZERO)
コード例 #2
0
    def test_raises_not_implemented_error_with_unimplemented_intrinsic(self):
        executor = create_test_executor(num_clients=3)
        dummy_intrinsic = intrinsic_defs.IntrinsicDef(
            'DUMMY_INTRINSIC', 'dummy_intrinsic',
            computation_types.AbstractType('T'))
        comp = pb.Computation(intrinsic=pb.Intrinsic(uri='dummy_intrinsic'),
                              type=type_serialization.serialize_type(tf.int32))

        comp = self.run_sync(executor.create_value(comp))
        with self.assertRaises(NotImplementedError):
            self.run_sync(executor.create_call(comp))
コード例 #3
0
 def proto(self):
   return pb.Computation(
       type=type_serialization.serialize_type(self.type_signature),
       block=pb.Block(
           **{
               'local': [
                   pb.Block.Local(name=k, value=v.proto)
                   for k, v in self._locals
               ],
               'result': self._result.proto
           }))
コード例 #4
0
  def test_executor_call_unsupported_intrinsic(self):
    dummy_intrinsic = intrinsic_defs.IntrinsicDef(
        'DUMMY_INTRINSIC', 'dummy_intrinsic',
        computation_types.AbstractType('T'))

    comp = pb.Computation(
        intrinsic=pb.Intrinsic(uri='dummy_intrinsic'),
        type=type_serialization.serialize_type(tf.int32))

    with self.assertRaises(NotImplementedError):
      _run_test_comp(comp, num_clients=3)
コード例 #5
0
 def proto(self):
   elements = []
   for k, v in anonymous_tuple.to_elements(self):
     if k is not None:
       element = pb.Tuple.Element(name=k, value=v.proto)
     else:
       element = pb.Tuple.Element(value=v.proto)
     elements.append(element)
   return pb.Computation(
       type=type_serialization.serialize_type(self.type_signature),
       tuple=pb.Tuple(element=elements))
コード例 #6
0
    def test_raises_value_error_with_unrecognized_computation_selection(self):
        executor = create_test_executor(num_clients=3)
        element_value = executor_test_utils.create_dummy_empty_tensorflow_computation(
        )
        element_type = computation_types.FunctionType(
            None, computation_types.NamedTupleType([]))
        element = pb.Tuple.Element(value=element_value)
        source = pb.Computation(type=type_serialization.serialize_type(
            [element_type]),
                                tuple=pb.Tuple(element=[element]))
        # A `ValueError` will be raised because `create_value` can not handle the
        # following `pb.Selection`, because does not set either a name or an index
        # field.
        value = pb.Computation(
            type=type_serialization.serialize_type(element_type),
            selection=pb.Selection(source=source))
        type_signature = computation_types.FunctionType(
            None, computation_types.NamedTupleType([]))

        with self.assertRaises(ValueError):
            self.run_sync(executor.create_value(value, type_signature))
コード例 #7
0
 def test_data_proto_tensor(self):
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'foo://bar', 10, tf.int32))
     proto = pb.Computation(data=pb.Data(uri='foo://bar'),
                            type=type_serialization.serialize_type(
                                computation_types.TensorType(tf.int32)))
     val = self._loop.run_until_complete(ex.create_value(proto))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int32')
     self.assertEqual(self._loop.run_until_complete(val.compute()), 10)
     ex.close()
コード例 #8
0
def create_dummy_computation_tuple():
    """Returns a tuple computation and type."""
    names = ['a', 'b', 'c']
    element_value, element_type = create_dummy_computation_tensorflow_constant(
    )
    elements = [pb.Tuple.Element(name=n, value=element_value) for n in names]
    type_signature = computation_types.NamedTupleType(
        (n, element_type) for n in names)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tuple=pb.Tuple(element=elements))
    return value, type_signature
コード例 #9
0
    def test_raises_value_error_with_unrecognized_computation_intrinsic(self):
        executor = create_test_executor()
        # A `ValueError` will be raised because `create_value` can not recognize the
        # following intrinsic, because it has not been added to the intrinsic
        # registry.
        value = pb.Computation(
            type=type_serialization.serialize_type(tf.int32),
            intrinsic=pb.Intrinsic(uri='unregistered_intrinsic'))
        type_signature = computation_types.TensorType(tf.int32)

        with self.assertRaises(ValueError):
            self.run_sync(executor.create_value(value, type_signature))
コード例 #10
0
    def test_raises_value_error_with_unrecognized_computation_selection(self):
        executor = create_test_executor()
        source, _ = executor_test_utils.create_dummy_computation_tuple()
        type_signature = computation_types.NamedTupleType([])
        # A `ValueError` will be raised because `create_value` can not handle the
        # following `pb.Selection`, because does not set either a name or an index
        # field.
        value = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            selection=pb.Selection(source=source))

        with self.assertRaises(ValueError):
            self.run_sync(executor.create_value(value, type_signature))
コード例 #11
0
def create_lambda_empty_struct() -> pb.Computation:
  """Returns a lambda computation returning an empty struct.

  Has the type signature:

  ( -> <>)

  Returns:
    An instance of `pb.Computation`.
  """
  result_type = computation_types.StructType([])
  type_signature = computation_types.FunctionType(None, result_type)
  result = pb.Computation(
      type=type_serialization.serialize_type(result_type),
      struct=pb.Struct(element=[]))
  fn = pb.Lambda(parameter_name=None, result=result)
  # We are unpacking the lambda argument here because `lambda` is a reserved
  # keyword in Python, but it is also the name of the parameter for a
  # `pb.Computation`.
  # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature), **{'lambda': fn})  # pytype: disable=wrong-keyword-args
コード例 #12
0
    def test_raises_on_xla(self):
        function_type = computation_types.FunctionType(
            computation_types.TensorType(tf.int32),
            computation_types.TensorType(tf.int32))
        empty_xla_computation_proto = computation_pb2.Computation(
            type=type_serialization.serialize_type(function_type),
            xla=computation_pb2.Xla())

        compiled_comp = building_blocks.CompiledComputation(
            proto=empty_xla_computation_proto)

        with self.assertRaises(compiler.XlaToTensorFlowError):
            compiler.compile_local_computation_to_tensorflow(compiled_comp)
コード例 #13
0
  def test_something(self):
    # TODO(b/113112108): Revise these tests after a more complete implementation
    # is in place.

    # At the moment, this should succeed, as both the computation body and the
    # type are well-formed.
    computation_impl.ComputationImpl(
        pb.Computation(
            **{
                'type':
                    type_serialization.serialize_type(
                        computation_types.FunctionType(tf.int32, tf.int32)),
                'intrinsic':
                    pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    # This should fail, as the proto is not well-formed.
    self.assertRaises(TypeError, computation_impl.ComputationImpl,
                      pb.Computation(), context_stack_impl.context_stack)

    # This should fail, as "10" is not an instance of pb.Computation.
    self.assertRaises(TypeError, computation_impl.ComputationImpl, 10,
                      context_stack_impl.context_stack)
コード例 #14
0
def main(_: Sequence[str]) -> None:

  def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
    # In order to de-reference data uri's bundled in TFF computations, a
    # DataExecutor must exist in the runtime context to process those uri's and
    # return the underlying data. We can wrap an EagerTFExecutor (which handles
    # TF operations) with a DataExecutor instance defined with a DataBackend
    # object.
    return tff.framework.DataExecutor(
        tff.framework.EagerTFExecutor(device),
        data_backend=NumpyArrDataBackend())

  # Executor factory used by the runtime context to spawn executors to run TFF
  # computations.
  factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)

  # Context in which to execute the following computation.
  ctx = tff.framework.ExecutionContext(executor_fn=factory)
  tff.framework.set_default_context(ctx)

  # Type of the data returned by the DataBackend.
  element_type = tff.types.TensorType(tf.int32)
  element_type_proto = tff.framework.serialize_type(element_type)
  # We construct a list of uri's as our references to the dataset.
  uris = [f'uri://{i}' for i in range(3)]
  # The uris are embedded in TFF computation protos so they can be processed by
  # TFF executors.
  arguments = [
      pb.Computation(data=pb.Data(uri=uri), type=element_type_proto)
      for uri in uris
  ]
  # The embedded uris are passed to a DataDescriptor which recognizes the
  # underlying dataset as federated and allows combining it with a federated
  # computation.
  data_handle = tff.framework.DataDescriptor(
      None, arguments, tff.FederatedType(element_type, tff.CLIENTS),
      len(arguments))

  # Federated computation that sums the values in the arrays.
  @tff.federated_computation(tff.types.FederatedType(element_type, tff.CLIENTS))
  def foo(x):

    @tff.tf_computation(element_type)
    def local_sum(nums):
      return tf.math.reduce_sum(nums)

    return tff.federated_sum(tff.federated_map(local_sum, x))

  # Should print 18.
  print(foo(data_handle))
コード例 #15
0
def create_intrinsic_comp(intrinsic_def, type_spec):
    """Creates an intrinsic `pb.Computation`.

  Args:
    intrinsic_def: An instance of `intrinsic_defs.IntrinsicDef`.
    type_spec: The concrete type of the intrinsic (`computation_types.Type`).

  Returns:
    An instance of `pb.Computation` that represents the intrinsics.
  """
    py_typecheck.check_type(intrinsic_def, intrinsic_defs.IntrinsicDef)
    py_typecheck.check_type(type_spec, computation_types.Type)
    return pb.Computation(type=type_serialization.serialize_type(type_spec),
                          intrinsic=pb.Intrinsic(uri=intrinsic_def.uri))
コード例 #16
0
def create_dummy_identity_lambda_computation(type_spec=tf.int32):
    """Returns a `pb.Computation` representing an identity lambda.

  The type signature of this `pb.Computation` is:

  (int32 -> int32)

  Args:
    type_spec: A type signature.

  Returns:
    A `pb.Computation`.
  """
    type_signature = type_serialization.serialize_type(
        type_factory.unary_op(type_spec))
    result = pb.Computation(type=type_serialization.serialize_type(type_spec),
                            reference=pb.Reference(name='a'))
    fn = pb.Lambda(parameter_name='a', result=result)
    # We are unpacking the lambda argument here because `lambda` is a reserved
    # keyword in Python, but it is also the name of the parameter for a
    # `pb.Computation`.
    # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
    return pb.Computation(type=type_signature, **{'lambda': fn})  # pytype: disable=wrong-keyword-args
コード例 #17
0
  def test_raises_not_implemented_error_with_unimplemented_intrinsic(self):
    executor = create_test_executor()
    # `whimsy_intrinsic` definition is needed to allow lookup.
    whimsy_intrinsic = intrinsic_defs.IntrinsicDef(
        'WHIMSY_INTRINSIC', 'whimsy_intrinsic',
        computation_types.AbstractType('T'))
    type_signature = computation_types.TensorType(tf.int32)
    comp = pb.Computation(
        intrinsic=pb.Intrinsic(uri='whimsy_intrinsic'),
        type=type_serialization.serialize_type(type_signature))
    del whimsy_intrinsic

    comp = self.run_sync(executor.create_value(comp))
    with self.assertRaises(NotImplementedError):
      self.run_sync(executor.create_call(comp))
コード例 #18
0
    def test_executor_call_unsupported_intrinsic(self):
        dummy_intrinsic = intrinsic_defs.IntrinsicDef(
            'DUMMY_INTRINSIC', 'dummy_intrinsic',
            computation_types.AbstractType('T'))

        comp = pb.Computation(type=type_serialization.serialize_type(tf.int32),
                              intrinsic=pb.Intrinsic(uri='dummy_intrinsic'))

        loop = asyncio.get_event_loop()
        executor = composing_executor.ComposingExecutor(
            _create_bottom_stack(), [_create_worker_stack() for _ in range(3)])

        with self.assertRaises(NotImplementedError):
            v1 = loop.run_until_complete(executor.create_value(comp, tf.int32))
            loop.run_until_complete(executor.create_call(v1, None))
コード例 #19
0
  def test_with_type_raises_non_assignable_type(self):
    int_return_type = computation_types.FunctionType(tf.int32, tf.int32)
    original_comp = computation_impl.ConcreteComputation(
        pb.Computation(
            **{
                'type': type_serialization.serialize_type(int_return_type),
                'intrinsic': pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    list_return_type = computation_types.FunctionType(
        tf.int32,
        computation_types.StructWithPythonType([(None, tf.int32)], list))
    with self.assertRaises(computation_types.TypeNotAssignableError):
      computation_impl.ConcreteComputation.with_type(original_comp,
                                                     list_return_type)
コード例 #20
0
def create_lambda_identity(type_spec: computation_types.Type) -> pb.Computation:
  """Returns a lambda computation representing an identity function.

  Has the type signature:

  (T -> T)

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    An instance of `pb.Computation`.
  """
  type_signature = type_factory.unary_op(type_spec)
  result = pb.Computation(
      type=type_serialization.serialize_type(type_spec),
      reference=pb.Reference(name='a'))
  fn = pb.Lambda(parameter_name='a', result=result)
  # We are unpacking the lambda argument here because `lambda` is a reserved
  # keyword in Python, but it is also the name of the parameter for a
  # `pb.Computation`.
  # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature), **{'lambda': fn})  # pytype: disable=wrong-keyword-args
コード例 #21
0
    def test_invoke_raises_value_error_with_federated_computation(self):
        bogus_proto = pb.Computation(type=type_serialization.serialize_type(
            computation_types.to_type(
                computation_types.FunctionType(tf.int32, tf.int32))),
                                     reference=pb.Reference(name='boogledy'))
        non_tf_computation = computation_impl.ComputationImpl(
            bogus_proto, context_stack_impl.context_stack)

        context = tensorflow_computation_context.TensorFlowComputationContext(
            tf.compat.v1.get_default_graph())

        with self.assertRaisesRegex(
                ValueError, 'Can only invoke TensorFlow in the body of '
                'a TensorFlow computation'):
            context.invoke(non_tf_computation, None)
コード例 #22
0
def wrap_graph_parameter_as_tuple(comp, name=None):
    """Wraps the parameter of `comp` in a tuple binding.

  `wrap_graph_parameter_as_tuple` is intended as a preprocessing step
  to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type`
  can
  make the assumption that its argument `comp` always has a tuple binding,
  instead of dealing with the possibility of an unwrapped tensor or sequence
  binding.

  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation` whose
      parameter we wish to wrap in a tuple binding.
    name: Optional string argument, the name to assign to the element type in
      the constructed tuple. Defaults to `None`.

  Returns:
    A transformed version of comp representing exactly the same computation,
    but accepting a tuple containing one element--the parameter of `comp`.

  Raises:
    TypeError: If `comp` is not a
      `computation_building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    if name is not None:
        py_typecheck.check_type(name, six.string_types)
    proto = comp.proto
    proto_type = type_serialization.deserialize_type(proto.type)

    parameter_binding = [proto.tensorflow.parameter]
    parameter_type_list = [(name, proto_type.parameter)]
    new_parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding))

    new_function_type = computation_types.FunctionType(parameter_type_list,
                                                       proto_type.result)
    serialized_type = type_serialization.serialize_type(new_function_type)

    input_padded_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def,
                                 initialize_op=proto.tensorflow.initialize_op,
                                 parameter=new_parameter_binding,
                                 result=proto.tensorflow.result))

    return computation_building_blocks.CompiledComputation(input_padded_proto)
コード例 #23
0
def create_dummy_computation_tensorflow_empty():
    """Returns a tensorflow computation and type `( -> <>)`."""

    with tf.Graph().as_default() as graph:
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            [], graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
コード例 #24
0
def disable_grappler_for_partitioned_calls(proto):
    """Disables grappler for `PartitionedCall` and `StatefulPartitionedCall` nodes in the graph.

  TensorFlow serializes a `ConfigProto` into `PartitionedCall` and
  `StatefulPartitionedCall` the `config_proto` `attr` of graph nodes. This
  overrides any session config that might disable runtime grappler. The disable
  grappler for these nodes as well, this function overwrites the serialized
  configproto, setting the `disable_meta_optimizer` field to `True.

  Args:
    proto: Instance of `computation_pb2.Computation` with the `tensorflow` field
      populated.

  Returns:
    A transformed instance of `computation_pb2.Computation` with a `tensorflow`
    field.
  """
    py_typecheck.check_type(proto, computation_pb2.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                        'protos of the "tensorflow" variety; you have passed '
                        'one of variety {}.'.format(computation_oneof))
    original_tf = proto.tensorflow
    graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def)
    all_nodes = itertools.chain(
        graph_def.node, *[f.node_def for f in graph_def.library.function])
    for node in all_nodes:
        if node.op not in CALL_OPS:
            continue
        attr_str = node.attr.get('config_proto')
        if attr_str is None:
            config_proto = tf.compat.v1.ConfigProto()
        else:
            config_proto = tf.compat.v1.ConfigProto.FromString(attr_str.s)
        config_proto.graph_options.rewrite_options.disable_meta_optimizer = True
        attr_str.s = config_proto.SerializeToString(deterministic=True)
    tf_block = computation_pb2.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(graph_def),
        initialize_op=original_tf.initialize_op
        if original_tf.initialize_op else None,
        parameter=original_tf.parameter
        if original_tf.HasField('parameter') else None,
        result=original_tf.result)
    new_proto = computation_pb2.Computation(type=proto.type,
                                            tensorflow=tf_block)
    return new_proto
コード例 #25
0
  def test_with_type_preserves_python_container(self):
    struct_return_type = computation_types.FunctionType(
        tf.int32, computation_types.StructType([(None, tf.int32)]))
    original_comp = computation_impl.ConcreteComputation(
        pb.Computation(
            **{
                'type': type_serialization.serialize_type(struct_return_type),
                'intrinsic': pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    list_return_type = computation_types.FunctionType(
        tf.int32,
        computation_types.StructWithPythonType([(None, tf.int32)], list))
    fn_with_annotated_type = computation_impl.ConcreteComputation.with_type(
        original_comp, list_return_type)
    type_test_utils.assert_types_identical(
        list_return_type, fn_with_annotated_type.type_signature)
コード例 #26
0
  def test_get_wrapped_function_from_comp_raises_with_incorrect_binding(self):

    with tf.Graph().as_default() as graph:
      var = tf.Variable(initial_value=0.0, name='var1', import_scope='')
      assign_op = var.assign_add(tf.constant(1.0))
      tf.add(1.0, assign_op)

    result_binding = pb.TensorFlow.Binding(
        tensor=pb.TensorFlow.TensorBinding(tensor_name='Invalid'))
    comp = pb.Computation(
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            result=result_binding))
    with self.assertRaises(TypeError):
      wrapped_fn = eager_tf_executor._get_wrapped_function_from_comp(
          comp, must_pin_function_to_cpu=False, param_type=None, device=None)
      wrapped_fn()
コード例 #27
0
def create_xla_tff_computation(xla_computation, type_spec):
    """Creates an XLA TFF computation.

  Args:
    xla_computation: An instance of `xla_client.XlaComputation`.
    type_spec: The TFF type of the computation to be constructed.

  Returns:
    An instance of `pb.Computation`.
  """
    py_typecheck.check_type(xla_computation, xla_client.XlaComputation)
    py_typecheck.check_type(type_spec, computation_types.FunctionType)
    return pb.Computation(
        type=type_serialization.serialize_type(type_spec),
        xla=pb.Xla(hlo_module=pack_xla_computation(xla_computation),
                   parameter=_make_xla_binding_for_type(type_spec.parameter),
                   result=_make_xla_binding_for_type(type_spec.result)))
コード例 #28
0
 def test_data_proto_dataset(self):
     type_spec = computation_types.SequenceType(tf.int64)
     ex = data_executor.DataExecutor(
         eager_tf_executor.EagerTFExecutor(),
         TestDataBackend(self, 'foo://bar', tf.data.Dataset.range(3),
                         type_spec))
     proto = pb.Computation(
         data=pb.Data(uri='foo://bar'),
         type=type_serialization.serialize_type(type_spec))
     val = self._loop.run_until_complete(ex.create_value(proto))
     self.assertIsInstance(val, eager_tf_executor.EagerValue)
     self.assertEqual(str(val.type_signature), 'int64*')
     self.assertCountEqual([
         x.numpy()
         for x in iter(self._loop.run_until_complete(val.compute()))
     ], [0, 1, 2])
     ex.close()
コード例 #29
0
def prune_tensorflow_proto(proto):
    """Extracts subgraph from `proto` preserving parameter, result and initialize.

  Args:
    proto: Instance of `pb.Computation` of the `tensorflow` variety whose
      `graphdef` attribute we wish to prune of extraneous ops.

  Returns:
    A transformed instance of `pb.Computation` of the `tensorflow` variety,
    whose `graphdef` attribute contains only ops which can reach the
    parameter or result bindings, or initialize op.
  """
    py_typecheck.check_type(proto, pb.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError(
            '`prune_tensorflow_proto` only accepts `Computation` '
            'protos of the \'tensorflow\' variety; you have passed '
            'one of variety {}.'.format(computation_oneof))
    if proto.tensorflow.parameter.WhichOneof('binding'):
        parameter_tensor_names = graph_utils.extract_tensor_names_from_binding(
            proto.tensorflow.parameter)
        parameter_names = [
            ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names
        ]
    else:
        parameter_names = []
    return_tensor_names = graph_utils.extract_tensor_names_from_binding(
        proto.tensorflow.result)
    return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names]
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    init_op_name = proto.tensorflow.initialize_op
    names_to_preserve = parameter_names + return_names
    if init_op_name:
        names_to_preserve.append(init_op_name)
    subgraph_def = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def, names_to_preserve)
    tf_block = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(subgraph_def),
        initialize_op=proto.tensorflow.initialize_op,
        parameter=proto.tensorflow.parameter,
        result=proto.tensorflow.result)
    pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block)
    return pruned_proto
コード例 #30
0
    def test_executor_call_unsupported_intrinsic(self):
        dummy_intrinsic = intrinsic_defs.IntrinsicDef(
            'DUMMY_INTRINSIC', 'dummy_intrinsic',
            computation_types.AbstractType('T'))
        type_signature = computation_types.TensorType(tf.int32)
        comp = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            intrinsic=pb.Intrinsic(uri='dummy_intrinsic'))

        loop = asyncio.get_event_loop()
        factory = federated_composing_strategy.FederatedComposingStrategy.factory(
            _create_bottom_stack(), [_create_worker_stack()])
        executor = federating_executor.FederatingExecutor(
            factory, _create_bottom_stack())

        v1 = loop.run_until_complete(executor.create_value(comp))
        with self.assertRaises(NotImplementedError):
            loop.run_until_complete(executor.create_call(v1))