def test_basic_functionality_of_compiled_computation_class(self):
     comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
     x = computation_building_blocks.CompiledComputation(comp)
     self.assertEqual(str(x.type_signature), '(int32 -> int32)')
     self.assertEqual(str(x.proto), str(comp))
     self.assertTrue(
         re.match(
             r'CompiledComputation\([0-9a-f]+, '
             r'FunctionType\(TensorType\(tf\.int32\), '
             r'TensorType\(tf\.int32\)\)\)', repr(x)))
     self.assertTrue(re.match(r'comp#[0-9a-f]+', x.tff_repr))
     y = computation_building_blocks.CompiledComputation(comp, name='foo')
     self.assertEqual(y.tff_repr, 'comp#foo')
     self._serialize_deserialize_roundtrip_test(x)
Esempio n. 2
0
    def test_returns_string_for_comp_with_left_overhang(self):
        fn_type = computation_types.FunctionType(tf.int32, tf.int32)
        fn = computation_building_blocks.Reference('a', fn_type)
        proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            lambda: tf.constant(1), None, context_stack_impl.context_stack)
        compiled = computation_building_blocks.CompiledComputation(
            proto, 'bbbbb')
        arg = computation_building_blocks.Call(compiled)

        comp = computation_building_blocks.Call(fn, arg)
        compact_string = computation_building_blocks.compact_representation(
            comp)
        self.assertEqual(compact_string, 'a(comp#bbbbb())')
        formatted_string = computation_building_blocks.formatted_representation(
            comp)
        self.assertEqual(formatted_string, 'a(comp#bbbbb())')
        structural_string = computation_building_blocks.structural_representation(
            comp)
        # pyformat: disable
        self.assertEqual(
            structural_string, '           Call\n'
            '          /    \\\n'
            '    Ref(a)      Call\n'
            '               /\n'
            'Compiled(bbbbb)')
 def test_returns_string_for_compiled_computation(self):
     proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda: tf.constant(1), None, context_stack_impl.context_stack)
     comp = computation_building_blocks.CompiledComputation(proto, 'a')
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, 'comp#a')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, 'comp#a')
     structural_string = comp.structural_representation()
     self.assertEqual(structural_string, 'Compiled(a)')
Esempio n. 4
0
    def test_does_not_reduce_no_unnecessary_ops(self):
        def fn(x):
            return x

        comp = _create_compiled_computation(fn, tf.int32)
        pruned = computation_building_blocks.CompiledComputation(
            proto_transformations.prune_tensorflow_proto(comp.proto))
        ops_before = computation_building_block_utils.count_tensorflow_ops_in(
            comp)
        ops_after = computation_building_block_utils.count_tensorflow_ops_in(
            pruned)
        self.assertEqual(ops_before, ops_after)
Esempio n. 5
0
    def test_replace_compiled_computations_names_replaces_name(self):
        fn = lambda: tf.constant(1)
        tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            fn, None, context_stack_impl.context_stack)
        compiled_comp = computation_building_blocks.CompiledComputation(
            tf_comp)
        comp = compiled_comp

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        self.assertNotEqual(transformed_comp._name, comp._name)
Esempio n. 6
0
    def test_reduces_unnecessary_ops(self):
        def bad_fn(x):
            _ = tf.constant(0)
            return x

        comp = _create_compiled_computation(bad_fn, tf.int32)
        ops_before = computation_building_block_utils.count_tensorflow_ops_in(
            comp)
        reduced_proto = proto_transformations.prune_tensorflow_proto(
            comp.proto)
        reduced_comp = computation_building_blocks.CompiledComputation(
            reduced_proto)
        ops_after = computation_building_block_utils.count_tensorflow_ops_in(
            reduced_comp)
        self.assertLess(ops_after, ops_before)
def wrap_graph_parameter_as_tuple(comp, name=None):
    """Wraps the parameter of `comp` in a tuple binding.

  `wrap_graph_parameter_as_tuple` is intended as a preprocessing step
  to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type`
  can
  make the assumption that its argument `comp` always has a tuple binding,
  instead of dealing with the possibility of an unwrapped tensor or sequence
  binding.

  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation` whose
      parameter we wish to wrap in a tuple binding.
    name: Optional string argument, the name to assign to the element type in
      the constructed tuple. Defaults to `None`.

  Returns:
    A transformed version of comp representing exactly the same computation,
    but accepting a tuple containing one element--the parameter of `comp`.

  Raises:
    TypeError: If `comp` is not a
      `computation_building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    if name is not None:
        py_typecheck.check_type(name, six.string_types)
    proto = comp.proto
    proto_type = type_serialization.deserialize_type(proto.type)

    parameter_binding = [proto.tensorflow.parameter]
    parameter_type_list = [(name, proto_type.parameter)]
    new_parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding))

    new_function_type = computation_types.FunctionType(parameter_type_list,
                                                       proto_type.result)
    serialized_type = type_serialization.serialize_type(new_function_type)

    input_padded_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def,
                                 initialize_op=proto.tensorflow.initialize_op,
                                 parameter=new_parameter_binding,
                                 result=proto.tensorflow.result))

    return computation_building_blocks.CompiledComputation(input_padded_proto)
Esempio n. 8
0
    def test_prune_does_not_change_exeuction(self):
        def bad_fn(x):
            _ = tf.constant(0)
            return x

        comp = _create_compiled_computation(bad_fn, tf.int32)
        reduced_proto = proto_transformations.prune_tensorflow_proto(
            comp.proto)
        reduced_comp = computation_building_blocks.CompiledComputation(
            reduced_proto)

        orig_executable = computation_wrapper_instances.building_block_to_computation(
            comp)
        reduced_executable = computation_wrapper_instances.building_block_to_computation(
            reduced_comp)
        for k in range(5):
            self.assertEqual(orig_executable(k), reduced_executable(k))
Esempio n. 9
0
def _wrap_constant_as_value(const, context_stack):
    """Wraps the given Python constant as a `tff.Value`.

  Args:
    const: Python constant to be converted to TFF value. Anything convertible to
      Tensor via `tf.constant` can be passed in.
    context_stack: The context stack to use.

  Returns:
    An instance of `value_base.Value`.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        lambda: tf.constant(const), None, context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    called_comp = computation_building_blocks.Call(compiled_comp)
    return ValueImpl(called_comp, context_stack)
Esempio n. 10
0
 def test_returns_string_for_call_with_no_arg(self):
     proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda: tf.constant(1), None, context_stack_impl.context_stack)
     compiled = computation_building_blocks.CompiledComputation(proto, 'a')
     comp = computation_building_blocks.Call(compiled)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'comp#a()')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'comp#a()')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, '            Call\n'
                      '           /\n'
                      'Compiled(a)')
Esempio n. 11
0
def _create_call_to_py_fn(fn):
    r"""Creates a computation to call a Python function.

           Call
          /
  Compiled
  Computation

  Args:
    fn: The Python function to wrap.

  Returns:
    An instance of `computation_building_blocks.Call` wrapping the Python
    function.
  """
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        fn, None, context_stack_impl.context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    return computation_building_blocks.Call(compiled_comp)
Esempio n. 12
0
def _wrap_sequence_as_value(elements, element_type, context_stack):
    """Wraps `elements` as a TFF sequence with elements of type `element_type`.

  Args:
    elements: Python object to the wrapped as a TFF sequence value.
    element_type: An instance of `Type` that determines the type of elements of
      the sequence.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value`.

  Raises:
    TypeError: If `elements` and `element_type` are of incompatible types.
  """
    # TODO(b/113116813): Add support for other representations of sequences.
    py_typecheck.check_type(elements, list)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    # Checks that the types of all the individual elements are compatible with the
    # requested type of the sequence as a while.
    for elem in elements:
        elem_type = type_utils.infer_type(elem)
        if not type_utils.is_assignable_from(element_type, elem_type):
            raise TypeError(
                'Expected all sequence elements to be {}, found {}.'.format(
                    str(element_type), str(elem_type)))

    # Defines a no-arg function that builds a `tf.data.Dataset` from the elements.
    def _create_dataset_from_elements():
        return graph_utils.make_data_set_from_elements(tf.get_default_graph(),
                                                       elements, element_type)

    # Wraps the dataset as a value backed by a no-argument TensorFlow computation.
    return ValueImpl(
        computation_building_blocks.Call(
            computation_building_blocks.CompiledComputation(
                tensorflow_serialization.serialize_py_fn_as_tf_computation(
                    _create_dataset_from_elements, None, context_stack))),
        context_stack)
Esempio n. 13
0
def _create_lambda_to_cast(dtype1, dtype2):
    r"""Creates a computation to TensorFlow cast from dtype1 to dtype2.

  Lambda
        \
         Call
        /    \
  Compiled   Reference
  Computation

  Where `CompiledComputation` is a TensorFlow computation casting
  from `dtype1` to `dtype2`.

  The `dtype` arguments can be either instances of `tf.dtypes.DType` or
  `computation_types.TensorType`, but in the latter case the `tf.dtypes.DType`
  of these tensors will be extracted.

  Args:
    dtype1: The type of the argument.
    dtype2: The type to cast the argument to.

  Returns:
    An instance of `computation_building_blocks.Lambda` wrapping a function that
    casts TensorFlow dtype1 to dtype2.
  """
    if isinstance(dtype1, computation_types.TensorType):
        dtype1 = dtype1.dtype
    if isinstance(dtype2, computation_types.TensorType):
        dtype2 = dtype2.dtype
    py_typecheck.check_type(dtype1, tf.dtypes.DType)
    py_typecheck.check_type(dtype2, tf.dtypes.DType)
    arg = computation_building_blocks.Reference('arg', dtype1)
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        lambda x: tf.cast(x, dtype2), dtype1, context_stack_impl.context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    call = computation_building_blocks.Call(compiled_comp, arg)
    return computation_building_blocks.Lambda(arg.name, dtype1, call)
Esempio n. 14
0
    def test_replace_compiled_computations_names_replaces_multiple_names(self):
        comps = []
        for _ in range(10):
            fn = lambda: tf.constant(1)
            tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
                fn, None, context_stack_impl.context_stack)
            compiled_comp = computation_building_blocks.CompiledComputation(
                tf_comp)
            comps.append(compiled_comp)
        tup = computation_building_blocks.Tuple(comps)
        comp = tup

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        comp_names = [element._name for element in comp]
        transformed_comp_names = [
            element._name for element in transformed_comp
        ]
        self.assertNotEqual(transformed_comp_names, comp_names)
        self.assertEqual(
            len(transformed_comp_names), len(set(transformed_comp_names)),
            'The transformed computation names are not unique: {}.'.format(
                transformed_comp_names))
Esempio n. 15
0
def select_graph_output(comp, name=None, index=None):
    r"""Makes `CompiledComputation` with same input as `comp` and output `output`.

  Given an instance of `computation_building_blocks.CompiledComputation` `comp`
  with type signature (T -> <U, ...,V>), `select_output` returns a
  `CompiledComputation` representing the logic of calling `comp` and then
  selecting `name` or `index` from the resulting `tuple`. Notice that only one
  of `name` or `index` can be specified, and one of them must be specified.

  At the level of a TFF AST, `select_graph_output` is necessary to transform
  the structure below:

                                Select(x)
                                   |
                                  Call
                                 /    \
                            Graph      Comp

  into:

                                Call
                               /    \
  select_graph_output(Graph, x)      Comp


  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation` which
      must have result type `computation_types.NamedTupleType`, the function
      from which to select `output`.
    name: Instance of `str`, the name of the field to select from the output of
      `comp`. Optional, but one of `name` or `index` must be specified.
    index: Instance of `index`, the index of the field to select from the output
      of `comp`. Optional, but one of `name` or `index` must be specified.

  Returns:
    An instance of `computation_building_blocks.CompiledComputation` as
    described, the result of selecting the appropriate output from `comp`.
  """
    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    if index and name:
        raise ValueError(
            'Please specify at most one of `name` or `index` to `select_outputs`.'
        )
    if index is not None:
        py_typecheck.check_type(index, int)
    elif name is not None:
        py_typecheck.check_type(name, str)
    else:
        raise ValueError(
            'Please pass a `name` or `index` to `select_outputs`.')
    proto = comp.proto
    graph_result_binding = proto.tensorflow.result
    binding_oneof = graph_result_binding.WhichOneof('binding')
    if binding_oneof != 'tuple':
        raise TypeError(
            'Can only select output from a CompiledComputation with return type '
            'tuple; you have attempted a selection from a CompiledComputation '
            'with return type {}'.format(binding_oneof))
    proto_type = type_serialization.deserialize_type(proto.type)
    py_typecheck.check_type(proto_type.result,
                            computation_types.NamedTupleType)
    if name is None:
        result = [x for x in graph_result_binding.tuple.element][index]
        result_type = proto_type.result[index]
    else:
        type_names_list = [
            x[0] for x in anonymous_tuple.to_elements(proto_type.result)
        ]
        index = type_names_list.index(name)
        result = [x for x in graph_result_binding.tuple.element][index]
        result_type = proto_type.result[index]
    serialized_type = type_serialization.serialize_type(
        computation_types.FunctionType(proto_type.parameter, result_type))
    selected_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def,
                                 initialize_op=proto.tensorflow.initialize_op,
                                 parameter=proto.tensorflow.parameter,
                                 result=result))
    return computation_building_blocks.CompiledComputation(selected_proto)
Esempio n. 16
0
 def _transform(comp):
   if not _should_transform(comp):
     return comp, False
   transformed_comp = computation_building_blocks.CompiledComputation(
       comp.proto, six.next(name_generator))
   return transformed_comp, True
Esempio n. 17
0
 def _transform(comp):
     if not _should_transform(comp):
         return comp
     return computation_building_blocks.CompiledComputation(
         comp.proto, str(six.next(name_generator)))
Esempio n. 18
0
def _create_compiled_computation(py_fn, arg_type):
    proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        py_fn, arg_type, context_stack_impl.context_stack)
    return computation_building_blocks.CompiledComputation(proto)
Esempio n. 19
0
 def _transformation_func(x, name_sequence):
     if not isinstance(x, computation_building_blocks.CompiledComputation):
         return x
     else:
         return computation_building_blocks.CompiledComputation(
             x.proto, six.next(name_sequence))
Esempio n. 20
0
def permute_graph_inputs(comp, input_permutation):
    r"""Remaps input indices of `comp` to match the `input_permutation`.

  Changes the order of the parameters `comp`, an instance of
  `computation_building_blocks.CompiledComputation`. Accepts a permutation
  of the input tuple by index, and applies this permutation to the input
  bindings of `comp`. For example, given a `comp` which accepts a 3-tuple of
  types `[tf.int32, tf.float32, tf.bool]` as its parameter, passing in the
  input permutation

                          [2, 0, 1]

  would change the order of the parameter bindings accepted, so that
  `permute_graph_inputs` returns a
  `computation_building_blocks.CompiledComputation`
  accepting a 3-tuple of types `[tf.bool, tf.int32, tf.float32]`. Notice that
  we use one-line notation for our permutations, with beginning index 0
  (https://en.wikipedia.org/wiki/Permutation#One-line_notation).

  At the AST structural level, this is a no-op, as it simply takes in one
  instance of `computation_building_blocks.CompiledComputation` and returns
  another. However, it is necessary to make a replacement such as transforming:

                          Call
                         /    \
                    Graph      Tuple
                              / ... \
                  Selection(i)       Selection(j)
                       |                  |
                     Comp               Comp

  into:
                                     Call
                                    /    \
  permute_graph_inputs(Graph, [...])      Comp

  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation` whose
      parameter bindings we wish to permute.
    input_permutation: The permutation we wish to apply to the parameter
      bindings of `comp` in 0-indexed one-line permutation notation. This can be
      a Python `list` or `tuple` of `int`s.

  Returns:
    An instance of `computation_building_blocks.CompiledComputation` whose
    parameter bindings represent the same as the result of applying
    `input_permutation` to the parameter bindings of `comp`.

  Raises:
    TypeError: If the types specified in the args section do not match.
  """

    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    py_typecheck.check_type(input_permutation, (tuple, list))
    permutation_length = len(input_permutation)
    for index in input_permutation:
        py_typecheck.check_type(index, int)
    proto = comp.proto
    graph_parameter_binding = proto.tensorflow.parameter
    proto_type = type_serialization.deserialize_type(proto.type)
    py_typecheck.check_type(proto_type.parameter,
                            computation_types.NamedTupleType)
    binding_oneof = graph_parameter_binding.WhichOneof('binding')
    if binding_oneof != 'tuple':
        raise TypeError(
            'Can only permute inputs of a CompiledComputation with parameter type '
            'tuple; you have attempted a permutation with a CompiledComputation '
            'with parameter type {}'.format(binding_oneof))

    original_parameter_type_elements = anonymous_tuple.to_elements(
        proto_type.parameter)
    original_parameter_bindings = [
        x for x in graph_parameter_binding.tuple.element
    ]

    def _is_permutation(ls):
        #  Sorting since these shouldn't be long
        return list(sorted(ls)) == list(range(permutation_length))

    if len(original_parameter_bindings
           ) != permutation_length or not _is_permutation(input_permutation):
        raise ValueError(
            'Can only map the inputs with a true permutation; that '
            'is, the position of each input element must be uniquely specified. '
            'You have tried to map inputs {} with permutation {}'.format(
                original_parameter_bindings, input_permutation))

    new_parameter_bindings = [
        original_parameter_bindings[k] for k in input_permutation
    ]
    new_parameter_type_elements = [
        original_parameter_type_elements[k] for k in input_permutation
    ]

    serialized_type = type_serialization.serialize_type(
        computation_types.FunctionType(new_parameter_type_elements,
                                       proto_type.result))
    permuted_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def,
                                 initialize_op=proto.tensorflow.initialize_op,
                                 parameter=pb.TensorFlow.Binding(
                                     tuple=pb.TensorFlow.NamedTupleBinding(
                                         element=new_parameter_bindings)),
                                 result=proto.tensorflow.result))
    return computation_building_blocks.CompiledComputation(permuted_proto)
Esempio n. 21
0
def pad_graph_inputs_to_match_type(comp, type_signature):
    r"""Pads the parameter bindings of `comp` to match `type_signature`.

  The padded parameters here are in effect dummy bindings--they are not
  plugged in elsewhere in `comp`. This pattern is necessary to transform TFF
  expressions of the form:

                            Lambda(arg)
                                |
                              Call
                             /     \
          CompiledComputation       Tuple
                                      |
                                  Selection[i]
                                      |
                                    Ref(arg)

  into the form:

                          CompiledComputation

  in the case where arg in the above picture represents an n-tuple, where n > 1.

  Notice that some type manipulation must take place to execute the
  transformation outlined above, or anything similar to it, since the Lambda
  we are looking to replace accepts a parameter of an n-tuple, whereas the
  `CompiledComputation` represented above accepts only a 1-tuple.
  `pad_graph_inputs_to_match_type` is intended as an intermediate transform in
  the transformation outlined above, since there may also need to be some
  parameter permutation via `permute_graph_inputs`.

  Notice also that the existing parameter bindings of `comp` must match the
  first elements of `type_signature`. This is to ensure that we are attempting
  to pad only compatible `CompiledComputation`s to a given type signature.

  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation`
      representing the graph whose inputs we want to pad to match
      `type_signature`.
    type_signature: Instance of `computation_types.NamedTupleType` representing
      the type signature we wish to pad `comp` to accept as a parameter.

  Returns:
    A transformed version of `comp`, instance of
    `computation_building_blocks.CompiledComputation` which takes an argument
    of type `type_signature` and executes the same logic as `comp`. In
    particular, this transformed version will have the same return type as
    the original `comp`.

  Raises:
    TypeError: If the proto underlying `comp` has a parameter type which
      is not of `NamedTupleType`, the `type_signature` argument is not of type
      `NamedTupleType`, or there is a type mismatch between the declared
      parameters of `comp` and the requested `type_signature`.
    ValueError: If the requested `type_signature` is shorter than the
      parameter type signature declared by `comp`.
  """
    py_typecheck.check_type(type_signature, computation_types.NamedTupleType)
    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    proto = comp.proto
    graph_def = proto.tensorflow.graph_def
    graph_parameter_binding = proto.tensorflow.parameter
    proto_type = type_serialization.deserialize_type(proto.type)
    binding_oneof = graph_parameter_binding.WhichOneof('binding')
    if binding_oneof != 'tuple':
        raise TypeError(
            'Can only pad inputs of a CompiledComputation with parameter type '
            'tuple; you have attempted to pad a CompiledComputation '
            'with parameter type {}'.format(binding_oneof))
    # This line provides protection against an improperly serialized proto
    py_typecheck.check_type(proto_type.parameter,
                            computation_types.NamedTupleType)
    parameter_bindings = [x for x in graph_parameter_binding.tuple.element]
    parameter_type_elements = anonymous_tuple.to_elements(proto_type.parameter)
    type_signature_elements = anonymous_tuple.to_elements(type_signature)
    if len(parameter_bindings) > len(type_signature):
        raise ValueError(
            'We can only pad graph input bindings, never mask them. '
            'This means that a proposed type signature passed to '
            '`pad_graph_inputs_to_match_type` must have more elements '
            'than the existing type signature of the compiled '
            'computation. You have proposed a type signature of '
            'length {} be assigned to a computation with parameter '
            'type signature of length {}.'.format(len(type_signature),
                                                  len(parameter_bindings)))
    if any(x != type_signature_elements[idx]
           for idx, x in enumerate(parameter_type_elements)):
        raise TypeError(
            'The existing elements of the parameter type signature '
            'of the compiled computation in `pad_graph_inputs_to_match_type` '
            'must match the beginning of the proposed new type signature; '
            'you have proposed a parameter type of {} for a computation '
            'with existing parameter type {}.'.format(type_signature,
                                                      proto_type.parameter))
    g = tf.Graph()
    with g.as_default():
        tf.graph_util.import_graph_def(
            serialization_utils.unpack_graph_def(graph_def), name='')

    elems_to_stamp = anonymous_tuple.to_elements(
        type_signature)[len(parameter_bindings):]
    for name, type_spec in elems_to_stamp:
        if name is None:
            stamp_name = 'name'
        else:
            stamp_name = name
        _, stamped_binding = graph_utils.stamp_parameter_in_graph(
            stamp_name, type_spec, g)
        parameter_bindings.append(stamped_binding)
        parameter_type_elements.append((name, type_spec))

    new_parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_bindings))
    new_graph_def = g.as_graph_def()

    new_function_type = computation_types.FunctionType(parameter_type_elements,
                                                       proto_type.result)
    serialized_type = type_serialization.serialize_type(new_function_type)

    input_padded_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(new_graph_def),
            initialize_op=proto.tensorflow.initialize_op,
            parameter=new_parameter_binding,
            result=proto.tensorflow.result))

    return computation_building_blocks.CompiledComputation(input_padded_proto)
Esempio n. 22
0
def to_value(arg, type_spec, context_stack):
    """Converts the argument into an instance of `tff.Value`.

  The types of non-`tff.Value` arguments that are currently convertible to
  `tff.Value` include the following:

  * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all
    of which are converted into instances of `tff.Tuple`.
  * Placement literals, converted into instances of `tff.Placement`.
  * Computations.
  * Python constants of type `str`, `int`, `float`, `bool`
  * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent
    of numpy scalar types)

  Args:
    arg: Either an instance of `tff.Value`, or an argument convertible to
      `tff.Value`. The argument must not be `None`.
    type_spec: A type specifier that allows for disambiguating the target type
      (e.g., when two TFF types can be mapped to the same Python
      representations), or `None` if none available, in which case TFF tries to
      determine the type of the TFF value automatically.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value` corresponding to the given `arg`, and of TFF type
    matching the `type_spec` if specified (not `None`).

  Raises:
    TypeError: if `arg` is of an unsupported type, or of a type that does not
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
        type_utils.check_well_formed(type_spec)
    if isinstance(arg, ValueImpl):
        result = arg
    elif isinstance(arg, computation_building_blocks.ComputationBuildingBlock):
        result = ValueImpl(arg, context_stack)
    elif isinstance(arg, placement_literals.PlacementLiteral):
        result = ValueImpl(computation_building_blocks.Placement(arg),
                           context_stack)
    elif isinstance(arg, computation_base.Computation):
        result = ValueImpl(
            computation_building_blocks.CompiledComputation(
                computation_impl.ComputationImpl.get_proto(arg)),
            context_stack)
    elif type_spec is not None and isinstance(type_spec,
                                              computation_types.SequenceType):
        result = _wrap_sequence_as_value(arg, type_spec.element, context_stack)
    elif isinstance(arg, anonymous_tuple.AnonymousTuple):
        result = ValueImpl(
            computation_building_blocks.Tuple([
                (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
                for k, v in anonymous_tuple.to_elements(arg)
            ]), context_stack)
    elif py_typecheck.is_named_tuple(arg):
        result = to_value(arg._asdict(), None, context_stack)
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = six.iteritems(arg)
        else:
            items = sorted(six.iteritems(arg))
        value = computation_building_blocks.Tuple([
            (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
            for k, v in items
        ])
        result = ValueImpl(value, context_stack)
    elif isinstance(arg, (tuple, list)):
        result = ValueImpl(
            computation_building_blocks.Tuple([
                ValueImpl.get_comp(to_value(x, None, context_stack))
                for x in arg
            ]), context_stack)
    elif isinstance(arg, dtype_utils.TENSOR_REPRESENTATION_TYPES):
        result = _wrap_constant_as_value(arg, context_stack)
    elif isinstance(arg, (tf.Tensor, tf.Variable)):
        raise TypeError(
            'TensorFlow construct {} has been encountered in a federated '
            'context. TFF does not support mixing TF and federated orchestration '
            'code. Please wrap any TensorFlow constructs with '
            '`tff.tf_computation`.'.format(arg))
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a TFF value.'.
            format(py_typecheck.type_string(type(arg))))
    py_typecheck.check_type(result, ValueImpl)
    if (type_spec is not None and not type_utils.is_assignable_from(
            type_spec, result.type_signature)):
        raise TypeError(
            'The supplied argument maps to TFF type {}, which is incompatible '
            'with the requested type {}.'.format(str(result.type_signature),
                                                 str(type_spec)))
    return result
Esempio n. 23
0
def concatenate_tensorflow_blocks(tf_comp_list):
    """Concatenates inputs and outputs of its argument to a single TF block.

  Takes a Python `list` or `tuple` of instances of
  `computation_building_blocks.CompiledComputation`, and constructs a single
  instance of the same building block representing the computations present
  in this list concatenated side-by-side.

  There is one important convention here for callers to be aware of.
  `concatenate_tensorflow_blocks` does not perform any more packing into tuples
  than necessary. That is, if `tf_comp_list` contains only a single TF
  computation which declares a parameter, the parameter type of the resulting
  computation is exactly this single parameter type. Since all TF blocks declare
  a result, this is only of concern for parameters, and we will always return a
  function with a tuple for its result value.

  Args:
    tf_comp_list: Python `list` or `tuple` of
      `computation_building_blocks.CompiledComputation`s, whose inputs and
      outputs we wish to concatenate.

  Returns:
    A single instance of `computation_building_blocks.CompiledComputation`,
    representing all the computations in `tf_comp_list` concatenated
    side-by-side.

  Raises:
    ValueError: If we are passed less than 2 computations in `tf_comp_list`. In
      this case, the caller is likely using the wrong function.
    TypeError: If `tf_comp_list` is not a `list` or `tuple`, or if it
      contains anything other than TF blocks.
  """
    py_typecheck.check_type(tf_comp_list, (list, tuple))
    if len(tf_comp_list) < 2:
        raise ValueError(
            'We expect to concatenate at least two blocks of '
            'TensorFlow; otherwise the transformation you seek '
            'represents simply type manipulation, and you will find '
            'your desired function elsewhere in '
            '`compiled_computation_transforms`. You passed a tuple of '
            'length {}'.format(len(tf_comp_list)))
    tf_proto_list = []
    for comp in tf_comp_list:
        py_typecheck.check_type(
            comp, computation_building_blocks.CompiledComputation)
        tf_proto_list.append(comp.proto)

    (merged_graph, init_op_name, parameter_name_maps,
     result_name_maps) = graph_merge.concatenate_inputs_and_outputs(
         [_unpack_proto_into_graph_spec(x) for x in tf_proto_list])

    concatenated_parameter_bindings = _pack_concatenated_bindings(
        [x.tensorflow.parameter for x in tf_proto_list], parameter_name_maps)
    concatenated_result_bindings = _pack_concatenated_bindings(
        [x.tensorflow.result for x in tf_proto_list], result_name_maps)

    if concatenated_parameter_bindings:
        tf_result_proto = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(
                merged_graph.as_graph_def()),
            initialize_op=init_op_name,
            parameter=concatenated_parameter_bindings,
            result=concatenated_result_bindings)
    else:
        tf_result_proto = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(
                merged_graph.as_graph_def()),
            initialize_op=init_op_name,
            result=concatenated_result_bindings)

    parameter_type = _construct_concatenated_type(
        [x.type_signature.parameter for x in tf_comp_list])
    return_type = _construct_concatenated_type(
        [x.type_signature.result for x in tf_comp_list])
    function_type = computation_types.FunctionType(parameter_type, return_type)
    serialized_function_type = type_serialization.serialize_type(function_type)

    constructed_proto = pb.Computation(type=serialized_function_type,
                                       tensorflow=tf_result_proto)
    return computation_building_blocks.CompiledComputation(constructed_proto)