def create_dummy_computation_tuple(): """Returns a tuple computation and type.""" element_type = computation_types.NamedTupleType([]) element_value = pb.Computation( type=type_serialization.serialize_type(element_type), tuple=pb.Tuple(element=[])) element = pb.Tuple.Element(value=element_value) type_signature = computation_types.NamedTupleType([element_type]) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tuple=pb.Tuple(element=[element])) return value, type_signature
def proto(self): elements = [] for k, v in anonymous_tuple.to_elements(self): if k is not None: element = pb.Tuple.Element(name=k, value=v.proto) else: element = pb.Tuple.Element(value=v.proto) elements.append(element) return pb.Computation(type=type_serialization.serialize_type( self.type_signature), tuple=pb.Tuple(element=elements))
def create_dummy_computation_tuple(): """Returns a tuple computation and type.""" names = ['a', 'b', 'c'] element_value, element_type = create_dummy_computation_tensorflow_constant( ) elements = [pb.Tuple.Element(name=n, value=element_value) for n in names] type_signature = computation_types.NamedTupleType( (n, element_type) for n in names) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tuple=pb.Tuple(element=elements)) return value, type_signature
def create_dummy_computation_lambda_empty(): """Returns a lambda computation and type `( -> <>)`.""" result_type = computation_types.NamedTupleType([]) type_signature = computation_types.FunctionType(None, result_type) result = pb.Computation( type=type_serialization.serialize_type(result_type), tuple=pb.Tuple(element=[])) fn = pb.Lambda(parameter_name=None, result=result) # We are unpacking the lambda argument here because `lambda` is a reserved # keyword in Python, but it is also the name of the parameter for a # `pb.Computation`. # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts value = pb.Computation( type=type_serialization.serialize_type(type_signature), **{'lambda': fn}) # pytype: disable=wrong-keyword-args return value, type_signature
def create_dummy_computation_selection(): element_value = executor_test_utils.create_dummy_empty_tensorflow_computation( ) element_type = computation_types.FunctionType( None, computation_types.NamedTupleType([])) element = pb.Tuple.Element(value=element_value) source = pb.Computation(type=type_serialization.serialize_type( [element_type]), tuple=pb.Tuple(element=[element])) value = pb.Computation( type=type_serialization.serialize_type(element_type), selection=pb.Selection(source=source, index=0)) type_signature = computation_types.FunctionType( None, computation_types.NamedTupleType([])) return value, type_signature
def test_raises_value_error_with_unrecognized_computation_selection(self): executor = create_test_executor(num_clients=3) element_value = executor_test_utils.create_dummy_empty_tensorflow_computation( ) element_type = computation_types.FunctionType( None, computation_types.NamedTupleType([])) element = pb.Tuple.Element(value=element_value) source = pb.Computation(type=type_serialization.serialize_type( [element_type]), tuple=pb.Tuple(element=[element])) # A `ValueError` will be raised because `create_value` can not handle the # following `pb.Selection`, because does not set either a name or an index # field. value = pb.Computation( type=type_serialization.serialize_type(element_type), selection=pb.Selection(source=source)) type_signature = computation_types.FunctionType( None, computation_types.NamedTupleType([])) with self.assertRaises(ValueError): self.run_sync(executor.create_value(value, type_signature))
def create_dummy_computation_tuple(): value = pb.Computation(type=type_serialization.serialize_type( computation_types.NamedTupleType([])), tuple=pb.Tuple(element=[])) type_signature = computation_types.NamedTupleType([]) return value, type_signature