Exemple #1
0
 def test_infer_type_with_dict(self):
     self.assertEqual(str(type_utils.infer_type({
         'a': 1,
         'b': 2.0,
     })), '<a=int32,b=float32>')
     self.assertEqual(str(type_utils.infer_type({
         'b': 2.0,
         'a': 1,
     })), '<a=int32,b=float32>')
            def _unpack_and_call(fn, arg_types, kwarg_types, arg):
                """An interceptor function that unpacks 'arg' before calling `fn`.

        The function verifies the actual parameters before it forwards the
        call as a last-minute check.

        Args:
          fn: The function or defun to invoke.
          arg_types: The list of positional argument types (guaranteed to all be
            instances of computation_types.Types).
          kwarg_types: The dictionary of keyword argument types (guaranteed to
            all be instances of computation_types.Types).
          arg: The argument to unpack.

        Returns:
          The result of invoking `fn` on the unpacked arguments.

        Raises:
          TypeError: if types don't match.
        """
                py_typecheck.check_type(
                    arg, (anonymous_tuple.AnonymousTuple, value_base.Value))
                args = []
                for idx, expected_type in enumerate(arg_types):
                    element_value = arg[idx]
                    actual_type = type_utils.infer_type(element_value)
                    if not type_utils.is_assignable_from(
                            expected_type, actual_type):
                        raise TypeError(
                            'Expected element at position {} to be '
                            'of type {}, found {}.'.format(
                                idx, str(expected_type), str(actual_type)))
                    if isinstance(element_value,
                                  anonymous_tuple.AnonymousTuple):
                        element_value = type_utils.convert_to_py_container(
                            element_value, expected_type)
                    args.append(element_value)
                kwargs = {}
                for name, expected_type in six.iteritems(kwarg_types):
                    element_value = getattr(arg, name)
                    actual_type = type_utils.infer_type(element_value)
                    if not type_utils.is_assignable_from(
                            expected_type, actual_type):
                        raise TypeError('Expected element named {} to be '
                                        'of type {}, found {}.'.format(
                                            name, str(expected_type),
                                            str(actual_type)))
                    if type_utils.is_anon_tuple_with_py_container(
                            element_value, expected_type):
                        element_value = type_utils.convert_to_py_container(
                            element_value, expected_type)
                    kwargs[name] = element_value
                return fn(*args, **kwargs)
Exemple #3
0
 def test_infer_type_with_dict_dataset(self):
     self.assertEqual(
         str(
             type_utils.infer_type(
                 tf.data.Dataset.from_tensors({
                     'a': 10,
                     'b': 20,
                 }))), '<a=int32,b=int32>*')
     self.assertEqual(
         str(
             type_utils.infer_type(
                 tf.data.Dataset.from_tensors({
                     'b': 20,
                     'a': 10,
                 }))), '<a=int32,b=int32>*')
Exemple #4
0
 def _call(func, parameter_type, arg):
     arg_type = type_utils.infer_type(arg)
     if not type_utils.is_assignable_from(parameter_type, arg_type):
         raise TypeError(
             'Expected an argument of type {}, found {}.'.format(
                 str(parameter_type), str(arg_type)))
     return func(arg)
Exemple #5
0
 def test_infer_type_with_tff_value(self):
     self.assertEqual(
         str(
             type_utils.infer_type(
                 value_impl.ValueImpl(
                     computation_building_blocks.Reference('foo', tf.bool),
                     context_stack_impl.context_stack))), 'bool')
Exemple #6
0
    def __call__(self, *args, **kwargs):
        """Invokes this polymorphic function with a given set of arguments.

    Args:
      *args: Positional args.
      **kwargs: Keyword args.

    Returns:
      The result of calling a concrete function, instantiated on demand based
      on the argument types (and cached for future calls).

    Raises:
      TypeError: if the concrete functions created by the factory are of the
        wrong computation_types.
    """
        # TODO(b/113112885): We may need to normalize individuals args, such that
        # the type is more predictable and uniform (e.g., if someone supplies an
        # unordered dictionary), possibly by converting dict-like and tuple-like
        # containters into anonymous tuples.
        packed_arg = pack_args_into_anonymous_tuple(args, kwargs)
        arg_type = type_utils.infer_type(packed_arg)
        key = repr(arg_type)
        concrete_fn = self._concrete_function_cache.get(key)
        if not concrete_fn:
            concrete_fn = self._concrete_function_factory(arg_type)
            py_typecheck.check_type(concrete_fn, ConcreteFunction,
                                    'concrete function')
            if concrete_fn.type_signature.parameter != arg_type:
                raise TypeError(
                    'Expected a concrete function that takes parameter {}, got one '
                    'that takes {}.'.format(
                        arg_type, concrete_fn.type_signature.parameter))
            self._concrete_function_cache[key] = concrete_fn
        return concrete_fn(packed_arg)
Exemple #7
0
 def test_infer_type_with_nested_dataset_list_tuple(self):
     self.assertEqual(
         str(
             type_utils.infer_type(
                 tuple([(tf.data.Dataset.from_tensors(x), )
                        for x in [1, True, [0.5]]]))),
         '<<int32*>,<bool*>,<float32[1]*>>')
 def test_infer_type_with_ordered_dict(self):
   t = type_utils.infer_type(collections.OrderedDict([('b', 2.0), ('a', 1)]))
   self.assertEqual(str(t), '<b=float32,a=int32>')
   self.assertIsInstance(t,
                         computation_types.NamedTupleTypeWithPyContainerType)
   self.assertIs(
       computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
           t), collections.OrderedDict)
Exemple #9
0
 def test_infer_type_with_int_list(self):
     t = type_utils.infer_type([1, 2, 3])
     self.assertEqual(str(t), '<int32,int32,int32>')
     self.assertIsInstance(
         t, computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(t), list)
Exemple #10
0
 def test_infer_type_with_nested_float_list(self):
     t = type_utils.infer_type([[0.1], [0.2], [0.3]])
     self.assertEqual(str(t), '<<float32>,<float32>,<float32>>')
     self.assertIsInstance(
         t, computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(t), list)
Exemple #11
0
 def test_infer_type_with_anonymous_tuple(self):
     self.assertEqual(
         str(
             type_utils.infer_type(
                 anonymous_tuple.AnonymousTuple([
                     ('a', 10),
                     (None, False),
                 ]))), '<a=int32,bool>')
Exemple #12
0
 def _call(fn, parameter_type, arg):
   arg_type = type_utils.infer_type(arg)
   if not type_utils.is_assignable_from(parameter_type, arg_type):
     raise TypeError('Expected an argument of type {}, found {}.'.format(
         parameter_type, arg_type))
   if type_utils.is_anon_tuple_with_py_container(arg, parameter_type):
     arg = type_utils.convert_to_py_container(arg, parameter_type)
   return fn(arg)
Exemple #13
0
 def test_infer_type_with_namedtuple(self):
     test_named_tuple = collections.namedtuple('TestNamedTuple', 'y x')
     t = type_utils.infer_type(test_named_tuple(1, True))
     self.assertEqual(str(t), '<y=int32,x=bool>')
     self.assertIsInstance(
         t, computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(t), test_named_tuple)
Exemple #14
0
 def test_infer_type_with_ordered_dict_dataset(self):
     self.assertEqual(
         str(
             type_utils.infer_type(
                 tf.data.Dataset.from_tensors(
                     collections.OrderedDict([
                         ('b', 20),
                         ('a', 10),
                     ])))), '<b=int32,a=int32>*')
Exemple #15
0
 def test_infer_type_with_nested_dataset_list_tuple(self):
   t = type_utils.infer_type(
       tuple([(tf.data.Dataset.from_tensors(x),) for x in [1, True, [0.5]]]))
   self.assertEqual(str(t), '<<int32*>,<bool*>,<float32[1]*>>')
   self.assertIsInstance(t,
                         computation_types.NamedTupleTypeWithPyContainerType)
   self.assertIs(
       computation_types.NamedTupleTypeWithPyContainerType.get_container_type(
           t), tuple)
Exemple #16
0
 def test_infer_type_with_dataset_list(self):
     t = type_utils.infer_type(
         [tf.data.Dataset.from_tensors(x) for x in [1, True, [0.5]]])
     self.assertEqual(str(t), '<int32*,bool*,float32[1]*>')
     self.assertIsInstance(
         t, computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(t), list)
 def _sequence_sum(self, arg):
   inferred_type_spec = type_utils.infer_type(arg.value[0])
   py_typecheck.check_type(arg.type_signature, computation_types.SequenceType)
   total = self._generic_zero(inferred_type_spec)
   for v in arg.value:
     total = self._generic_plus(
         ComputedValue(
             anonymous_tuple.AnonymousTuple([(None, total.value), (None, v)]),
             [arg.type_signature.element, arg.type_signature.element]))
   return total
Exemple #18
0
 def test_infer_type_with_anonymous_tuple(self):
     t = type_utils.infer_type(
         anonymous_tuple.AnonymousTuple([
             ('a', 10),
             (None, False),
         ]))
     self.assertEqual(str(t), '<a=int32,bool>')
     self.assertIsInstance(t, computation_types.NamedTupleType)
     self.assertNotIsInstance(
         t, computation_types.NamedTupleTypeWithPyContainerType)
Exemple #19
0
 def test_infer_type_with_nested_anonymous_tuple(self):
     self.assertEqual(
         str(
             type_utils.infer_type(
                 anonymous_tuple.AnonymousTuple([
                     ('a', 10),
                     (None,
                      anonymous_tuple.AnonymousTuple([
                          (None, True),
                          (None, 0.5),
                      ])),
                 ]))), '<a=int32,<bool,float32>>')
Exemple #20
0
 def test_infer_type_with_dataset_of_named_tuple(self):
     test_named_tuple = collections.namedtuple('_', 'A B')
     t = type_utils.infer_type(
         tf.data.Dataset.from_tensor_slices({
             'x': [0.0],
             'y': [1],
         }).map(lambda v: test_named_tuple(v['x'], v['y'])))
     self.assertEqual(str(t), '<A=float32,B=int32>*')
     self.assertIsInstance(
         t.element, computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(t.element), test_named_tuple)
Exemple #21
0
    def test_infer_type_with_dict(self):
        v1 = {
            'a': 1,
            'b': 2.0,
        }
        inferred_type = type_utils.infer_type(v1)
        self.assertEqual(str(inferred_type), '<a=int32,b=float32>')
        self.assertIsInstance(
            inferred_type, computation_types.NamedTupleTypeWithPyContainerType)
        self.assertIs(
            computation_types.NamedTupleTypeWithPyContainerType.
            get_container_type(inferred_type), dict)

        v2 = {
            'b': 2.0,
            'a': 1,
        }
        inferred_type = type_utils.infer_type(v2)
        self.assertEqual(str(inferred_type), '<a=int32,b=float32>')
        self.assertIsInstance(
            inferred_type, computation_types.NamedTupleTypeWithPyContainerType)
        self.assertIs(
            computation_types.NamedTupleTypeWithPyContainerType.
            get_container_type(inferred_type), dict)
Exemple #22
0
def is_signature_compatible_with_types(signature: inspect.Signature, *args,
                                       **kwargs) -> bool:
    """Determines if functions matching signature accept `args` and `kwargs`.

  Args:
    signature: An instance of `inspect.Signature` to verify agains the
      arguments.
    *args: Zero or more positional arguments, all of which must be instances of
      computation_types.Type or something convertible to it by
      computation_types.to_type().
    **kwargs: Zero or more keyword arguments, all of which must be instances of
      computation_types.Type or something convertible to it by
      computation_types.to_type().

  Returns:
    `True` or `False`, depending on the outcome of the test.

  Raises:
    TypeError: if the arguments are of the wrong computation_types.
  """
    try:
        bound_args = signature.bind(*args, **kwargs)
    except TypeError:
        return False

    # If we have no defaults then `bind` will have raised `TypeError` if the
    # signature was not compatible with *args and **kwargs.
    if all(p.default is inspect.Parameter.empty
           for p in signature.parameters.values()):
        return True

    # Otherwise we need to check the defaults against the types that were given to
    # ensure they are compatible.
    for p in signature.parameters.values():
        if p.default is inspect.Parameter.empty or p.default is None:
            # No default value or optional.
            continue
        arg_value = bound_args.arguments.get(p.name, p.default)
        if arg_value is p.default:
            continue
        arg_type = computation_types.to_type(arg_value)
        default_type = type_utils.infer_type(p.default)
        if not type_utils.is_assignable_from(arg_type, default_type):
            return False
    return True
Exemple #23
0
def _wrap_sequence_as_value(elements, element_type, context_stack):
    """Wraps `elements` as a TFF sequence with elements of type `element_type`.

  Args:
    elements: Python object to the wrapped as a TFF sequence value.
    element_type: An instance of `Type` that determines the type of elements of
      the sequence.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value`.

  Raises:
    TypeError: If `elements` and `element_type` are of incompatible types.
  """
    # TODO(b/113116813): Add support for other representations of sequences.
    py_typecheck.check_type(elements, list)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    # Checks that the types of all the individual elements are compatible with the
    # requested type of the sequence as a while.
    for elem in elements:
        elem_type = type_utils.infer_type(elem)
        if not type_utils.is_assignable_from(element_type, elem_type):
            raise TypeError(
                'Expected all sequence elements to be {}, found {}.'.format(
                    str(element_type), str(elem_type)))

    # Defines a no-arg function that builds a `tf.data.Dataset` from the elements.
    def _create_dataset_from_elements():
        return graph_utils.make_data_set_from_elements(tf.get_default_graph(),
                                                       elements, element_type)

    # Wraps the dataset as a value backed by a no-argument TensorFlow computation.
    return ValueImpl(
        computation_building_blocks.Call(
            computation_building_blocks.CompiledComputation(
                tensorflow_serialization.serialize_py_fn_as_tf_computation(
                    _create_dataset_from_elements, None, context_stack))),
        context_stack)
Exemple #24
0
def is_argspec_compatible_with_types(argspec, *args, **kwargs):
    """Determines if functions matching 'argspec' accept given 'args'/'kwargs'.

  Args:
    argspec: An instance of `SimpleArgSpec` to verify agains the arguments.
    *args: Zero or more positional arguments, all of which must be instances of
      computation_types.Type or something convertible to it by
      computation_types.to_type().
    **kwargs: Zero or more keyword arguments, all of which must be instances of
      computation_types.Type or something convertible to it by
      computation_types.to_type().

  Returns:
    True or false, depending on the outcome of the test.

  Raises:
    TypeError: if the arguments are of the wrong computation_types.
  """
    try:
        callargs = get_callargs_for_argspec(argspec, *args, **kwargs)
        if not argspec.defaults:
            return True
    except TypeError:
        return False

    # As long as we have been able to construct 'callargs', and there are no
    # default values to verify against the given types, there is nothing more
    # to do here, otherwise we have to verify the types of defaults against
    # the types we've been given as parameters to this function.
    num_specargs_without_defaults = len(argspec.args) - len(argspec.defaults)
    for idx, default_value in enumerate(argspec.defaults):
        if default_value is not None:
            arg_name = argspec.args[num_specargs_without_defaults + idx]
            call_arg = callargs[arg_name]
            if call_arg is not default_value:
                arg_type = computation_types.to_type(call_arg)
                default_type = type_utils.infer_type(default_value)
                if not type_utils.is_assignable_from(arg_type, default_type):
                    return False
    return True
Exemple #25
0
    def __call__(self, *args, **kwargs):
        """Invokes this polymorphic function with a given set of arguments.

    Args:
      *args: Positional args.
      **kwargs: Keyword args.

    Returns:
      The result of calling a concrete function, instantiated on demand based
      on the argument types (and cached for future calls).

    Raises:
      TypeError: if the concrete functions created by the factory are of the
        wrong computation_types.
    """
        # TODO(b/113112885): We may need to normalize individuals args, such that
        # the type is more predictable and uniform (e.g., if someone supplies an
        # unordered dictionary), possibly by converting dict-like and tuple-like
        # containters into anonymous tuples.
        packed_arg = pack_args_into_anonymous_tuple(args, kwargs)
        arg_type = type_utils.infer_type(packed_arg)
        # We know the argument types have been packed, so force unpacking.
        concrete_fn = self.fn_for_argument_type(arg_type, unpack=True)
        return concrete_fn(packed_arg)
Exemple #26
0
 def test_infer_type_with_scalar_int_variable_tensor(self):
     self.assertEqual(str(type_utils.infer_type(tf.Variable(10))), 'int32')
Exemple #27
0
 def test_infer_type_with_int_dataset(self):
     self.assertEqual(
         str(type_utils.infer_type(tf.data.Dataset.from_tensors(10))),
         'int32*')
Exemple #28
0
 def test_infer_type_with_scalar_int_array_variable_tensor(self):
     self.assertEqual(str(type_utils.infer_type(tf.Variable([10]))),
                      'int32[1]')
Exemple #29
0
 def test_infer_type_with_scalar_float_variable_tensor(self):
     self.assertEqual(str(type_utils.infer_type(tf.Variable(0.5))),
                      'float32')
Exemple #30
0
 def test_infer_type_with_scalar_bool_variable_tensor(self):
     self.assertEqual(str(type_utils.infer_type(tf.Variable(True))), 'bool')