Ejemplo n.º 1
0
 def test_construct_generic_raises_federated_type(self):
     bad_type = computation_types.FederatedType(
         [[tf.int32, tf.int32],
          computation_types.TensorType(tf.float32, [2])],
         placement_literals.CLIENTS)
     with self.assertRaisesRegex(TypeError,
                                 'argument that is not a two-tuple'):
         intrinsic_utils.create_binary_operator_with_upcast(
             bad_type, tf.multiply)
Ejemplo n.º 2
0
 def test_raises_non_callable_op(self):
     bad_type_ref = computation_building_blocks.Reference(
         'x', [tf.float32, tf.float32])
     with self.assertRaisesRegex(TypeError, 'non-callable'):
         intrinsic_utils.apply_binary_operator_with_upcast(
             bad_type_ref, tf.constant(0))
     with self.assertRaisesRegex(TypeError, 'non-callable'):
         intrinsic_utils.create_binary_operator_with_upcast(
             bad_type_ref, tf.constant(0))
Ejemplo n.º 3
0
 def test_raises_tuple_scalar_multiplied_by_nonscalar(self):
     bad_type_ref = computation_building_blocks.Reference(
         'x',
         [tf.int32, computation_types.TensorType(tf.float32, [2])])
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.apply_binary_operator_with_upcast(
             bad_type_ref, tf.multiply)
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.create_binary_operator_with_upcast(
             bad_type_ref.type_signature, tf.multiply)
Ejemplo n.º 4
0
 def test_raises_incompatible_tuple_and_tensor(self):
     bad_type_ref = computation_building_blocks.Reference(
         'x',
         computation_types.FederatedType([[tf.int32, tf.int32], tf.float32],
                                         placement_literals.CLIENTS))
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.apply_binary_operator_with_upcast(
             bad_type_ref, tf.multiply)
     with self.assertRaisesRegex(TypeError, 'incompatible with upcasted'):
         intrinsic_utils.create_binary_operator_with_upcast(
             bad_type_ref.type_signature.member, tf.multiply)
Ejemplo n.º 5
0
 def test_construct_integer_type_signature(self):
     ref = computation_building_blocks.Reference('x', [tf.int32, tf.int32])
     multiplier = intrinsic_utils.create_binary_operator_with_upcast(
         ref.type_signature, tf.multiply)
     self.assertEqual(
         multiplier.type_signature,
         type_constructors.binary_op(computation_types.to_type(tf.int32)))
Ejemplo n.º 6
0
 def test_construct_divide_op_named_tuple_with_scalar_type_signature(self):
     type_spec = computation_types.to_type([[('a', tf.float32),
                                             ('b', tf.float32)],
                                            tf.float32])
     multiplier = intrinsic_utils.create_binary_operator_with_upcast(
         type_spec, tf.divide)
     expected_function_type = computation_types.FunctionType(
         type_spec, type_spec[0])
     self.assertEqual(multiplier.type_signature, expected_function_type)
Ejemplo n.º 7
0
 def federated_sum(x):
   zero = value_impl.ValueImpl(
       intrinsic_utils.create_generic_constant(x.type_signature.member, 0),
       context_stack)
   plus_op = value_impl.ValueImpl(
       intrinsic_utils.create_binary_operator_with_upcast(
           computation_types.NamedTupleType(
               [x.type_signature.member, x.type_signature.member]), tf.add),
       context_stack)
   return federated_reduce([x, zero, plus_op])
Ejemplo n.º 8
0
 def test_construct_op_raises_on_none_operator(self):
     with self.assertRaisesRegex(TypeError, 'found non-callable'):
         intrinsic_utils.create_binary_operator_with_upcast(tf.int32, None)