def test_federated_setattr_call_fails_non_federated_type(self): bad_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) bad_comp = computation_building_blocks.Data('data', bad_type) value_comp = computation_building_blocks.Data('x', tf.int32) with self.assertRaises(TypeError): _ = computation_constructing_utils.construct_federated_setattr_call( bad_comp, 'a', value_comp)
def test_federated_setattr_call_fails_on_none_value(self): named_tuple_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) good_type = computation_types.FederatedType(named_tuple_type, placement_literals.CLIENTS) acceptable_comp = computation_building_blocks.Data('data', good_type) with self.assertRaises(TypeError): _ = computation_constructing_utils.construct_federated_setattr_call( acceptable_comp, 'a', None)
def test_federated_setattr_call_leaves_type_signatures_alone(self, placement): named_tuple_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) good_type = computation_types.FederatedType(named_tuple_type, placement) federated_comp = computation_building_blocks.Data('federated_comp', good_type) value_comp = computation_building_blocks.Data('x', tf.int32) federated_setattr = computation_constructing_utils.construct_federated_setattr_call( federated_comp, 'a', value_comp) self.assertTrue( type_utils.are_equivalent_types(federated_setattr.type_signature, federated_comp.type_signature))
def test_federated_setattr_call_constructs_correct_intrinsic_server(self): named_tuple_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) good_type = computation_types.FederatedType(named_tuple_type, placement_literals.SERVER) federated_comp = computation_building_blocks.Data('federated_comp', good_type) value_comp = computation_building_blocks.Data('x', tf.int32) federated_setattr = computation_constructing_utils.construct_federated_setattr_call( federated_comp, 'a', value_comp) self.assertEqual(federated_setattr.function.uri, intrinsic_defs.FEDERATED_APPLY.uri)
def test_federated_setattr_call_constructs_correct_computation_server(self): named_tuple_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) good_type = computation_types.FederatedType(named_tuple_type, placement_literals.SERVER) federated_comp = computation_building_blocks.Data('federated_comp', good_type) value_comp = computation_building_blocks.Data('x', tf.int32) federated_setattr = computation_constructing_utils.construct_federated_setattr_call( federated_comp, 'a', value_comp) self.assertEqual( federated_setattr.tff_repr, 'federated_apply(<(let value_comp_placeholder=x in (lambda_arg -> <a=value_comp_placeholder,lambda_arg[1],b=lambda_arg[2]>)),federated_comp>)' )
def __setattr__(self, name, value): py_typecheck.check_type(name, six.string_types) value_comp = ValueImpl.get_comp( to_value(value, None, self._context_stack)) if isinstance(self._comp.type_signature, computation_types.FederatedType) and isinstance( self._comp.type_signature.member, computation_types.NamedTupleType): new_comp = computation_constructing_utils.construct_federated_setattr_call( self._comp, name, value_comp) super(ValueImpl, self).__setattr__('_comp', new_comp) return elif not isinstance(self._comp.type_signature, computation_types.NamedTupleType): raise TypeError( 'Operator setattr() is only supported for named tuples, but the ' 'object on which it has been invoked is of type {}.'.format( str(self._comp.type_signature))) named_tuple_setattr_lambda = computation_constructing_utils.construct_named_tuple_setattr_lambda( self._comp.type_signature, name, value_comp) new_comp = computation_building_blocks.Call(named_tuple_setattr_lambda, self._comp) super(ValueImpl, self).__setattr__('_comp', new_comp)
def test_federated_setattr_call_fails_on_none_federated_comp(self): value_comp = computation_building_blocks.Data('x', tf.int32) with self.assertRaises(TypeError): _ = computation_constructing_utils.construct_federated_setattr_call( None, 'a', value_comp)