def test_getitem_call_unnamed(self, placement): federated_comp_unnamed = computation_building_blocks.Reference( 'test', computation_types.FederatedType([tf.int32, tf.bool], placement, True)) self.assertEqual(str(federated_comp_unnamed.type_signature.member), '<int32,bool>') unnamed_idx_0 = computation_constructing_utils.construct_federated_getitem_call( federated_comp_unnamed, 0) unnamed_idx_1 = computation_constructing_utils.construct_federated_getitem_call( federated_comp_unnamed, 1) self.assertIsInstance(unnamed_idx_0.type_signature, computation_types.FederatedType) self.assertIsInstance(unnamed_idx_1.type_signature, computation_types.FederatedType) self.assertEqual(str(unnamed_idx_0.type_signature.member), 'int32') self.assertEqual(str(unnamed_idx_1.type_signature.member), 'bool') type_utils.check_federated_value_placement( value_impl.to_value(unnamed_idx_0, None, context_stack_impl.context_stack), placement) type_utils.check_federated_value_placement( value_impl.to_value(unnamed_idx_1, None, context_stack_impl.context_stack), placement) unnamed_flipped = computation_constructing_utils.construct_federated_getitem_call( federated_comp_unnamed, slice(None, None, -1)) self.assertIsInstance(unnamed_flipped.type_signature, computation_types.FederatedType) self.assertEqual(str(unnamed_flipped.type_signature.member), '<bool,int32>') type_utils.check_federated_value_placement( value_impl.to_value(unnamed_flipped, None, context_stack_impl.context_stack), placement)
def __getitem__(self, key): py_typecheck.check_type(key, (int, slice)) if (isinstance(self._comp.type_signature, computation_types.FederatedType) and isinstance(self._comp.type_signature.member, computation_types.NamedTupleType)): return ValueImpl( computation_constructing_utils.construct_federated_getitem_call( self._comp, key), self._context_stack) if not isinstance(self._comp.type_signature, computation_types.NamedTupleType): raise TypeError( 'Operator getitem() is only supported for named tuples, but the ' 'object on which it has been invoked is of type {}.'.format( str(self._comp.type_signature))) elem_length = len(self._comp.type_signature) if isinstance(key, int): if key < 0 or key >= elem_length: raise IndexError( 'The index of the selected element {} is out of range.'.format(key)) if isinstance(self._comp, computation_building_blocks.Tuple): return ValueImpl(self._comp[key], self._context_stack) else: return ValueImpl( computation_building_blocks.Selection(self._comp, index=key), self._context_stack) elif isinstance(key, slice): index_range = range(*key.indices(elem_length)) if not index_range: raise IndexError('Attempted to slice 0 elements, which is not ' 'currently supported.') return to_value([self[k] for k in index_range], None, self._context_stack)
def test_federated_getitem_call_fails_value(self): x = computation_building_blocks.Reference( 'x', computation_types.to_type([tf.int32])) with self.assertRaises(TypeError): computation_constructing_utils.construct_federated_getitem_call( value_impl.to_value(x), 0)