def test_basic_functionality_of_compiled_computation_class(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) x = computation_building_blocks.CompiledComputation(comp) self.assertEqual(str(x.type_signature), '(int32 -> int32)') self.assertEqual(str(x.proto), str(comp)) self.assertTrue( re.match( r'CompiledComputation\([0-9a-f]+, ' r'FunctionType\(TensorType\(tf\.int32\), ' r'TensorType\(tf\.int32\)\)\)', repr(x))) self.assertTrue(re.match(r'comp#[0-9a-f]+', x.tff_repr)) y = computation_building_blocks.CompiledComputation(comp, name='foo') self.assertEqual(y.tff_repr, 'comp#foo') self._serialize_deserialize_roundtrip_test(x)
def test_returns_string_for_comp_with_left_overhang(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = computation_building_blocks.Reference('a', fn_type) proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation( proto, 'bbbbb') arg = computation_building_blocks.Call(compiled) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a(comp#bbbbb())') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a(comp#bbbbb())') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' ' Ref(a) Call\n' ' /\n' 'Compiled(bbbbb)')
def test_returns_string_for_compiled_computation(self): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) comp = computation_building_blocks.CompiledComputation(proto, 'a') compact_string = comp.compact_representation() self.assertEqual(compact_string, 'comp#a') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'comp#a') structural_string = comp.structural_representation() self.assertEqual(structural_string, 'Compiled(a)')
def test_does_not_reduce_no_unnecessary_ops(self): def fn(x): return x comp = _create_compiled_computation(fn, tf.int32) pruned = computation_building_blocks.CompiledComputation( proto_transformations.prune_tensorflow_proto(comp.proto)) ops_before = computation_building_block_utils.count_tensorflow_ops_in( comp) ops_after = computation_building_block_utils.count_tensorflow_ops_in( pruned) self.assertEqual(ops_before, ops_after)
def test_replace_compiled_computations_names_replaces_name(self): fn = lambda: tf.constant(1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, None, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation( tf_comp) comp = compiled_comp transformed_comp = transformations.replace_compiled_computations_names_with_unique_names( comp) self.assertNotEqual(transformed_comp._name, comp._name)
def test_reduces_unnecessary_ops(self): def bad_fn(x): _ = tf.constant(0) return x comp = _create_compiled_computation(bad_fn, tf.int32) ops_before = computation_building_block_utils.count_tensorflow_ops_in( comp) reduced_proto = proto_transformations.prune_tensorflow_proto( comp.proto) reduced_comp = computation_building_blocks.CompiledComputation( reduced_proto) ops_after = computation_building_block_utils.count_tensorflow_ops_in( reduced_comp) self.assertLess(ops_after, ops_before)
def wrap_graph_parameter_as_tuple(comp, name=None): """Wraps the parameter of `comp` in a tuple binding. `wrap_graph_parameter_as_tuple` is intended as a preprocessing step to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type` can make the assumption that its argument `comp` always has a tuple binding, instead of dealing with the possibility of an unwrapped tensor or sequence binding. Args: comp: Instance of `computation_building_blocks.CompiledComputation` whose parameter we wish to wrap in a tuple binding. name: Optional string argument, the name to assign to the element type in the constructed tuple. Defaults to `None`. Returns: A transformed version of comp representing exactly the same computation, but accepting a tuple containing one element--the parameter of `comp`. Raises: TypeError: If `comp` is not a `computation_building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) if name is not None: py_typecheck.check_type(name, six.string_types) proto = comp.proto proto_type = type_serialization.deserialize_type(proto.type) parameter_binding = [proto.tensorflow.parameter] parameter_type_list = [(name, proto_type.parameter)] new_parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding)) new_function_type = computation_types.FunctionType(parameter_type_list, proto_type.result) serialized_type = type_serialization.serialize_type(new_function_type) input_padded_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def, initialize_op=proto.tensorflow.initialize_op, parameter=new_parameter_binding, result=proto.tensorflow.result)) return computation_building_blocks.CompiledComputation(input_padded_proto)
def test_prune_does_not_change_exeuction(self): def bad_fn(x): _ = tf.constant(0) return x comp = _create_compiled_computation(bad_fn, tf.int32) reduced_proto = proto_transformations.prune_tensorflow_proto( comp.proto) reduced_comp = computation_building_blocks.CompiledComputation( reduced_proto) orig_executable = computation_wrapper_instances.building_block_to_computation( comp) reduced_executable = computation_wrapper_instances.building_block_to_computation( reduced_comp) for k in range(5): self.assertEqual(orig_executable(k), reduced_executable(k))
def _wrap_constant_as_value(const, context_stack): """Wraps the given Python constant as a `tff.Value`. Args: const: Python constant to be converted to TFF value. Anything convertible to Tensor via `tf.constant` can be passed in. context_stack: The context stack to use. Returns: An instance of `value_base.Value`. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(const), None, context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) called_comp = computation_building_blocks.Call(compiled_comp) return ValueImpl(called_comp, context_stack)
def test_returns_string_for_call_with_no_arg(self): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation(proto, 'a') comp = computation_building_blocks.Call(compiled) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'comp#a()') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'comp#a()') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, ' Call\n' ' /\n' 'Compiled(a)')
def _create_call_to_py_fn(fn): r"""Creates a computation to call a Python function. Call / Compiled Computation Args: fn: The Python function to wrap. Returns: An instance of `computation_building_blocks.Call` wrapping the Python function. """ tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, None, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) return computation_building_blocks.Call(compiled_comp)
def _wrap_sequence_as_value(elements, element_type, context_stack): """Wraps `elements` as a TFF sequence with elements of type `element_type`. Args: elements: Python object to the wrapped as a TFF sequence value. element_type: An instance of `Type` that determines the type of elements of the sequence. context_stack: The context stack to use. Returns: An instance of `tff.Value`. Raises: TypeError: If `elements` and `element_type` are of incompatible types. """ # TODO(b/113116813): Add support for other representations of sequences. py_typecheck.check_type(elements, list) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) # Checks that the types of all the individual elements are compatible with the # requested type of the sequence as a while. for elem in elements: elem_type = type_utils.infer_type(elem) if not type_utils.is_assignable_from(element_type, elem_type): raise TypeError( 'Expected all sequence elements to be {}, found {}.'.format( str(element_type), str(elem_type))) # Defines a no-arg function that builds a `tf.data.Dataset` from the elements. def _create_dataset_from_elements(): return graph_utils.make_data_set_from_elements(tf.get_default_graph(), elements, element_type) # Wraps the dataset as a value backed by a no-argument TensorFlow computation. return ValueImpl( computation_building_blocks.Call( computation_building_blocks.CompiledComputation( tensorflow_serialization.serialize_py_fn_as_tf_computation( _create_dataset_from_elements, None, context_stack))), context_stack)
def _create_lambda_to_cast(dtype1, dtype2): r"""Creates a computation to TensorFlow cast from dtype1 to dtype2. Lambda \ Call / \ Compiled Reference Computation Where `CompiledComputation` is a TensorFlow computation casting from `dtype1` to `dtype2`. The `dtype` arguments can be either instances of `tf.dtypes.DType` or `computation_types.TensorType`, but in the latter case the `tf.dtypes.DType` of these tensors will be extracted. Args: dtype1: The type of the argument. dtype2: The type to cast the argument to. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that casts TensorFlow dtype1 to dtype2. """ if isinstance(dtype1, computation_types.TensorType): dtype1 = dtype1.dtype if isinstance(dtype2, computation_types.TensorType): dtype2 = dtype2.dtype py_typecheck.check_type(dtype1, tf.dtypes.DType) py_typecheck.check_type(dtype2, tf.dtypes.DType) arg = computation_building_blocks.Reference('arg', dtype1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: tf.cast(x, dtype2), dtype1, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) call = computation_building_blocks.Call(compiled_comp, arg) return computation_building_blocks.Lambda(arg.name, dtype1, call)
def test_replace_compiled_computations_names_replaces_multiple_names(self): comps = [] for _ in range(10): fn = lambda: tf.constant(1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, None, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation( tf_comp) comps.append(compiled_comp) tup = computation_building_blocks.Tuple(comps) comp = tup transformed_comp = transformations.replace_compiled_computations_names_with_unique_names( comp) comp_names = [element._name for element in comp] transformed_comp_names = [ element._name for element in transformed_comp ] self.assertNotEqual(transformed_comp_names, comp_names) self.assertEqual( len(transformed_comp_names), len(set(transformed_comp_names)), 'The transformed computation names are not unique: {}.'.format( transformed_comp_names))
def select_graph_output(comp, name=None, index=None): r"""Makes `CompiledComputation` with same input as `comp` and output `output`. Given an instance of `computation_building_blocks.CompiledComputation` `comp` with type signature (T -> <U, ...,V>), `select_output` returns a `CompiledComputation` representing the logic of calling `comp` and then selecting `name` or `index` from the resulting `tuple`. Notice that only one of `name` or `index` can be specified, and one of them must be specified. At the level of a TFF AST, `select_graph_output` is necessary to transform the structure below: Select(x) | Call / \ Graph Comp into: Call / \ select_graph_output(Graph, x) Comp Args: comp: Instance of `computation_building_blocks.CompiledComputation` which must have result type `computation_types.NamedTupleType`, the function from which to select `output`. name: Instance of `str`, the name of the field to select from the output of `comp`. Optional, but one of `name` or `index` must be specified. index: Instance of `index`, the index of the field to select from the output of `comp`. Optional, but one of `name` or `index` must be specified. Returns: An instance of `computation_building_blocks.CompiledComputation` as described, the result of selecting the appropriate output from `comp`. """ py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) if index and name: raise ValueError( 'Please specify at most one of `name` or `index` to `select_outputs`.' ) if index is not None: py_typecheck.check_type(index, int) elif name is not None: py_typecheck.check_type(name, str) else: raise ValueError( 'Please pass a `name` or `index` to `select_outputs`.') proto = comp.proto graph_result_binding = proto.tensorflow.result binding_oneof = graph_result_binding.WhichOneof('binding') if binding_oneof != 'tuple': raise TypeError( 'Can only select output from a CompiledComputation with return type ' 'tuple; you have attempted a selection from a CompiledComputation ' 'with return type {}'.format(binding_oneof)) proto_type = type_serialization.deserialize_type(proto.type) py_typecheck.check_type(proto_type.result, computation_types.NamedTupleType) if name is None: result = [x for x in graph_result_binding.tuple.element][index] result_type = proto_type.result[index] else: type_names_list = [ x[0] for x in anonymous_tuple.to_elements(proto_type.result) ] index = type_names_list.index(name) result = [x for x in graph_result_binding.tuple.element][index] result_type = proto_type.result[index] serialized_type = type_serialization.serialize_type( computation_types.FunctionType(proto_type.parameter, result_type)) selected_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def, initialize_op=proto.tensorflow.initialize_op, parameter=proto.tensorflow.parameter, result=result)) return computation_building_blocks.CompiledComputation(selected_proto)
def _transform(comp): if not _should_transform(comp): return comp, False transformed_comp = computation_building_blocks.CompiledComputation( comp.proto, six.next(name_generator)) return transformed_comp, True
def _transform(comp): if not _should_transform(comp): return comp return computation_building_blocks.CompiledComputation( comp.proto, str(six.next(name_generator)))
def _create_compiled_computation(py_fn, arg_type): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( py_fn, arg_type, context_stack_impl.context_stack) return computation_building_blocks.CompiledComputation(proto)
def _transformation_func(x, name_sequence): if not isinstance(x, computation_building_blocks.CompiledComputation): return x else: return computation_building_blocks.CompiledComputation( x.proto, six.next(name_sequence))
def permute_graph_inputs(comp, input_permutation): r"""Remaps input indices of `comp` to match the `input_permutation`. Changes the order of the parameters `comp`, an instance of `computation_building_blocks.CompiledComputation`. Accepts a permutation of the input tuple by index, and applies this permutation to the input bindings of `comp`. For example, given a `comp` which accepts a 3-tuple of types `[tf.int32, tf.float32, tf.bool]` as its parameter, passing in the input permutation [2, 0, 1] would change the order of the parameter bindings accepted, so that `permute_graph_inputs` returns a `computation_building_blocks.CompiledComputation` accepting a 3-tuple of types `[tf.bool, tf.int32, tf.float32]`. Notice that we use one-line notation for our permutations, with beginning index 0 (https://en.wikipedia.org/wiki/Permutation#One-line_notation). At the AST structural level, this is a no-op, as it simply takes in one instance of `computation_building_blocks.CompiledComputation` and returns another. However, it is necessary to make a replacement such as transforming: Call / \ Graph Tuple / ... \ Selection(i) Selection(j) | | Comp Comp into: Call / \ permute_graph_inputs(Graph, [...]) Comp Args: comp: Instance of `computation_building_blocks.CompiledComputation` whose parameter bindings we wish to permute. input_permutation: The permutation we wish to apply to the parameter bindings of `comp` in 0-indexed one-line permutation notation. This can be a Python `list` or `tuple` of `int`s. Returns: An instance of `computation_building_blocks.CompiledComputation` whose parameter bindings represent the same as the result of applying `input_permutation` to the parameter bindings of `comp`. Raises: TypeError: If the types specified in the args section do not match. """ py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) py_typecheck.check_type(input_permutation, (tuple, list)) permutation_length = len(input_permutation) for index in input_permutation: py_typecheck.check_type(index, int) proto = comp.proto graph_parameter_binding = proto.tensorflow.parameter proto_type = type_serialization.deserialize_type(proto.type) py_typecheck.check_type(proto_type.parameter, computation_types.NamedTupleType) binding_oneof = graph_parameter_binding.WhichOneof('binding') if binding_oneof != 'tuple': raise TypeError( 'Can only permute inputs of a CompiledComputation with parameter type ' 'tuple; you have attempted a permutation with a CompiledComputation ' 'with parameter type {}'.format(binding_oneof)) original_parameter_type_elements = anonymous_tuple.to_elements( proto_type.parameter) original_parameter_bindings = [ x for x in graph_parameter_binding.tuple.element ] def _is_permutation(ls): # Sorting since these shouldn't be long return list(sorted(ls)) == list(range(permutation_length)) if len(original_parameter_bindings ) != permutation_length or not _is_permutation(input_permutation): raise ValueError( 'Can only map the inputs with a true permutation; that ' 'is, the position of each input element must be uniquely specified. ' 'You have tried to map inputs {} with permutation {}'.format( original_parameter_bindings, input_permutation)) new_parameter_bindings = [ original_parameter_bindings[k] for k in input_permutation ] new_parameter_type_elements = [ original_parameter_type_elements[k] for k in input_permutation ] serialized_type = type_serialization.serialize_type( computation_types.FunctionType(new_parameter_type_elements, proto_type.result)) permuted_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def, initialize_op=proto.tensorflow.initialize_op, parameter=pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=new_parameter_bindings)), result=proto.tensorflow.result)) return computation_building_blocks.CompiledComputation(permuted_proto)
def pad_graph_inputs_to_match_type(comp, type_signature): r"""Pads the parameter bindings of `comp` to match `type_signature`. The padded parameters here are in effect dummy bindings--they are not plugged in elsewhere in `comp`. This pattern is necessary to transform TFF expressions of the form: Lambda(arg) | Call / \ CompiledComputation Tuple | Selection[i] | Ref(arg) into the form: CompiledComputation in the case where arg in the above picture represents an n-tuple, where n > 1. Notice that some type manipulation must take place to execute the transformation outlined above, or anything similar to it, since the Lambda we are looking to replace accepts a parameter of an n-tuple, whereas the `CompiledComputation` represented above accepts only a 1-tuple. `pad_graph_inputs_to_match_type` is intended as an intermediate transform in the transformation outlined above, since there may also need to be some parameter permutation via `permute_graph_inputs`. Notice also that the existing parameter bindings of `comp` must match the first elements of `type_signature`. This is to ensure that we are attempting to pad only compatible `CompiledComputation`s to a given type signature. Args: comp: Instance of `computation_building_blocks.CompiledComputation` representing the graph whose inputs we want to pad to match `type_signature`. type_signature: Instance of `computation_types.NamedTupleType` representing the type signature we wish to pad `comp` to accept as a parameter. Returns: A transformed version of `comp`, instance of `computation_building_blocks.CompiledComputation` which takes an argument of type `type_signature` and executes the same logic as `comp`. In particular, this transformed version will have the same return type as the original `comp`. Raises: TypeError: If the proto underlying `comp` has a parameter type which is not of `NamedTupleType`, the `type_signature` argument is not of type `NamedTupleType`, or there is a type mismatch between the declared parameters of `comp` and the requested `type_signature`. ValueError: If the requested `type_signature` is shorter than the parameter type signature declared by `comp`. """ py_typecheck.check_type(type_signature, computation_types.NamedTupleType) py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) proto = comp.proto graph_def = proto.tensorflow.graph_def graph_parameter_binding = proto.tensorflow.parameter proto_type = type_serialization.deserialize_type(proto.type) binding_oneof = graph_parameter_binding.WhichOneof('binding') if binding_oneof != 'tuple': raise TypeError( 'Can only pad inputs of a CompiledComputation with parameter type ' 'tuple; you have attempted to pad a CompiledComputation ' 'with parameter type {}'.format(binding_oneof)) # This line provides protection against an improperly serialized proto py_typecheck.check_type(proto_type.parameter, computation_types.NamedTupleType) parameter_bindings = [x for x in graph_parameter_binding.tuple.element] parameter_type_elements = anonymous_tuple.to_elements(proto_type.parameter) type_signature_elements = anonymous_tuple.to_elements(type_signature) if len(parameter_bindings) > len(type_signature): raise ValueError( 'We can only pad graph input bindings, never mask them. ' 'This means that a proposed type signature passed to ' '`pad_graph_inputs_to_match_type` must have more elements ' 'than the existing type signature of the compiled ' 'computation. You have proposed a type signature of ' 'length {} be assigned to a computation with parameter ' 'type signature of length {}.'.format(len(type_signature), len(parameter_bindings))) if any(x != type_signature_elements[idx] for idx, x in enumerate(parameter_type_elements)): raise TypeError( 'The existing elements of the parameter type signature ' 'of the compiled computation in `pad_graph_inputs_to_match_type` ' 'must match the beginning of the proposed new type signature; ' 'you have proposed a parameter type of {} for a computation ' 'with existing parameter type {}.'.format(type_signature, proto_type.parameter)) g = tf.Graph() with g.as_default(): tf.graph_util.import_graph_def( serialization_utils.unpack_graph_def(graph_def), name='') elems_to_stamp = anonymous_tuple.to_elements( type_signature)[len(parameter_bindings):] for name, type_spec in elems_to_stamp: if name is None: stamp_name = 'name' else: stamp_name = name _, stamped_binding = graph_utils.stamp_parameter_in_graph( stamp_name, type_spec, g) parameter_bindings.append(stamped_binding) parameter_type_elements.append((name, type_spec)) new_parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_bindings)) new_graph_def = g.as_graph_def() new_function_type = computation_types.FunctionType(parameter_type_elements, proto_type.result) serialized_type = type_serialization.serialize_type(new_function_type) input_padded_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(new_graph_def), initialize_op=proto.tensorflow.initialize_op, parameter=new_parameter_binding, result=proto.tensorflow.result)) return computation_building_blocks.CompiledComputation(input_padded_proto)
def to_value(arg, type_spec, context_stack): """Converts the argument into an instance of `tff.Value`. The types of non-`tff.Value` arguments that are currently convertible to `tff.Value` include the following: * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all of which are converted into instances of `tff.Tuple`. * Placement literals, converted into instances of `tff.Placement`. * Computations. * Python constants of type `str`, `int`, `float`, `bool` * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent of numpy scalar types) Args: arg: Either an instance of `tff.Value`, or an argument convertible to `tff.Value`. The argument must not be `None`. type_spec: A type specifier that allows for disambiguating the target type (e.g., when two TFF types can be mapped to the same Python representations), or `None` if none available, in which case TFF tries to determine the type of the TFF value automatically. context_stack: The context stack to use. Returns: An instance of `tff.Value` corresponding to the given `arg`, and of TFF type matching the `type_spec` if specified (not `None`). Raises: TypeError: if `arg` is of an unsupported type, or of a type that does not match `type_spec`. Raises explicit error message if TensorFlow constructs are encountered, as TensorFlow code should be sealed away from TFF federated context. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_utils.check_well_formed(type_spec) if isinstance(arg, ValueImpl): result = arg elif isinstance(arg, computation_building_blocks.ComputationBuildingBlock): result = ValueImpl(arg, context_stack) elif isinstance(arg, placement_literals.PlacementLiteral): result = ValueImpl(computation_building_blocks.Placement(arg), context_stack) elif isinstance(arg, computation_base.Computation): result = ValueImpl( computation_building_blocks.CompiledComputation( computation_impl.ComputationImpl.get_proto(arg)), context_stack) elif type_spec is not None and isinstance(type_spec, computation_types.SequenceType): result = _wrap_sequence_as_value(arg, type_spec.element, context_stack) elif isinstance(arg, anonymous_tuple.AnonymousTuple): result = ValueImpl( computation_building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in anonymous_tuple.to_elements(arg) ]), context_stack) elif py_typecheck.is_named_tuple(arg): result = to_value(arg._asdict(), None, context_stack) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = six.iteritems(arg) else: items = sorted(six.iteritems(arg)) value = computation_building_blocks.Tuple([ (k, ValueImpl.get_comp(to_value(v, None, context_stack))) for k, v in items ]) result = ValueImpl(value, context_stack) elif isinstance(arg, (tuple, list)): result = ValueImpl( computation_building_blocks.Tuple([ ValueImpl.get_comp(to_value(x, None, context_stack)) for x in arg ]), context_stack) elif isinstance(arg, dtype_utils.TENSOR_REPRESENTATION_TYPES): result = _wrap_constant_as_value(arg, context_stack) elif isinstance(arg, (tf.Tensor, tf.Variable)): raise TypeError( 'TensorFlow construct {} has been encountered in a federated ' 'context. TFF does not support mixing TF and federated orchestration ' 'code. Please wrap any TensorFlow constructs with ' '`tff.tf_computation`.'.format(arg)) else: raise TypeError( 'Unable to interpret an argument of type {} as a TFF value.'. format(py_typecheck.type_string(type(arg)))) py_typecheck.check_type(result, ValueImpl) if (type_spec is not None and not type_utils.is_assignable_from( type_spec, result.type_signature)): raise TypeError( 'The supplied argument maps to TFF type {}, which is incompatible ' 'with the requested type {}.'.format(str(result.type_signature), str(type_spec))) return result
def concatenate_tensorflow_blocks(tf_comp_list): """Concatenates inputs and outputs of its argument to a single TF block. Takes a Python `list` or `tuple` of instances of `computation_building_blocks.CompiledComputation`, and constructs a single instance of the same building block representing the computations present in this list concatenated side-by-side. There is one important convention here for callers to be aware of. `concatenate_tensorflow_blocks` does not perform any more packing into tuples than necessary. That is, if `tf_comp_list` contains only a single TF computation which declares a parameter, the parameter type of the resulting computation is exactly this single parameter type. Since all TF blocks declare a result, this is only of concern for parameters, and we will always return a function with a tuple for its result value. Args: tf_comp_list: Python `list` or `tuple` of `computation_building_blocks.CompiledComputation`s, whose inputs and outputs we wish to concatenate. Returns: A single instance of `computation_building_blocks.CompiledComputation`, representing all the computations in `tf_comp_list` concatenated side-by-side. Raises: ValueError: If we are passed less than 2 computations in `tf_comp_list`. In this case, the caller is likely using the wrong function. TypeError: If `tf_comp_list` is not a `list` or `tuple`, or if it contains anything other than TF blocks. """ py_typecheck.check_type(tf_comp_list, (list, tuple)) if len(tf_comp_list) < 2: raise ValueError( 'We expect to concatenate at least two blocks of ' 'TensorFlow; otherwise the transformation you seek ' 'represents simply type manipulation, and you will find ' 'your desired function elsewhere in ' '`compiled_computation_transforms`. You passed a tuple of ' 'length {}'.format(len(tf_comp_list))) tf_proto_list = [] for comp in tf_comp_list: py_typecheck.check_type( comp, computation_building_blocks.CompiledComputation) tf_proto_list.append(comp.proto) (merged_graph, init_op_name, parameter_name_maps, result_name_maps) = graph_merge.concatenate_inputs_and_outputs( [_unpack_proto_into_graph_spec(x) for x in tf_proto_list]) concatenated_parameter_bindings = _pack_concatenated_bindings( [x.tensorflow.parameter for x in tf_proto_list], parameter_name_maps) concatenated_result_bindings = _pack_concatenated_bindings( [x.tensorflow.result for x in tf_proto_list], result_name_maps) if concatenated_parameter_bindings: tf_result_proto = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( merged_graph.as_graph_def()), initialize_op=init_op_name, parameter=concatenated_parameter_bindings, result=concatenated_result_bindings) else: tf_result_proto = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( merged_graph.as_graph_def()), initialize_op=init_op_name, result=concatenated_result_bindings) parameter_type = _construct_concatenated_type( [x.type_signature.parameter for x in tf_comp_list]) return_type = _construct_concatenated_type( [x.type_signature.result for x in tf_comp_list]) function_type = computation_types.FunctionType(parameter_type, return_type) serialized_function_type = type_serialization.serialize_type(function_type) constructed_proto = pb.Computation(type=serialized_function_type, tensorflow=tf_result_proto) return computation_building_blocks.CompiledComputation(constructed_proto)