def test_slicing_tuple_values(self, sequence_type):
        def _to_value(cbb):
            return value_impl.to_value(cbb, None,
                                       context_stack_impl.context_stack)

        t = sequence_type(range(0, 50, 10))
        comp, _ = transformations.replace_compiled_computations_names_with_unique_names(
            value_impl.ValueImpl.get_comp(_to_value(t)))
        v = _to_value(comp)

        self.assertEqual((str(v.type_signature)),
                         '<int32,int32,int32,int32,int32>')
        self.assertEqual(str(v[:]), str(v))

        sliced = v[:2]
        self.assertEqual((str(sliced.type_signature)), '<int32,int32>')
        self.assertEqual(str(sliced), '<comp#1(),comp#2()>')

        sliced = v[-3:]
        self.assertEqual((str(sliced.type_signature)), '<int32,int32,int32>')
        self.assertEqual(str(sliced), '<comp#3(),comp#4(),comp#5()>')

        sliced = v[::2]
        self.assertEqual((str(sliced.type_signature)), '<int32,int32,int32>')
        self.assertEqual(str(sliced), '<comp#1(),comp#3(),comp#5()>')
Exemple #2
0
    def test_replace_compiled_computations_names_does_not_replace_other_name(
            self):
        comp = computation_building_blocks.Reference('name', tf.int32)

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        self.assertEqual(transformed_comp._name, comp._name)
Exemple #3
0
    def test_replace_compiled_computations_names_replaces_name(self):
        fn = lambda: tf.constant(1)
        tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            fn, None, context_stack_impl.context_stack)
        compiled_comp = computation_building_blocks.CompiledComputation(
            tf_comp)
        comp = compiled_comp

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        self.assertNotEqual(transformed_comp._name, comp._name)
Exemple #4
0
    def test_get_curried(self):
        add_numbers = value_impl.ValueImpl(
            computation_building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(
                    computations.tf_computation(tf.add,
                                                [tf.int32, tf.int32]))),
            _context_stack)

        curried = value_utils.get_curried(add_numbers)
        self.assertEqual(str(curried.type_signature),
                         '(int32 -> (int32 -> int32))')

        comp, _ = transformations.replace_compiled_computations_names_with_unique_names(
            value_impl.ValueImpl.get_comp(curried))
        self.assertEqual(comp.tff_repr,
                         '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
Exemple #5
0
    def test_replace_compiled_computations_names_replaces_multiple_names(self):
        comps = []
        for _ in range(10):
            fn = lambda: tf.constant(1)
            tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
                fn, None, context_stack_impl.context_stack)
            compiled_comp = computation_building_blocks.CompiledComputation(
                tf_comp)
            comps.append(compiled_comp)
        tup = computation_building_blocks.Tuple(comps)
        comp = tup

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        comp_names = [element._name for element in comp]
        transformed_comp_names = [
            element._name for element in transformed_comp
        ]
        self.assertNotEqual(transformed_comp_names, comp_names)
        self.assertEqual(
            len(transformed_comp_names), len(set(transformed_comp_names)),
            'The transformed computation names are not unique: {}.'.format(
                transformed_comp_names))
Exemple #6
0
 def test_replace_compiled_computations_names_raises_type_error(self):
     with self.assertRaises(TypeError):
         transformations.replace_compiled_computations_names_with_unique_names(
             None)