예제 #1
0
 def test_raises_type_error(self):
     with self.assertRaises(TypeError):
         computation_building_blocks.compact_representation(None)
     with self.assertRaises(TypeError):
         computation_building_blocks.formatted_representation(None)
     with self.assertRaises(TypeError):
         computation_building_blocks.structural_representation(None)
 def test_basic_functionality_of_tuple_class(self):
     x = computation_building_blocks.Reference('foo', tf.int32)
     y = computation_building_blocks.Reference('bar', tf.bool)
     z = computation_building_blocks.Tuple([x, ('y', y)])
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Tuple([('', y)])
     self.assertIsInstance(z, anonymous_tuple.AnonymousTuple)
     self.assertEqual(str(z.type_signature), '<int32,y=bool>')
     self.assertEqual(
         repr(z),
         'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', '
         'Reference(\'bar\', TensorType(tf.bool)))])')
     self.assertEqual(computation_building_blocks.compact_representation(z),
                      '<foo,y=bar>')
     self.assertEqual(dir(z), ['y'])
     self.assertIs(z.y, y)
     self.assertLen(z, 2)
     self.assertIs(z[0], x)
     self.assertIs(z[1], y)
     self.assertEqual(
         ','.join(
             computation_building_blocks.compact_representation(e)
             for e in iter(z)), 'foo,bar')
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'tuple')
     self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y'])
     self._serialize_deserialize_roundtrip_test(z)
    def _serialize_deserialize_roundtrip_test(self, target):
        """Performs roundtrip serialization/deserialization of the given target.

    Args:
      target: An instane of ComputationBuildingBlock to serialize-deserialize.
    """
        assert isinstance(target,
                          computation_building_blocks.ComputationBuildingBlock)
        proto = target.proto
        target2 = computation_building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        proto2 = target2.proto
        self.assertEqual(
            computation_building_blocks.compact_representation(target),
            computation_building_blocks.compact_representation(target2))
        self.assertEqual(str(proto), str(proto2))
 def test_intrinsic_class_succeeds_simple_federated_map(self):
     simple_function = computation_types.FunctionType(tf.int32, tf.float32)
     federated_arg = computation_types.FederatedType(
         simple_function.parameter, placements.CLIENTS)
     federated_result = computation_types.FederatedType(
         simple_function.result, placements.CLIENTS)
     federated_map_concrete_type = computation_types.FunctionType(
         [simple_function, federated_arg], federated_result)
     concrete_federated_map = computation_building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type)
     self.assertIsInstance(concrete_federated_map,
                           computation_building_blocks.Intrinsic)
     self.assertEqual(
         str(concrete_federated_map.type_signature),
         '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)')
     self.assertEqual(concrete_federated_map.uri, 'federated_map')
     self.assertEqual(
         computation_building_blocks.compact_representation(
             concrete_federated_map), 'federated_map')
     concrete_federated_map_proto = concrete_federated_map.proto
     self.assertEqual(
         type_serialization.deserialize_type(
             concrete_federated_map_proto.type),
         concrete_federated_map.type_signature)
     self.assertEqual(
         concrete_federated_map_proto.WhichOneof('computation'),
         'intrinsic')
     self.assertEqual(concrete_federated_map_proto.intrinsic.uri,
                      concrete_federated_map.uri)
     self._serialize_deserialize_roundtrip_test(concrete_federated_map)
예제 #5
0
 def test_returns_string_for_federated_aggregate(self):
     comp = computation_test_utils.create_dummy_called_federated_aggregate(
         accumulate_parameter_name='a',
         merge_parameter_name='b',
         report_parameter_name='c')
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(
         compact_string,
         'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)'
     )
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         formatted_string, 'federated_aggregate(<\n'
         '  data,\n'
         '  data,\n'
         '  (a -> data),\n'
         '  (b -> data),\n'
         '  (c -> data)\n'
         '>)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '                    Call\n'
         '                   /    \\\n'
         'federated_aggregate      Tuple\n'
         '                         |\n'
         '                         [data, data, Lambda(a), Lambda(b), Lambda(c)]\n'
         '                                      |          |          |\n'
         '                                      data       data       data')
예제 #6
0
 def test_returns_string_for_federated_map(self):
     comp = computation_test_utils.create_dummy_called_federated_map(
         parameter_name='a')
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'federated_map(<(a -> a),data>)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(formatted_string, 'federated_map(<\n'
                      '  (a -> a),\n'
                      '  data\n'
                      '>)')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '              Call\n'
         '             /    \\\n'
         'federated_map      Tuple\n'
         '                   |\n'
         '                   [Lambda(a), data]\n'
         '                    |\n'
         '                    Ref(a)')
예제 #7
0
    def test_returns_string_for_comp_with_left_overhang(self):
        fn_type = computation_types.FunctionType(tf.int32, tf.int32)
        fn = computation_building_blocks.Reference('a', fn_type)
        proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            lambda: tf.constant(1), None, context_stack_impl.context_stack)
        compiled = computation_building_blocks.CompiledComputation(
            proto, 'bbbbb')
        arg = computation_building_blocks.Call(compiled)

        comp = computation_building_blocks.Call(fn, arg)
        compact_string = computation_building_blocks.compact_representation(
            comp)
        self.assertEqual(compact_string, 'a(comp#bbbbb())')
        formatted_string = computation_building_blocks.formatted_representation(
            comp)
        self.assertEqual(formatted_string, 'a(comp#bbbbb())')
        structural_string = computation_building_blocks.structural_representation(
            comp)
        # pyformat: disable
        self.assertEqual(
            structural_string, '           Call\n'
            '          /    \\\n'
            '    Ref(a)      Call\n'
            '               /\n'
            'Compiled(bbbbb)')
 def test_basic_functionality_of_call_class(self):
     x = computation_building_blocks.Reference(
         'foo', computation_types.FunctionType(tf.int32, tf.bool))
     y = computation_building_blocks.Reference('bar', tf.int32)
     z = computation_building_blocks.Call(x, y)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertIs(z.function, x)
     self.assertIs(z.argument, y)
     self.assertEqual(
         repr(z), 'Call(Reference(\'foo\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), '
         'Reference(\'bar\', TensorType(tf.int32)))')
     self.assertEqual(computation_building_blocks.compact_representation(z),
                      'foo(bar)')
     with self.assertRaises(TypeError):
         computation_building_blocks.Call(x)
     w = computation_building_blocks.Reference('bak', tf.float32)
     with self.assertRaises(TypeError):
         computation_building_blocks.Call(x, w)
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'call')
     self.assertEqual(str(z_proto.call.function), str(x.proto))
     self.assertEqual(str(z_proto.call.argument), str(y.proto))
     self._serialize_deserialize_roundtrip_test(z)
 def test_basic_functionality_of_selection_class(self):
     x = computation_building_blocks.Reference('foo', [('bar', tf.int32),
                                                       ('baz', tf.bool)])
     y = computation_building_blocks.Selection(x, name='bar')
     self.assertEqual(y.name, 'bar')
     self.assertEqual(y.index, None)
     self.assertEqual(str(y.type_signature), 'int32')
     self.assertEqual(
         repr(y), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', name=\'bar\')')
     self.assertEqual(computation_building_blocks.compact_representation(y),
                      'foo.bar')
     z = computation_building_blocks.Selection(x, name='baz')
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(computation_building_blocks.compact_representation(z),
                      'foo.baz')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, name='bak')
     x0 = computation_building_blocks.Selection(x, index=0)
     self.assertEqual(x0.name, None)
     self.assertEqual(x0.index, 0)
     self.assertEqual(str(x0.type_signature), 'int32')
     self.assertEqual(
         repr(x0), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', index=0)')
     self.assertEqual(
         computation_building_blocks.compact_representation(x0), 'foo[0]')
     x1 = computation_building_blocks.Selection(x, index=1)
     self.assertEqual(str(x1.type_signature), 'bool')
     self.assertEqual(
         computation_building_blocks.compact_representation(x1), 'foo[1]')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=2)
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=-1)
     y_proto = y.proto
     self.assertEqual(type_serialization.deserialize_type(y_proto.type),
                      y.type_signature)
     self.assertEqual(y_proto.WhichOneof('computation'), 'selection')
     self.assertEqual(str(y_proto.selection.source), str(x.proto))
     self.assertEqual(y_proto.selection.name, 'bar')
     self._serialize_deserialize_roundtrip_test(y)
     self._serialize_deserialize_roundtrip_test(z)
     self._serialize_deserialize_roundtrip_test(x0)
     self._serialize_deserialize_roundtrip_test(x1)
예제 #10
0
 def _check_whitelisted(comp):
     if isinstance(comp, computation_building_blocks.Intrinsic
                   ) and comp.uri not in uri_whitelist:
         raise ValueError(
             'Encountered an Intrinsic not currently reducible to aggregate or '
             'broadcast, the intrinsic {}'.format(
                 computation_building_blocks.compact_representation(comp)))
     return comp, False
 def test_basic_functionality_of_compiled_computation_class(self):
     comp, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
     x = computation_building_blocks.CompiledComputation(comp)
     self.assertEqual(str(x.type_signature), '(int32 -> int32)')
     self.assertEqual(str(x.proto), str(comp))
     self.assertTrue(
         re.match(
             r'CompiledComputation\([0-9a-f]+, '
             r'FunctionType\(TensorType\(tf\.int32\), '
             r'TensorType\(tf\.int32\)\)\)', repr(x)))
     self.assertTrue(
         re.match(r'comp#[0-9a-f]+',
                  computation_building_blocks.compact_representation(x)))
     y = computation_building_blocks.CompiledComputation(comp, name='foo')
     self.assertEqual(computation_building_blocks.compact_representation(y),
                      'comp#foo')
     self._serialize_deserialize_roundtrip_test(x)
예제 #12
0
 def _check_single_placement(comp):
     """Checks that the placement in `type_spec` matches `single_placement`."""
     if (isinstance(comp.type_signature, computation_types.FederatedType)
             and comp.type_signature.placement != single_placement):
         raise ValueError(
             'Comp contains a placement other than {}; '
             'placement {} on comp {} inside the structure. '.format(
                 single_placement, comp.type_signature.placement,
                 computation_building_blocks.compact_representation(comp)))
     return comp, False
예제 #13
0
 def test_returns_string_for_intrinsic(self):
     comp = computation_building_blocks.Intrinsic('intrinsic', tf.int32)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'intrinsic')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'intrinsic')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'intrinsic')
예제 #14
0
 def test_returns_string_for_reference(self):
     comp = computation_building_blocks.Reference('a', tf.int32)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'a')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'a')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'Ref(a)')
예제 #15
0
 def test_returns_string_for_placement(self):
     comp = computation_building_blocks.Placement(placements.CLIENTS)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'CLIENTS')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'CLIENTS')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'Placement')
예제 #16
0
    def test_raises_with_federated_mean(self):
        intrinsic = computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_MEAN.uri,
            computation_types.FunctionType(
                computation_types.FederatedType(tf.int32, placements.CLIENTS),
                computation_types.FederatedType(tf.int32, placements.SERVER)))

        with self.assertRaisesRegex(
                ValueError,
                computation_building_blocks.compact_representation(intrinsic)):
            tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
예제 #17
0
 def test_returns_string_for_compiled_computation(self):
     proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda: tf.constant(1), None, context_stack_impl.context_stack)
     comp = computation_building_blocks.CompiledComputation(proto, 'a')
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'comp#a')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'comp#a')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     self.assertEqual(structural_string, 'Compiled(a)')
 def test_basic_functionality_of_reference_class(self):
     x = computation_building_blocks.Reference('foo', tf.int32)
     self.assertEqual(x.name, 'foo')
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      'foo')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'reference')
     self.assertEqual(x_proto.reference.name, x.name)
     self._serialize_deserialize_roundtrip_test(x)
예제 #19
0
 def __eq__(self, other):
     """Base class equality checks names and values equal."""
     # TODO(b/130890785): Delegate value-checking to
     # `computation_building_blocks.ComputationBuildingBlock`.
     if self is other:
         return True
     if not isinstance(other, BoundVariableTracker):
         return NotImplemented
     if self.name != other.name:
         return False
     if (isinstance(self.value,
                    computation_building_blocks.ComputationBuildingBlock)
             and isinstance(
                 other.value,
                 computation_building_blocks.ComputationBuildingBlock)):
         return (computation_building_blocks.compact_representation(
             self.value)
                 == computation_building_blocks.compact_representation(
                     other.value)
                 and type_utils.are_equivalent_types(
                     self.value.type_signature, other.value.type_signature))
     return self.value is other.value
 def test_basic_functionality_of_placement_class(self):
     x = computation_building_blocks.Placement(placements.CLIENTS)
     self.assertEqual(str(x.type_signature), 'placement')
     self.assertEqual(x.uri, 'clients')
     self.assertEqual(repr(x), 'Placement(\'clients\')')
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      'CLIENTS')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'placement')
     self.assertEqual(x_proto.placement.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = computation_building_blocks.Reference(arg_name, arg_type)
     arg_f = computation_building_blocks.Selection(arg, name='f')
     arg_x = computation_building_blocks.Selection(arg, name='x')
     x = computation_building_blocks.Lambda(
         arg_name, arg_type,
         computation_building_blocks.Call(
             arg_f, computation_building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(
         computation_building_blocks.compact_representation(x.result),
         'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
 def test_basic_intrinsic_functionality_plus_canonical_typecheck(self):
     x = computation_building_blocks.Intrinsic(
         'generic_plus',
         computation_types.FunctionType([tf.int32, tf.int32], tf.int32))
     self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)')
     self.assertEqual(x.uri, 'generic_plus')
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      'generic_plus')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
     self.assertEqual(x_proto.intrinsic.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
예제 #23
0
 def test_returns_string_for_selection_with_index(self):
     ref = computation_building_blocks.Reference('a', (('b', tf.int32),
                                                       ('c', tf.bool)))
     comp = computation_building_blocks.Selection(ref, index=0)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'a[0]')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'a[0]')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, 'Sel(0)\n' '|\n' 'Ref(a)')
예제 #24
0
 def test_returns_string_for_lambda(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     comp = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                               ref)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '(a -> a)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, '(a -> a)')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, 'Lambda(a)\n' '|\n' 'Ref(a)')
예제 #25
0
 def test_returns_string_for_tuple_with_no_names(self):
     data = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Tuple((data, data))
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '<data,data>')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     # pyformat: disable
     self.assertEqual(formatted_string, '<\n' '  data,\n' '  data\n' '>')
     # pyformat: enable
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, 'Tuple\n' '|\n' '[data, data]')
 def test_basic_functionality_of_block_class(self):
     x = computation_building_blocks.Block([
         ('x',
          computation_building_blocks.Reference('arg',
                                                (tf.int32, tf.int32))),
         ('y',
          computation_building_blocks.Selection(
              computation_building_blocks.Reference('x',
                                                    (tf.int32, tf.int32)),
              index=0))
     ], computation_building_blocks.Reference('y', tf.int32))
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual(
         [(k, computation_building_blocks.compact_representation(v))
          for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')])
     self.assertEqual(
         computation_building_blocks.compact_representation(x.result), 'y')
     self.assertEqual(
         repr(x), 'Block([(\'x\', Reference(\'arg\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), '
         '(\'y\', Selection(Reference(\'x\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), '
         'index=0))], '
         'Reference(\'y\', TensorType(tf.int32)))')
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      '(let x=arg,y=x[0] in y)')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'block')
     self.assertEqual(str(x_proto.block.result), str(x.result.proto))
     for idx, loc_proto in enumerate(x_proto.block.local):
         loc_name, loc_value = x.locals[idx]
         self.assertEqual(loc_proto.name, loc_name)
         self.assertEqual(str(loc_proto.value), str(loc_value.proto))
         self._serialize_deserialize_roundtrip_test(x)
 def test_basic_functionality_of_intrinsic_class(self):
     x = computation_building_blocks.Intrinsic(
         'add_one', computation_types.FunctionType(tf.int32, tf.int32))
     self.assertEqual(str(x.type_signature), '(int32 -> int32)')
     self.assertEqual(x.uri, 'add_one')
     self.assertEqual(
         repr(x), 'Intrinsic(\'add_one\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))')
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      'add_one')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
     self.assertEqual(x_proto.intrinsic.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
 def test_basic_functionality_of_data_class(self):
     x = computation_building_blocks.Data(
         '/tmp/mydata', computation_types.SequenceType(tf.int32))
     self.assertEqual(str(x.type_signature), 'int32*')
     self.assertEqual(x.uri, '/tmp/mydata')
     self.assertEqual(
         repr(x),
         'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))')
     self.assertEqual(computation_building_blocks.compact_representation(x),
                      '/tmp/mydata')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'data')
     self.assertEqual(x_proto.data.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
예제 #29
0
    def test_get_curried(self):
        add_numbers = value_impl.ValueImpl(
            computation_building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(
                    computations.tf_computation(tf.add,
                                                [tf.int32, tf.int32]))),
            _context_stack)

        curried = value_utils.get_curried(add_numbers)
        self.assertEqual(str(curried.type_signature),
                         '(int32 -> (int32 -> int32))')

        comp, _ = transformations.uniquify_compiled_computation_names(
            value_impl.ValueImpl.get_comp(curried))
        self.assertEqual(
            computation_building_blocks.compact_representation(comp),
            '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
예제 #30
0
 def test_returns_string_for_call_with_arg(self):
     ref = computation_building_blocks.Reference('a', tf.int32)
     fn = computation_building_blocks.Lambda(ref.name, ref.type_signature,
                                             ref)
     arg = computation_building_blocks.Data('data', tf.int32)
     comp = computation_building_blocks.Call(fn, arg)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, '(a -> a)(data)')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, '(a -> a)(data)')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(
         structural_string, '          Call\n'
         '         /    \\\n'
         'Lambda(a)      data\n'
         '|\n'
         'Ref(a)')