def proto(self): if self._name is not None: selection = pb.Selection(source=self._source.proto, name=self._name) else: selection = pb.Selection(source=self._source.proto, index=self._index) return pb.Computation( type=type_serialization.serialize_type(self.type_signature), selection=selection)
def create_dummy_computation_selection(): """Returns a selection computation and type.""" source, source_type = create_dummy_computation_tuple() type_signature = source_type[0] value = pb.Computation( type=type_serialization.serialize_type(type_signature), selection=pb.Selection(source=source, index=0)) return value, type_signature
def test_raises_value_error_with_unrecognized_computation_selection(self): executor = create_test_executor() source, _ = executor_test_utils.create_dummy_computation_tuple() type_signature = computation_types.NamedTupleType([]) # A `ValueError` will be raised because `create_value` can not handle the # following `pb.Selection`, because does not set either a name or an index # field. value = pb.Computation( type=type_serialization.serialize_type(type_signature), selection=pb.Selection(source=source)) with self.assertRaises(ValueError): self.run_sync(executor.create_value(value, type_signature))
def create_dummy_computation_selection(): element_value = executor_test_utils.create_dummy_empty_tensorflow_computation( ) element_type = computation_types.FunctionType( None, computation_types.NamedTupleType([])) element = pb.Tuple.Element(value=element_value) source = pb.Computation(type=type_serialization.serialize_type( [element_type]), tuple=pb.Tuple(element=[element])) value = pb.Computation( type=type_serialization.serialize_type(element_type), selection=pb.Selection(source=source, index=0)) type_signature = computation_types.FunctionType( None, computation_types.NamedTupleType([])) return value, type_signature
def test_raises_value_error_with_unrecognized_computation_selection(self): executor = create_test_executor(num_clients=3) element_value = executor_test_utils.create_dummy_empty_tensorflow_computation( ) element_type = computation_types.FunctionType( None, computation_types.NamedTupleType([])) element = pb.Tuple.Element(value=element_value) source = pb.Computation(type=type_serialization.serialize_type( [element_type]), tuple=pb.Tuple(element=[element])) # A `ValueError` will be raised because `create_value` can not handle the # following `pb.Selection`, because does not set either a name or an index # field. value = pb.Computation( type=type_serialization.serialize_type(element_type), selection=pb.Selection(source=source)) type_signature = computation_types.FunctionType( None, computation_types.NamedTupleType([])) with self.assertRaises(ValueError): self.run_sync(executor.create_value(value, type_signature))