def _unpack_and_call(fn, arg_types, kwarg_types, arg): """An interceptor function that unpacks 'arg' before calling `fn`. The function verifies the actual parameters before it forwards the call as a last-minute check. Args: fn: The function or defun to invoke. arg_types: The list of positional argument types (guaranteed to all be instances of computation_types.Types). kwarg_types: The dictionary of keyword argument types (guaranteed to all be instances of computation_types.Types). arg: The argument to unpack. Returns: The result of invoking `fn` on the unpacked arguments. Raises: TypeError: if types don't match. """ py_typecheck.check_type( arg, (anonymous_tuple.AnonymousTuple, value_base.Value)) args = [] for idx, expected_type in enumerate(arg_types): element_value = arg[idx] actual_type = type_utils.infer_type(element_value) if not type_utils.is_assignable_from( expected_type, actual_type): raise TypeError( 'Expected element at position {} to be ' 'of type {}, found {}.'.format( idx, str(expected_type), str(actual_type))) if isinstance(element_value, anonymous_tuple.AnonymousTuple): element_value = type_utils.convert_to_py_container( element_value, expected_type) args.append(element_value) kwargs = {} for name, expected_type in six.iteritems(kwarg_types): element_value = getattr(arg, name) actual_type = type_utils.infer_type(element_value) if not type_utils.is_assignable_from( expected_type, actual_type): raise TypeError('Expected element named {} to be ' 'of type {}, found {}.'.format( name, str(expected_type), str(actual_type))) if type_utils.is_anon_tuple_with_py_container( element_value, expected_type): element_value = type_utils.convert_to_py_container( element_value, expected_type) kwargs[name] = element_value return fn(*args, **kwargs)
def _convert_to_py_container(value, type_spec): """Converts value to a Python container if type_spec has an annotation.""" if type_utils.is_anon_tuple_with_py_container(value, type_spec): return type_utils.convert_to_py_container(value, type_spec) elif isinstance(type_spec, computation_types.SequenceType): if all( type_utils.is_anon_tuple_with_py_container( element, type_spec.element) for element in value): return [ type_utils.convert_to_py_container(element, type_spec.element) for element in value ] return value
def test_anon_tuple_with_names_to_container_without_names(self): anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), ('a', 2.0)]) types = [tf.int32, tf.float32] self.assertEqual( type_utils.convert_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType(types, tuple)), anon_tuple) self.assertEqual( type_utils.convert_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType(types, list)), anon_tuple)
async def _invoke(executor, comp, arg): """A coroutine that handles invocation. Args: executor: An instance of `executor_base.Executor`. comp: The first argument to `context_base.Context.invoke()`. arg: The optional second argument to `context_base.Context.invoke()`. Returns: The result of the invocation. """ result_type = comp.type_signature.result elements = [executor.create_value(comp)] if isinstance(arg, anonymous_tuple.AnonymousTuple): elements.append(executor.create_tuple(arg)) elements = await asyncio.gather(*elements) comp = elements[0] if len(elements) > 1: arg = elements[1] result = await executor.create_call(comp, arg) result_val = _unwrap(await result.compute()) if type_utils.is_anon_tuple_with_py_container(result_val, result_type): return type_utils.convert_to_py_container(result_val, result_type) else: return result_val
def _call(fn, parameter_type, arg): arg_type = type_utils.infer_type(arg) if not type_utils.is_assignable_from(parameter_type, arg_type): raise TypeError('Expected an argument of type {}, found {}.'.format( parameter_type, arg_type)) if type_utils.is_anon_tuple_with_py_container(arg, parameter_type): arg = type_utils.convert_to_py_container(arg, parameter_type) return fn(arg)
def invoke(self, fn, arg): comp = self._compile(fn) cardinalities = {} root_context = ComputationContext(cardinalities=cardinalities) computed_comp = self._compute(comp, root_context) type_utils.check_assignable_from(comp.type_signature, computed_comp.type_signature) if not isinstance(computed_comp.type_signature, computation_types.FunctionType): if arg is not None: raise TypeError('Unexpected argument {}.'.format(arg)) else: value = computed_comp.value result_type = fn.type_signature.result if type_utils.is_anon_tuple_with_py_container(value, result_type): return type_utils.convert_to_py_container(value, result_type) return value else: if arg is not None: def _handle_callable(fn, fn_type): py_typecheck.check_type(fn, computation_base.Computation) type_utils.check_assignable_from(fn.type_signature, fn_type) computed_fn = self._compute(self._compile(fn), root_context) return computed_fn.value computed_arg = ComputedValue( to_representation_for_type(arg, computed_comp.type_signature.parameter, _handle_callable), computed_comp.type_signature.parameter) cardinalities.update( runtime_utils.infer_cardinalities(computed_arg.value, computed_arg.type_signature)) else: computed_arg = None result = computed_comp.value(computed_arg) py_typecheck.check_type(result, ComputedValue) type_utils.check_assignable_from(comp.type_signature.result, result.type_signature) value = result.value fn_result_type = fn.type_signature.result if type_utils.is_anon_tuple_with_py_container(value, fn_result_type): return type_utils.convert_to_py_container(value, fn_result_type) return value
def test_anon_tuple_without_names_to_container_with_names(self): anon_tuple = anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)]) types = [('a', tf.int32), ('b', tf.float32)] self.assertEqual( type_utils.convert_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType(types, dict)), anon_tuple) self.assertEqual( type_utils.convert_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType( types, collections.OrderedDict)), anon_tuple) test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b']) self.assertEqual( type_utils.convert_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType( types, test_named_tuple)), anon_tuple)
def test_nested_py_containers(self): anon_tuple = anonymous_tuple.AnonymousTuple([ (None, 1), (None, 2.0), (None, anonymous_tuple.AnonymousTuple([ ('a', 3), ('b', anonymous_tuple.AnonymousTuple([(None, 4), (None, 5)])) ])) ]) expected_nested_structure = [1, 2.0, {'a': 3, 'b': (4, 5)}] self.assertEqual( type_utils.convert_to_py_container( anon_tuple, computation_types.NamedTupleTypeWithPyContainerType([ tf.int32, tf.float32, computation_types.NamedTupleTypeWithPyContainerType( [('a', tf.int32), ('b', computation_types.NamedTupleTypeWithPyContainerType( [tf.int32, tf.int32], tuple))], dict) ], list)), expected_nested_structure)
async def _invoke(executor, comp, arg): """A coroutine that handles invocation. Args: executor: An instance of `executor_base.Executor`. comp: The first argument to `context_base.Context.invoke()`. arg: The optional second argument to `context_base.Context.invoke()`. Returns: The result of the invocation. """ py_typecheck.check_type(comp.type_signature, computation_types.FunctionType) result_type = comp.type_signature.result if arg is not None: py_typecheck.check_type(arg, executor_value_base.ExecutorValue) comp = await executor.create_value(comp) result = await executor.create_call(comp, arg) py_typecheck.check_type(result, executor_value_base.ExecutorValue) result_val = _unwrap(await result.compute()) if type_utils.is_anon_tuple_with_py_container(result_val, result_type): return type_utils.convert_to_py_container(result_val, result_type) else: return result_val
def get_tf_typespec_and_binding(parameter_type, arg_names, unpack=None): """Computes a `TensorSpec` input_signature and bindings for parameter_type. This is the TF2 analog to `stamp_parameter_in_graph`. Args: parameter_type: The TFF type of the input to a tensorflow function. Must be either an instance of computation_types.Type (or convertible to it), or None in the case of a no-arg function. arg_names: String names for any positional arguments to the tensorflow function. unpack: Whether or not to unpack parameter_type into args and kwargs. See e.g. `function_utils.pack_args_into_anonymous_tuple`. Returns: A tuple (args_typespec, kwargs_typespec, binding), where args_typespec is a list and kwargs_typespec is a dict, both containing `tf.TensorSpec` objects. These structures are intended to be passed to the `get_concrete_function` method of a `tf.function`. Note the "binding" is "preliminary" in that it includes the names embedded in the TensorSpecs produced; these must be converted to the names of actual tensors based on the SignatureDef of the SavedModel before the binding is finalized. """ if parameter_type is None: return ([], {}, None) if unpack: arg_types, kwarg_types = function_utils.unpack_args_from_tuple( parameter_type) pack_in_tuple = True else: pack_in_tuple = False arg_types, kwarg_types = [parameter_type], {} py_typecheck.check_type(arg_names, collections.Iterable) if len(arg_names) < len(arg_types): raise ValueError( 'If provided, arg_names must be a list of at least {} strings to ' 'match the number of positional arguments. Found: {}'.format( len(arg_types), arg_names)) get_unique_name = UniqueNameFn() def _get_one_typespec_and_binding(parameter_name, parameter_type): """Returns a (tf.TensorSpec, binding) pair.""" parameter_type = computation_types.to_type(parameter_type) if isinstance(parameter_type, computation_types.TensorType): name = get_unique_name(parameter_name) tf_spec = tf.TensorSpec( shape=parameter_type.shape, dtype=parameter_type.dtype, name=name) binding = pb.TensorFlow.Binding( tensor=pb.TensorFlow.TensorBinding(tensor_name=name)) return (tf_spec, binding) elif isinstance(parameter_type, computation_types.NamedTupleType): element_typespec_pairs = [] element_bindings = [] have_names = False have_nones = False for e_name, e_type in anonymous_tuple.to_elements(parameter_type): if e_name is None: have_nones = True else: have_names = True name = '_'.join([n for n in [parameter_name, e_name] if n]) e_typespec, e_binding = _get_one_typespec_and_binding( name if name else None, e_type) element_typespec_pairs.append((e_name, e_typespec)) element_bindings.append(e_binding) # For a given argument or kwarg, we shouldn't have both: if (have_names and have_nones): raise ValueError( 'A mix of named and unnamed entries are not supported inside ' 'a nested structure representing a single argument in a call ' 'to a TensorFlow or Python function.\n' + str(parameter_type)) tf_typespec = anonymous_tuple.AnonymousTuple(element_typespec_pairs) return (tf_typespec, pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=element_bindings))) elif isinstance(parameter_type, computation_types.SequenceType): raise NotImplementedError('Sequence iputs not yet supported for TF 2.0.') else: raise ValueError('Parameter type component {} cannot be converted ' 'to a TensorSpec'.format(repr(parameter_type))) def get_arg_name(i): name = arg_names[i] if not isinstance(name, six.string_types): raise ValueError('arg_names must be strings, but got' + str(name)) return name # Main logic --- process arg_types and kwarg_types: arg_typespecs = [] kwarg_typespecs = {} bindings = [] for i, arg_type in enumerate(arg_types): name = get_arg_name(i) typespec, binding = _get_one_typespec_and_binding(name, arg_type) typespec = type_utils.convert_to_py_container(typespec, arg_type) arg_typespecs.append(typespec) bindings.append(binding) for name, kwarg_type in six.iteritems(kwarg_types): typespec, binding = _get_one_typespec_and_binding(name, kwarg_type) typespec = type_utils.convert_to_py_container(typespec, kwarg_type) kwarg_typespecs[name] = typespec bindings.append(binding) assert bindings, 'Given parameter_type {}, but produced no bindings.'.format( parameter_type) if pack_in_tuple: final_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding(element=bindings)) else: final_binding = bindings[0] return (arg_typespecs, kwarg_typespecs, final_binding)
def test_fails_not_named_tuple_type_with_py_container(self): with self.assertRaises(TypeError): type_utils.convert_to_py_container( anonymous_tuple.AnonymousTuple([(None, 1), (None, 2.0)]), computation_types.NamedTupleType([tf.int32, tf.float32]))
def test_fails_not_anon_tuple(self): with self.assertRaises(TypeError): type_utils.convert_to_py_container( (1, 2.0), computation_types.NamedTupleTypeWithPyContainerType( [tf.int32, tf.float32], list))