示例#1
0
def create_replicate_input(type_signature: computation_types.Type,
                           count: int) -> ProtoAndType:
  """Returns a tensorflow computation returning `count` copies of its argument.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_signature` and the length of the result is `count`.

  Args:
    type_signature: A `computation_types.Type` to replicate.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  py_typecheck.check_type(count, int)
  parameter_type = type_signature

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', parameter_type, graph)
    result = [parameter_value] * count
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
def create_identity(
        type_signature: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    parameter_type = type_signature
    if parameter_type is None:
        raise TypeError('TensorFlow identity cannot be created for NoneType.')

    # TF relies on feeds not-identical to fetches in certain circumstances.
    if type_signature.is_tensor() or type_signature.is_sequence():
        identity_fn = tf.identity
    elif type_signature.is_struct():
        identity_fn = functools.partial(structure.map_structure, tf.identity)
    else:
        raise NotImplementedError(
            f'TensorFlow identity cannot be created for type {type_signature}')

    return create_computation_for_py_fn(identity_fn, parameter_type)
def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
    """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    parameter_type = type_signature

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
示例#4
0
def create_replicate_input(type_spec, count: int) -> pb.Computation:
  """Returns a tensorflow computation which returns `count` clones of an input.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_spec` and the length of the result is `count`.

  Args:
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type`.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_spec` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_spec = computation_types.to_type(type_spec)
  type_analysis.check_tensorflow_compatible_type(type_spec)
  py_typecheck.check_type(count, int)

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_spec, graph)
    result = [parameter_value] * count
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(type_spec, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
示例#5
0
def create_identity(type_spec) -> pb.Computation:
  """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_spec`.

  Args:
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type`.

  Raises:
    TypeError: If `type_spec` contains any types which cannot appear in
      TensorFlow bindings.
  """
  type_spec = computation_types.to_type(type_spec)
  type_analysis.check_tensorflow_compatible_type(type_spec)

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_spec, graph)
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        parameter_value, graph)

  type_signature = computation_types.FunctionType(type_spec, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
示例#6
0
def compile_local_computation_to_tensorflow(
    comp: building_blocks.ComputationBuildingBlock,
) -> building_blocks.ComputationBuildingBlock:
    """Compiles a fully specified local computation to TensorFlow.

  Args:
    comp: A `building_blocks.ComputationBuildingBlock` which can be compiled to
      TensorFlow. In order to compile a computation to TensorFlow, it must not
      contain 1. References to values defined outside of comp, 2. `Data`,
      `Intrinsic`, or `Placement` blocks, or 3. Calls to intrinsics or
      non-TensorFlow computations.

  Returns:
    A `building_blocks.ComputationBuildingBlock` containing a TensorFlow-only
    representation of `comp`. If `comp` is of functional type, this will be
    a `building_blocks.CompiledComputation`. Otherwise, it will be a
    `building_blocks.Call` which wraps a `building_blocks.CompiledComputation`.
  """
    if not comp.type_signature.is_function():
        lambda_wrapped = building_blocks.Lambda(None, None, comp)
        return building_blocks.Call(
            compile_local_computation_to_tensorflow(lambda_wrapped), None)

    parameter_type = comp.type_signature.parameter
    type_analysis.check_tensorflow_compatible_type(parameter_type)
    type_analysis.check_tensorflow_compatible_type(comp.type_signature.result)

    if (comp.is_compiled_computation()
            and comp.proto.WhichOneof('computation') == 'tensorflow'):
        return comp

    # Ensure that unused values are removed and that reference bindings have
    # unique names.
    comp = unpack_compiled_computations(comp)
    comp = transformations.to_call_dominant(comp)

    if parameter_type is None:
        to_evaluate = building_blocks.Call(comp)

        @tensorflow_computation.tf_computation
        def result_computation():
            return _evaluate_to_tensorflow(to_evaluate, {})
    else:
        name_generator = building_block_factory.unique_name_generator(comp)
        parameter_name = next(name_generator)
        to_evaluate = building_blocks.Call(
            comp, building_blocks.Reference(parameter_name, parameter_type))

        @tensorflow_computation.tf_computation(parameter_type)
        def result_computation(arg):
            if parameter_type.is_struct():
                arg = structure.from_container(arg, recursive=True)
            return _evaluate_to_tensorflow(to_evaluate, {parameter_name: arg})

    return result_computation.to_compiled_building_block()
def create_replicate_input(type_signature: computation_types.Type,
                           count: int) -> ComputationProtoAndType:
    """Returns a tensorflow computation returning `count` copies of its argument.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_signature` and the length of the result is `count`.

  Args:
    type_signature: A `computation_types.Type` to replicate.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    py_typecheck.check_type(count, int)
    parameter_type = type_signature
    return create_computation_for_py_fn(lambda v: [v] * count, parameter_type)
示例#8
0
def create_replicate_input(type_signature: computation_types.Type,
                           count: int) -> ComputationProtoAndType:
  """Returns a tensorflow computation returning `count` copies of its argument.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_signature` and the length of the result is `count`.

  Args:
    type_signature: A `computation_types.Type` to replicate.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  py_typecheck.check_type(count, int)
  parameter_type = type_signature
  identity_comp, _ = create_identity(parameter_type)
  # This manual proto manipulation is significantly faster than using TFF's
  # GraphDef serialization for large `count` arguments.
  tensorflow_comp = identity_comp.tensorflow
  single_result_binding = tensorflow_comp.result
  if tensorflow_comp.parameter:
    new_tf_pb = pb.TensorFlow(
        graph_def=tensorflow_comp.graph_def,
        parameter=tensorflow_comp.parameter,
        result=pb.TensorFlow.Binding(
            struct=pb.TensorFlow.StructBinding(
                element=(single_result_binding for _ in range(count)))))
  else:
    new_tf_pb = pb.TensorFlow(
        graph_def=tensorflow_comp.graph_def,
        result=pb.TensorFlow.Binding(
            struct=pb.TensorFlow.StructBinding(
                element=(single_result_binding for _ in range(count)))))
  fn_type = computation_types.FunctionType(
      parameter_type,
      computation_types.StructType([(None, parameter_type) for _ in range(count)
                                   ]))
  return _tensorflow_comp(new_tf_pb, fn_type)
示例#9
0
def create_binary_operator_with_upcast(
    type_signature: computation_types.StructType,
    operator: Callable[[Any, Any], Any]) -> ProtoAndType:
  """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.StructType` with two elements, both of
      the same type or the second able to be upcast to the first, as explained
      in `apply_binary_operator_with_upcast`, and both containing only tuples
      and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `building_blocks.CompiledComputation` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_type(type_signature, computation_types.StructType)
  py_typecheck.check_callable(operator)
  type_analysis.check_tensorflow_compatible_type(type_signature)
  if not type_signature.is_struct() or len(type_signature) != 2:
    raise TypeError('To apply a binary operator, we must by definition have an '
                    'argument which is a `StructType` with 2 elements; '
                    'asked to create a binary operator for type: {t}'.format(
                        t=type_signature))
  if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
    raise TypeError(
        'Applying binary operators in TensorFlow is only '
        'supported on Tensors and StructTypes; you '
        'passed {t} which contains a SequenceType.'.format(t=type_signature))

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if type_spec.is_struct():
      elem_iter = structure.iter_elements(type_spec)
      return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type))
                               for elem_name, elem_type in elem_iter])
    elif type_spec.is_tensor():
      return tf.broadcast_to(to_pack, type_spec.shape)

  with tf.Graph().as_default() as graph:
    first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_signature[0], graph)
    operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
        'y', type_signature[1], graph)
    if type_signature[0].is_equivalent_to(type_signature[1]):
      second_arg = operand_2_value
    else:
      second_arg = _pack_into_type(operand_2_value, type_signature[0])

    if type_signature[0].is_tensor():
      result_value = operator(first_arg, second_arg)
    elif type_signature[0].is_struct():
      result_value = structure.map_structure(operator, first_arg, second_arg)
    else:
      raise TypeError('Encountered unexpected type {t}; can only handle Tensor '
                      'and StructTypes.'.format(t=type_signature[0]))

  result_type, result_binding = tensorflow_utils.capture_result_from_graph(
      result_value, graph)

  type_signature = computation_types.FunctionType(type_signature, result_type)
  parameter_binding = pb.TensorFlow.Binding(
      struct=pb.TensorFlow.StructBinding(
          element=[operand_1_binding, operand_2_binding]))
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
def _get_type_info(initialize_tree, before_broadcast, after_broadcast,
                   before_aggregate, after_aggregate):
  """Returns type information for an `tff.templates.IterativeProcess`.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create the expected type
  signatures when compiling a given `tff.templates.IterativeProcess` into a
  `tff.backends.mapreduce.CanonicalForm` and returns a `collections.OrderedDict`
  whose keys and order match the explicit and intermediate componets of
  `tff.backends.mapreduce.CanonicalForm` defined here:

  ```
  s1 = arg[0]
  c1 = arg[1]
  s2 = intrinsics.federated_map(cf.prepare, s1)
  c2 = intrinsics.federated_broadcast(s2)
  c3 = intrinsics.federated_zip([c1, c2])
  c4 = intrinsics.federated_map(cf.work, c3)
  c5 = c4[0]
  c6 = c4[1]
  s3 = intrinsics.federated_aggregate(c5,
                                      cf.zero(),
                                      cf.accumulate,
                                      cf.merge,
                                      cf.report)
  s4 = intrinsics.federated_secure_sum(c6, cf.bitwidth())
  s5 = intrinsics.federated_zip([s3, s4])
  s6 = intrinsics.federated_zip([s1, s5])
  s7 = intrinsics.federated_map(cf.update, s6)
  s8 = s7[0]
  s9 = s7[1]
  ```

  Note that the type signatures for the `initalize` and `next` components of an
  `tff.templates.IterativeProcess` are:

  initalize:  `( -> s1)`
  next:       `(<s1,c1> -> <s8,s9>)`

  However, the `next` component of an `tff.templates.IterativeProcess` has been
  split into a before and after broadcast and a before and after aggregate with
  the given semantics:

  ```
  (arg -> after(<arg, intrinsic(before(arg))>))
  ```

  as a result, the type signatures for the components split from the `next`
  component of an `tff.templates.IterativeProcess` are:

  before_broadcast:  `(<s1,c1> -> s2)`
  after_broadcast:   `(<<s1,c1>,c2> -> <s8,s9>)`
  before_aggregate:  `(<<s1,c1>,c2> -> <<c5,zero,accumulate,merge,report>,c6>)`
  after_aggregate:   `(<<<s1,c1>,c2>,<s3,s4>> -> <s8,s9>)`

  Args:
    initialize_tree: An instance of `building_blocks.ComputationBuildingBlock`
      representing the `initalize` component of an
      `tff.templates.IterativeProcess`.
    before_broadcast: The first result of splitting `next` component of an
      `tff.templates.IterativeProcess` on broadcast.
    after_broadcast: The second result of splitting `next` component of an
      `tff.templates.IterativeProcess` on broadcast.
    before_aggregate: The first result of splitting `next` component of an
      `tff.templates.IterativeProcess` on aggregate.
    after_aggregate: The second result of splitting `next` component of an
      `tff.templates.IterativeProcess` on aggregate.

  Raises:
    transformations.CanonicalFormCompilationError: If the arguments are of the
      wrong types.
  """

  # The type signature of `initalize` is: `( -> s1)`.
  init_tree_ty = initialize_tree.type_signature
  _check_type_is_no_arg_fn(init_tree_ty)
  _check_type(init_tree_ty.result, computation_types.FederatedType)
  _check_placement(init_tree_ty.result, placements.SERVER)
  # The named components of canonical form have no placement, so we must
  # remove the placement on the return type of initialize_tree
  initialize_type = computation_types.FunctionType(
      initialize_tree.type_signature.parameter,
      initialize_tree.type_signature.result.member)

  # The type signature of `before_broadcast` is: `(<s1,c1> -> s2)`.
  _check_type(before_broadcast.type_signature, computation_types.FunctionType)
  _check_type(before_broadcast.type_signature.parameter,
              computation_types.StructType)
  _check_len(before_broadcast.type_signature.parameter, 2)
  s1_type = before_broadcast.type_signature.parameter[0]
  _check_type(s1_type, computation_types.FederatedType)
  _check_placement(s1_type, placements.SERVER)
  c1_type = before_broadcast.type_signature.parameter[1]
  _check_type(c1_type, computation_types.FederatedType)
  _check_placement(c1_type, placements.CLIENTS)
  s2_type = before_broadcast.type_signature.result
  _check_type(s2_type, computation_types.FederatedType)
  _check_placement(s2_type, placements.SERVER)

  prepare_type = computation_types.FunctionType(s1_type.member, s2_type.member)

  # The type signature of `after_broadcast` is: `(<<s1,c1>,c2> -> <s8,s9>)'.
  _check_type(after_broadcast.type_signature, computation_types.FunctionType)
  _check_type(after_broadcast.type_signature.parameter,
              computation_types.StructType)
  _check_len(after_broadcast.type_signature.parameter, 2)
  _check_type(after_broadcast.type_signature.parameter[0],
              computation_types.StructType)
  _check_len(after_broadcast.type_signature.parameter[0], 2)
  _check_type_equal(after_broadcast.type_signature.parameter[0][0], s1_type)
  _check_type_equal(after_broadcast.type_signature.parameter[0][1], c1_type)
  c2_type = after_broadcast.type_signature.parameter[1]
  _check_type(c2_type, computation_types.FederatedType)
  _check_placement(c2_type, placements.CLIENTS)
  _check_type(after_broadcast.type_signature.result,
              computation_types.StructType)
  _check_len(after_broadcast.type_signature.result, 2)
  s8_type = after_broadcast.type_signature.result[0]
  _check_type(s8_type, computation_types.FederatedType)
  _check_placement(s8_type, placements.SERVER)
  s9_type = after_broadcast.type_signature.result[1]
  _check_type(s9_type, computation_types.FederatedType)
  _check_placement(s9_type, placements.SERVER)

  # The type signature of `before_aggregate` is:
  # `(<<s1,c1>,c2> -> <<c5,zero,accumulate,merge,report>,<c6,bitwidth>>)`.
  _check_type(before_aggregate.type_signature, computation_types.FunctionType)
  _check_type(before_aggregate.type_signature.parameter,
              computation_types.StructType)
  _check_len(before_aggregate.type_signature.parameter, 2)
  _check_type(before_aggregate.type_signature.parameter[0],
              computation_types.StructType)
  _check_len(before_aggregate.type_signature.parameter[0], 2)
  _check_type_equal(before_aggregate.type_signature.parameter[0][0], s1_type)
  _check_type_equal(before_aggregate.type_signature.parameter[0][1], c1_type)
  _check_type_equal(before_aggregate.type_signature.parameter[1], c2_type)
  _check_type(before_aggregate.type_signature.result,
              computation_types.StructType)
  _check_len(before_aggregate.type_signature.result, 2)
  _check_len(before_aggregate.type_signature.result[0], 5)
  c5_type = before_aggregate.type_signature.result[0][0]
  _check_type(c5_type, computation_types.FederatedType)
  _check_placement(c5_type, placements.CLIENTS)
  zero_type = computation_types.FunctionType(
      None, before_aggregate.type_signature.result[0][1])
  type_analysis.check_tensorflow_compatible_type(zero_type.result)
  accumulate_type = before_aggregate.type_signature.result[0][2]
  _check_type(accumulate_type, computation_types.FunctionType)
  merge_type = before_aggregate.type_signature.result[0][3]
  _check_type(merge_type, computation_types.FunctionType)
  report_type = before_aggregate.type_signature.result[0][4]
  _check_type(report_type, computation_types.FunctionType)
  _check_type(before_aggregate.type_signature.result[1],
              computation_types.StructType)
  _check_len(before_aggregate.type_signature.result[1], 2)
  c6_type = before_aggregate.type_signature.result[1][0]
  _check_type(c6_type, computation_types.FederatedType)
  _check_placement(c6_type, placements.CLIENTS)
  bitwidth_type = computation_types.FunctionType(
      None, before_aggregate.type_signature.result[1][1])
  type_analysis.check_tensorflow_compatible_type(bitwidth_type.result)

  c3_type = computation_types.FederatedType([c1_type.member, c2_type.member],
                                            placements.CLIENTS)
  c4_type = computation_types.FederatedType([c5_type.member, c6_type.member],
                                            placements.CLIENTS)

  # The type signature of `after_aggregate` is:
  # `(<<<s1,c1>,c2>,<s3,s4>> -> <s8,s9>)'.
  _check_type(after_aggregate.type_signature, computation_types.FunctionType)
  _check_type(after_aggregate.type_signature.parameter,
              computation_types.StructType)
  _check_len(after_aggregate.type_signature.parameter, 2)
  _check_type(after_aggregate.type_signature.parameter[0],
              computation_types.StructType)
  _check_len(after_aggregate.type_signature.parameter[0], 2)
  _check_type(after_aggregate.type_signature.parameter[0][0],
              computation_types.StructType)
  _check_len(after_aggregate.type_signature.parameter[0][0], 2)
  _check_type_equal(after_aggregate.type_signature.parameter[0][0][0], s1_type)
  _check_type_equal(after_aggregate.type_signature.parameter[0][0][1], c1_type)
  _check_type_equal(after_aggregate.type_signature.parameter[0][1], c2_type)
  _check_len(after_aggregate.type_signature.parameter[1], 2)
  s3_type = after_aggregate.type_signature.parameter[1][0]
  _check_type(s3_type, computation_types.FederatedType)
  _check_placement(s3_type, placements.SERVER)
  s4_type = after_aggregate.type_signature.parameter[1][1]
  _check_type(s4_type, computation_types.FederatedType)
  _check_placement(s4_type, placements.SERVER)
  _check_len(after_aggregate.type_signature.result, 2)
  _check_type_equal(after_aggregate.type_signature.result[0], s8_type)
  _check_type_equal(after_aggregate.type_signature.result[1], s9_type)

  work_type = computation_types.FunctionType(c3_type.member, c4_type.member)

  s5_type = computation_types.FederatedType([s3_type.member, s4_type.member],
                                            placements.SERVER)
  s6_type = computation_types.FederatedType([s1_type.member, s5_type.member],
                                            placements.SERVER)
  s7_type = computation_types.FederatedType([s8_type.member, s9_type.member],
                                            placements.SERVER)
  update_type = computation_types.FunctionType(s6_type.member, s7_type.member)

  return collections.OrderedDict(
      initialize_type=initialize_type,
      s1_type=s1_type,
      c1_type=c1_type,
      prepare_type=prepare_type,
      s2_type=s2_type,
      c2_type=c2_type,
      c3_type=c3_type,
      work_type=work_type,
      c4_type=c4_type,
      c5_type=c5_type,
      c6_type=c6_type,
      zero_type=zero_type,
      accumulate_type=accumulate_type,
      merge_type=merge_type,
      report_type=report_type,
      s3_type=s3_type,
      bitwidth_type=bitwidth_type,
      s4_type=s4_type,
      s5_type=s5_type,
      s6_type=s6_type,
      update_type=update_type,
      s7_type=s7_type,
      s8_type=s8_type,
      s9_type=s9_type,
  )
示例#11
0
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes,
                                        element_type: computation_types.Type):
    """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`.

  Args:
    serialized_graph_def: `bytes` object produced by
      `tensorflow_serialization.serialize_dataset`
    element_type: a `tff.Type` object representing the type structure of the
      elements yielded from the dataset.

  Returns:
    A `tf.data.Dataset` instance.
  """
    py_typecheck.check_type(element_type, computation_types.Type)
    type_analysis.check_tensorflow_compatible_type(element_type)

    def transform_to_tff_known_type(
        type_spec: computation_types.Type
    ) -> Tuple[computation_types.Type, bool]:
        """Transforms `StructType` to `StructWithPythonType`."""
        if type_spec.is_struct() and not type_spec.is_struct_with_python():
            field_is_named = tuple(
                name is not None
                for name, _ in structure.iter_elements(type_spec))
            has_names = any(field_is_named)
            is_all_named = all(field_is_named)
            if is_all_named:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=collections.OrderedDict), True
            elif not has_names:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=tuple), True
            else:
                raise TypeError(
                    'Cannot represent TFF type in TF because it contains '
                    f'partially named structures. Type: {type_spec}')
        return type_spec, False

    if element_type.is_struct():
        # TF doesn't suppor `structure.Strut` types, so we must transform the
        # `StructType` into a `StructWithPythonType` for use as the
        # `tf.data.Dataset.element_spec` later.
        tf_compatible_type, _ = type_transformations.transform_type_postorder(
            element_type, transform_to_tff_known_type)
    else:
        # We've checked this is only a struct or tensors, so we know this is a
        # `TensorType` here and will use as-is.
        tf_compatible_type = element_type

    def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec:
        return tf.TensorSpec(shape=t.shape, dtype=t.dtype)

    element_spec = type_conversions.structure_from_tensor_type_tree(
        type_to_tensorspec, tf_compatible_type)
    ds = tf.data.experimental.from_variant(
        tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def),
        structure=element_spec)
    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, tf_compatible_type)
def create_binary_operator_with_upcast(
        type_signature: computation_types.StructType,
        operator: Callable[[Any, Any], Any]) -> ComputationProtoAndType:
    """Creates TF computation upcasting its argument and applying `operator`.

  Args:
    type_signature: A `computation_types.StructType` with two elements, both
      only containing structs or tensors in their type tree. The first and
      second element must match in structure, or the second element may be a
      single tensor type that is broadcasted (upcast) to the leaves of the
      structure of the first type.
    operator: Callable defining the operator.

  Returns:
    Same as `create_binary_operator()`.
  """
    py_typecheck.check_type(type_signature, computation_types.StructType)
    py_typecheck.check_callable(operator)
    type_analysis.check_tensorflow_compatible_type(type_signature)
    if not type_signature.is_struct() or len(type_signature) != 2:
        raise TypeError(
            'To apply a binary operator, we must by definition have an '
            'argument which is a `StructType` with 2 elements; '
            'asked to create a binary operator for type: {t}'.format(
                t=type_signature))
    if type_analysis.contains(type_signature, lambda t: t.is_sequence()):
        raise TypeError('Applying binary operators in TensorFlow is only '
                        'supported on Tensors and StructTypes; you '
                        'passed {t} which contains a SequenceType.'.format(
                            t=type_signature))

    def _pack_into_type(to_pack, type_spec):
        """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
        if type_spec.is_struct():
            elem_iter = structure.iter_elements(type_spec)
            return structure.Struct([(elem_name,
                                      _pack_into_type(to_pack, elem_type))
                                     for elem_name, elem_type in elem_iter])
        elif type_spec.is_tensor():
            return tf.broadcast_to(to_pack, type_spec.shape)

    with tf.Graph().as_default() as graph:
        first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', type_signature[0], graph)
        operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', type_signature[1], graph)

        if type_signature[0].is_struct() and type_signature[1].is_struct():
            # If both the first and second arguments are structs with the same
            # structure, simply re-use operand_2_value as. `tf.nest.map_structure`
            # below will map the binary operator pointwise to the leaves of the
            # structure.
            if structure.is_same_structure(type_signature[0],
                                           type_signature[1]):
                second_arg = operand_2_value
            else:
                raise TypeError(
                    'Cannot upcast one structure to a different structure. '
                    '{x} -> {y}'.format(x=type_signature[1],
                                        y=type_signature[0]))
        elif type_signature[0].is_equivalent_to(type_signature[1]):
            second_arg = operand_2_value
        else:
            second_arg = _pack_into_type(operand_2_value, type_signature[0])

        if type_signature[0].is_tensor():
            result_value = operator(first_arg, second_arg)
        elif type_signature[0].is_struct():
            result_value = structure.map_structure(operator, first_arg,
                                                   second_arg)
        else:
            raise TypeError(
                'Encountered unexpected type {t}; can only handle Tensor '
                'and StructTypes.'.format(t=type_signature[0]))

    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result_value, graph)

    type_signature = computation_types.FunctionType(type_signature,
                                                    result_type)
    parameter_binding = pb.TensorFlow.Binding(
        struct=pb.TensorFlow.StructBinding(
            element=[operand_1_binding, operand_2_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)