Ejemplo n.º 1
0
 def testComplexStruct(self):
     struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
     trace_a = function_trace_type.get_arg_spec(struct, False, False, True)
     trace_b = function_trace_type.get_arg_spec(struct, False, False, True)
     self.assertEqual(trace_a, trace_b)
     self.assertTrue(trace_a.is_subtype_of(trace_b))
     self.assertTrue(trace_b.is_subtype_of(trace_a))
Ejemplo n.º 2
0
    def testCustomUnequableTypeSucceeds(self):
        class CustomUnequable:
            def __eq__(self, o):
                raise ValueError

            def __hash__(self):
                return 0

        object_a = CustomUnequable()
        object_b = CustomUnequable()

        trace_a_1 = function_trace_type.get_arg_spec(object_a, False, True,
                                                     True)
        trace_a_2 = function_trace_type.get_arg_spec(object_a, False, True,
                                                     True)
        trace_b = function_trace_type.get_arg_spec(object_b, False, True, True)

        self.assertEqual(trace_a_1, trace_a_2)

        with self.assertRaises(ValueError):
            trace_a_1.__eq__(trace_b)

        del object_a
        self.assertNotEqual(trace_a_1, trace_a_2)
        self.assertNotEqual(trace_a_2, trace_a_1)

        del object_b
        self.assertNotEqual(trace_a_1, trace_a_2)
        self.assertNotEqual(trace_a_2, trace_a_1)
  def testCompositeAndSpec(self):
    composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
        values=[1, 2, 3], row_splits=[0, 2, 3])
    spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)

    self.assertEqual(
        function_trace_type.get_arg_spec(composite_tensor, False, False, True),
        function_trace_type.get_arg_spec(spec, False, False, True))
  def testDictEquality(self):
    trace_a = function_trace_type.get_arg_spec({1: 2, 3: 4}, False, False, True)
    trace_b = function_trace_type.get_arg_spec({1: 2, 3: 2}, False, False, True)
    trace_c = function_trace_type.get_arg_spec({1: 2, 3: 0}, False, False, True)
    trace_d = function_trace_type.get_arg_spec({3: 4, 1: 2}, False, False, True)

    self.assertNotEqual(trace_a, trace_b)
    self.assertNotEqual(trace_a, trace_c)
    self.assertNotEqual(trace_b, trace_c)
    self.assertEqual(trace_a, trace_d)
  def testListEquality(self):
    trace_a = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
    trace_b = function_trace_type.get_arg_spec([1, 2, 2, 4], False, False, True)
    trace_c = function_trace_type.get_arg_spec([1, 2, 3], False, False, True)
    trace_d = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)

    self.assertNotEqual(trace_a, trace_b)
    self.assertNotEqual(trace_a, trace_c)
    self.assertNotEqual(trace_b, trace_c)
    self.assertEqual(trace_a, trace_d)
Ejemplo n.º 6
0
    def testCustomUnhashableTypeFailsGracefully(self):
        class CustomUnhashable:
            def __eq__(self, o):
                return True

        obj = CustomUnhashable()
        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                r'could not be represented through the generic tracing type'):
            function_trace_type.get_arg_spec(obj, False, True, True)
  def testCustomUnhashableTypeFailsGracefully(self):

    class CustomUnhashable:

      def __eq__(self, o):
        return True

    obj = CustomUnhashable()
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        r'Could not determine tracing type of generic object'):
      function_trace_type.get_arg_spec(obj, False, True, True)
Ejemplo n.º 8
0
def make_cache_key(args,
                   include_tensor_ranks_only: bool = False
                   ) -> FunctionCacheKey:
    """Computes the cache key given the function arguments."""
    arg_spec = function_trace_type.get_arg_spec(
        args, include_tensor_ranks_only, _ENCODE_VARIABLES_BY_RESOURCE_ID,
        USE_FULL_TRACE_TYPE)
    return FunctionCacheKey(arg_spec, _make_execution_context())
  def testVariableAliasing(self):
    v1 = resource_variable_ops.ResourceVariable([1])
    v2 = resource_variable_ops.ResourceVariable([1])
    v3 = resource_variable_ops.ResourceVariable([1])
    all_unique = function_trace_type.get_arg_spec((v1, v2, v3), False, True,
                                                  True)
    all_same = function_trace_type.get_arg_spec((v1, v1, v1), False, True, True)
    self.assertNotEqual(all_unique, all_same)

    v3 = resource_variable_ops.ResourceVariable([2])
    v4 = resource_variable_ops.ResourceVariable([2])
    v5 = resource_variable_ops.ResourceVariable([2])
    all_unique_again = function_trace_type.get_arg_spec((v3, v4, v5), False,
                                                        True, True)
    all_same_again = function_trace_type.get_arg_spec((v4, v4, v4), False, True,
                                                      True)
    self.assertEqual(all_unique, all_unique_again)
    self.assertEqual(all_same, all_same_again)
Ejemplo n.º 10
0
  def testAttrsCacheKeyGeneration(self):
    if attr is None:
      self.skipTest('attr module is unavailable.')

    trace_a = function_trace_type.get_arg_spec(
        TestAttrsClass(1, 2), False, False, True)
    expected = function_trace_type.AttrsType(
        TestAttrsClass, (function_trace_type.GenericType(1),
                         function_trace_type.GenericType(2)))
    self.assertEqual(trace_a, expected)
    self.assertTrue(trace_a.is_subtype_of(trace_a))
Ejemplo n.º 11
0
def make_cache_key_from_args(args, kwargs, include_tensor_ranks_only=False):
    """Computes the cache key given inputs and execution context."""
    inputs = (args, kwargs)
    signature = function_trace_type.get_arg_spec(
        inputs, include_tensor_ranks_only, ENCODE_VARIABLES_BY_RESOURCE_ID,
        USE_FULL_TRACE_TYPE)

    (parent_graph, device_functions, colocation_stack,
     in_cross_replica_context, variable_policy,
     xla_context_id) = _cache_key_context()

    return CacheKey(signature, parent_graph, device_functions,
                    colocation_stack, in_cross_replica_context,
                    variable_policy, xla_context_id)
Ejemplo n.º 12
0
    def testIteratorAliasing(self):
        it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
        it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))

        self.assertEqual(
            function_trace_type.get_arg_spec((it1, it1), False, False, True),
            function_trace_type.get_arg_spec((it2, it2), False, False, True))
        self.assertEqual(
            function_trace_type.get_arg_spec((it1, it2), False, False, True),
            function_trace_type.get_arg_spec((it2, it1), False, False, True))
        self.assertNotEqual(
            function_trace_type.get_arg_spec((it1, it1), False, False, True),
            function_trace_type.get_arg_spec((it1, it2), False, False, True))
Ejemplo n.º 13
0
 def testGeneric(self):
     function_trace_type.get_arg_spec(1, False, True, True)
     function_trace_type.get_arg_spec(DummyGenericClass(), False, True,
                                      True)
 def encode_variables(var_list):
     function_trace_type.get_arg_spec(var_list, False, False, True)
 def encode_tensor_specs(tensor_specs):
     function_trace_type.get_arg_spec(tensor_specs, False, False, True)
Ejemplo n.º 16
0
 def encode_model(model):
     function_trace_type.get_arg_spec(model, False, False,
                                      function.USE_FULL_TRACE_TYPE)
Ejemplo n.º 17
0
 def encode_struct(struct):
     function_trace_type.get_arg_spec(struct, False, False,
                                      function.USE_FULL_TRACE_TYPE)
Ejemplo n.º 18
0
 def encode_tensor_specs(tensor_specs):
     function_trace_type.get_arg_spec(tensor_specs, False, False,
                                      function.USE_FULL_TRACE_TYPE)
Ejemplo n.º 19
0
 def encode_variables(var_list):
     function_trace_type.get_arg_spec(var_list, False, False,
                                      function.USE_FULL_TRACE_TYPE)
Ejemplo n.º 20
0
 def testList(self):
     function_trace_type.get_arg_spec([1, 2, 3], False, True, True)
Ejemplo n.º 21
0
    def testAttrs(self):
        if attr is None:
            self.skipTest('attr module is unavailable.')

        function_trace_type.get_arg_spec(TestAttrsClass(1, 2), False, True,
                                         True)
Ejemplo n.º 22
0
 def testDict(self):
     function_trace_type.get_arg_spec({1: 1, 2: 2, 3: 3}, False, True, True)
Ejemplo n.º 23
0
 def testTuple(self):
     function_trace_type.get_arg_spec((1, 2, 3), False, True, True)
Ejemplo n.º 24
0
 def testTensor(self):
     tensor = array_ops.zeros([10])
     function_trace_type.get_arg_spec(tensor, False, True, True)
 def encode_model(model):
     function_trace_type.get_arg_spec(model, False, False, True)
 def encode_struct(struct):
     function_trace_type.get_arg_spec(struct, False, False, True)