Beispiel #1
0
 def test_getitem_call_unnamed(self, placement):
     federated_comp_unnamed = computation_building_blocks.Reference(
         'test',
         computation_types.FederatedType([tf.int32, tf.bool], placement,
                                         True))
     self.assertEqual(str(federated_comp_unnamed.type_signature.member),
                      '<int32,bool>')
     unnamed_idx_0 = computation_constructing_utils.construct_federated_getitem_call(
         federated_comp_unnamed, 0)
     unnamed_idx_1 = computation_constructing_utils.construct_federated_getitem_call(
         federated_comp_unnamed, 1)
     self.assertIsInstance(unnamed_idx_0.type_signature,
                           computation_types.FederatedType)
     self.assertIsInstance(unnamed_idx_1.type_signature,
                           computation_types.FederatedType)
     self.assertEqual(str(unnamed_idx_0.type_signature.member), 'int32')
     self.assertEqual(str(unnamed_idx_1.type_signature.member), 'bool')
     type_utils.check_federated_value_placement(
         value_impl.to_value(unnamed_idx_0, None,
                             context_stack_impl.context_stack), placement)
     type_utils.check_federated_value_placement(
         value_impl.to_value(unnamed_idx_1, None,
                             context_stack_impl.context_stack), placement)
     unnamed_flipped = computation_constructing_utils.construct_federated_getitem_call(
         federated_comp_unnamed, slice(None, None, -1))
     self.assertIsInstance(unnamed_flipped.type_signature,
                           computation_types.FederatedType)
     self.assertEqual(str(unnamed_flipped.type_signature.member),
                      '<bool,int32>')
     type_utils.check_federated_value_placement(
         value_impl.to_value(unnamed_flipped, None,
                             context_stack_impl.context_stack), placement)
Beispiel #2
0
 def __getitem__(self, key):
   py_typecheck.check_type(key, (int, slice))
   if (isinstance(self._comp.type_signature, computation_types.FederatedType)
       and isinstance(self._comp.type_signature.member,
                      computation_types.NamedTupleType)):
     return ValueImpl(
         computation_constructing_utils.construct_federated_getitem_call(
             self._comp, key), self._context_stack)
   if not isinstance(self._comp.type_signature,
                     computation_types.NamedTupleType):
     raise TypeError(
         'Operator getitem() is only supported for named tuples, but the '
         'object on which it has been invoked is of type {}.'.format(
             str(self._comp.type_signature)))
   elem_length = len(self._comp.type_signature)
   if isinstance(key, int):
     if key < 0 or key >= elem_length:
       raise IndexError(
           'The index of the selected element {} is out of range.'.format(key))
     if isinstance(self._comp, computation_building_blocks.Tuple):
       return ValueImpl(self._comp[key], self._context_stack)
     else:
       return ValueImpl(
           computation_building_blocks.Selection(self._comp, index=key),
           self._context_stack)
   elif isinstance(key, slice):
     index_range = range(*key.indices(elem_length))
     if not index_range:
       raise IndexError('Attempted to slice 0 elements, which is not '
                        'currently supported.')
     return to_value([self[k] for k in index_range], None, self._context_stack)
Beispiel #3
0
 def test_federated_getitem_call_fails_value(self):
     x = computation_building_blocks.Reference(
         'x', computation_types.to_type([tf.int32]))
     with self.assertRaises(TypeError):
         computation_constructing_utils.construct_federated_getitem_call(
             value_impl.to_value(x), 0)