Exemplo n.º 1
0
 def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args,
                                      expected_kwargs):
     args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args)
     self.assertEqual(len(args), len(expected_args))
     for idx, arg in enumerate(args):
         self.assertTrue(
             type_utils.are_equivalent_types(
                 arg, computation_types.to_type(expected_args[idx])))
     self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys()))
     for k, v in six.iteritems(kwargs):
         self.assertTrue(
             type_utils.are_equivalent_types(computation_types.to_type(v),
                                             expected_kwargs[k]))
Exemplo n.º 2
0
 def test_unpack_args_from_anonymous_tuple(self, tuple_with_args,
                                           expected_args, expected_kwargs):
     self.assertEqual(
         function_utils.unpack_args_from_tuple(tuple_with_args),
         (expected_args, expected_kwargs))
Exemplo n.º 3
0
def get_tf_typespec_and_binding(parameter_type, arg_names, unpack=None):
  """Computes a `TensorSpec` input_signature and bindings for parameter_type.

  This is the TF2 analog to `stamp_parameter_in_graph`.

  Args:
    parameter_type: The TFF type of the input to a tensorflow function. Must be
      either an instance of computation_types.Type (or convertible to it), or
      None in the case of a no-arg function.
    arg_names: String names for any positional arguments to the tensorflow
      function.
    unpack: Whether or not to unpack parameter_type into args and kwargs. See
      e.g. `function_utils.pack_args_into_anonymous_tuple`.

  Returns:
    A tuple (args_typespec, kwargs_typespec, binding), where args_typespec is a
    list and kwargs_typespec is a dict, both containing `tf.TensorSpec`
    objects. These structures are intended to be passed to the
    `get_concrete_function` method of a `tf.function`.
    Note the "binding" is "preliminary" in that it includes the names embedded
    in the TensorSpecs produced; these must be converted to the names of actual
    tensors based on the SignatureDef of the SavedModel before the binding is
    finalized.
  """
  if parameter_type is None:
    return ([], {}, None)
  if unpack:
    arg_types, kwarg_types = function_utils.unpack_args_from_tuple(
        parameter_type)
    pack_in_tuple = True
  else:
    pack_in_tuple = False
    arg_types, kwarg_types = [parameter_type], {}

  py_typecheck.check_type(arg_names, collections.Iterable)
  if len(arg_names) < len(arg_types):
    raise ValueError(
        'If provided, arg_names must be a list of at least {} strings to '
        'match the number of positional arguments. Found: {}'.format(
            len(arg_types), arg_names))

  get_unique_name = UniqueNameFn()

  def _get_one_typespec_and_binding(parameter_name, parameter_type):
    """Returns a (tf.TensorSpec, binding) pair."""
    parameter_type = computation_types.to_type(parameter_type)
    if isinstance(parameter_type, computation_types.TensorType):
      name = get_unique_name(parameter_name)
      tf_spec = tf.TensorSpec(
          shape=parameter_type.shape, dtype=parameter_type.dtype, name=name)
      binding = pb.TensorFlow.Binding(
          tensor=pb.TensorFlow.TensorBinding(tensor_name=name))
      return (tf_spec, binding)
    elif isinstance(parameter_type, computation_types.NamedTupleType):
      element_typespec_pairs = []
      element_bindings = []
      have_names = False
      have_nones = False
      for e_name, e_type in anonymous_tuple.to_elements(parameter_type):
        if e_name is None:
          have_nones = True
        else:
          have_names = True
        name = '_'.join([n for n in [parameter_name, e_name] if n])
        e_typespec, e_binding = _get_one_typespec_and_binding(
            name if name else None, e_type)
        element_typespec_pairs.append((e_name, e_typespec))
        element_bindings.append(e_binding)
      # For a given argument or kwarg, we shouldn't have both:
      if (have_names and have_nones):
        raise ValueError(
            'A mix of named and unnamed entries are not supported inside '
            'a nested structure representing a single argument in a call '
            'to a TensorFlow or Python function.\n' + str(parameter_type))
      tf_typespec = anonymous_tuple.AnonymousTuple(element_typespec_pairs)
      return (tf_typespec,
              pb.TensorFlow.Binding(
                  tuple=pb.TensorFlow.NamedTupleBinding(
                      element=element_bindings)))
    elif isinstance(parameter_type, computation_types.SequenceType):
      raise NotImplementedError('Sequence iputs not yet supported for TF 2.0.')
    else:
      raise ValueError('Parameter type component {} cannot be converted '
                       'to a TensorSpec'.format(repr(parameter_type)))

  def get_arg_name(i):
    name = arg_names[i]
    if not isinstance(name, six.string_types):
      raise ValueError('arg_names must be strings, but got' + str(name))
    return name

  # Main logic --- process arg_types and kwarg_types:
  arg_typespecs = []
  kwarg_typespecs = {}
  bindings = []
  for i, arg_type in enumerate(arg_types):
    name = get_arg_name(i)
    typespec, binding = _get_one_typespec_and_binding(name, arg_type)
    typespec = type_utils.convert_to_py_container(typespec, arg_type)
    arg_typespecs.append(typespec)
    bindings.append(binding)
  for name, kwarg_type in six.iteritems(kwarg_types):
    typespec, binding = _get_one_typespec_and_binding(name, kwarg_type)
    typespec = type_utils.convert_to_py_container(typespec, kwarg_type)
    kwarg_typespecs[name] = typespec
    bindings.append(binding)

  assert bindings, 'Given parameter_type {}, but produced no bindings.'.format(
      parameter_type)
  if pack_in_tuple:
    final_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(element=bindings))
  else:
    final_binding = bindings[0]

  return (arg_typespecs, kwarg_typespecs, final_binding)