def iterator(self): @computations.tf_computation() def comp(): return 1 value_proto, _ = executor_serialization.serialize_value(comp) request = executor_pb2.ExecuteRequest( create_value=executor_pb2.CreateValueRequest(value=value_proto)) yield request response = self.queue.get() create_call_proto = executor_pb2.CreateCallRequest( function_ref=response.create_value.value_ref, argument_ref=None) request = executor_pb2.ExecuteRequest(create_call=create_call_proto) yield request response = self.queue.get() compute_proto = executor_pb2.ComputeRequest( value_ref=response.create_call.value_ref) request = executor_pb2.ExecuteRequest(compute=compute_proto) yield request
def _iterator(): @computations.tf_computation(tf.int32) def comp(x): return tf.add(x, 1) value_proto, _ = executor_serialization.serialize_value(comp) request = executor_pb2.ExecuteRequest( create_value=executor_pb2.CreateValueRequest( value=value_proto)) yield request
async def _compute(self, value_ref): py_typecheck.check_type(value_ref, executor_pb2.ValueRef) request = executor_pb2.ComputeRequest(value_ref=value_ref) if self._bidi_stream is None: response = _request(self._stub.Compute, request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(compute=request))).compute py_typecheck.check_type(response, executor_pb2.ComputeResponse) value, _ = executor_serialization.deserialize_value(response.value) return value
async def _compute(self, value_ref): py_typecheck.check_type(value_ref, executor_pb2.ValueRef) request = executor_pb2.ComputeRequest(value_ref=value_ref) if not self._bidi_stream: response = self._stub.Compute(request) else: response = self._bidi_stream.send_request( executor_pb2.ExecuteRequest(compute=request)).compute py_typecheck.check_type(response, executor_pb2.ComputeResponse) value, _ = executor_service_utils.deserialize_value(response.value) return value
async def create_value(self, value, type_spec=None): value_proto, type_spec = ( executor_service_utils.serialize_value(value, type_spec)) create_value_request = executor_pb2.CreateValueRequest(value=value_proto) if not self._bidi_stream: response = self._stub.CreateValue(create_value_request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_value=create_value_request) )).create_value py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self)
async def set_cardinalities( self, cardinalities: Mapping[placement_literals.PlacementLiteral, int]): serialized_cardinalities = executor_service_utils.serialize_cardinalities( cardinalities) request = executor_pb2.SetCardinalitiesRequest( cardinalities=serialized_cardinalities) if self._bidi_stream is None: _request(self._stub.SetCardinalities, request) else: await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(set_cardinalities=request)) return
async def _compute(self, value_ref): py_typecheck.check_type(value_ref, executor_pb2.ValueRef) request = executor_pb2.ComputeRequest(value_ref=value_ref) if not self._bidi_stream: try: response = self._stub.Compute(request) except grpc.RpcError as e: self._handle_grpc_error(e) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(compute=request))).compute py_typecheck.check_type(response, executor_pb2.ComputeResponse) value, _ = executor_service_utils.deserialize_value(response.value) return value
def _dispose(self, value_ref: executor_pb2.ValueRef): """Disposes of the remote value stored on the worker service.""" self._dispose_request.value_ref.append(value_ref) if len(self._dispose_request.value_ref) < self._dispose_batch_size: return dispose_request = self._dispose_request self._dispose_request = executor_pb2.DisposeRequest() if self._bidi_stream is None: _request(self._stub.Dispose, dispose_request) else: send_request_fut = self._bidi_stream.send_request( executor_pb2.ExecuteRequest(dispose=dispose_request)) # We don't care about the response, and so don't bother to await it. # Just start it as a task so that it runs at some point. asyncio.get_event_loop().create_task(send_request_fut)
async def create_value(self, value, type_spec=None): @tracing.trace def serialize_value(): return executor_serialization.serialize_value(value, type_spec) value_proto, type_spec = serialize_value() create_value_request = executor_pb2.CreateValueRequest( value=value_proto) if self._bidi_stream is None: response = _request(self._stub.CreateValue, create_value_request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_value=create_value_request) )).create_value py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self)
async def create_call(self, comp, arg=None): py_typecheck.check_type(comp, RemoteValue) py_typecheck.check_type(comp.type_signature, computation_types.FunctionType) if arg is not None: py_typecheck.check_type(arg, RemoteValue) create_call_request = executor_pb2.CreateCallRequest( function_ref=comp.value_ref, argument_ref=(arg.value_ref if arg is not None else None)) if not self._bidi_stream: response = self._stub.CreateCall(create_call_request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_call=create_call_request) )).create_call py_typecheck.check_type(response, executor_pb2.CreateCallResponse) return RemoteValue(response.value_ref, comp.type_signature.result, self)
async def create_tuple(self, elements): elem = anonymous_tuple.to_elements(anonymous_tuple.from_container(elements)) proto_elem = [] type_elem = [] for k, v in elem: py_typecheck.check_type(v, RemoteValue) proto_elem.append( executor_pb2.CreateTupleRequest.Element( name=(k if k else None), value_ref=v.value_ref)) type_elem.append((k, v.type_signature) if k else v.type_signature) result_type = computation_types.NamedTupleType(type_elem) request = executor_pb2.CreateTupleRequest(element=proto_elem) if not self._bidi_stream: response = self._stub.CreateTuple(request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_tuple=request))).create_tuple py_typecheck.check_type(response, executor_pb2.CreateTupleResponse) return RemoteValue(response.value_ref, result_type, self)
async def create_struct(self, elements): constructed_anon_tuple = structure.from_container(elements) proto_elem = [] type_elem = [] for k, v in structure.iter_elements(constructed_anon_tuple): py_typecheck.check_type(v, RemoteValue) proto_elem.append( executor_pb2.CreateStructRequest.Element( name=(k if k else None), value_ref=v.value_ref)) type_elem.append((k, v.type_signature) if k else v.type_signature) result_type = computation_types.StructType(type_elem) request = executor_pb2.CreateStructRequest(element=proto_elem) if self._bidi_stream is None: response = _request(self._stub.CreateStruct, request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_struct=request))).create_struct py_typecheck.check_type(response, executor_pb2.CreateStructResponse) return RemoteValue(response.value_ref, result_type, self)
async def create_selection(self, source, index=None, name=None): py_typecheck.check_type(source, RemoteValue) py_typecheck.check_type(source.type_signature, computation_types.StructType) if index is not None: py_typecheck.check_type(index, int) py_typecheck.check_none(name) result_type = source.type_signature[index] else: py_typecheck.check_type(name, str) result_type = getattr(source.type_signature, name) request = executor_pb2.CreateSelectionRequest( source_ref=source.value_ref, name=name, index=index) if self._bidi_stream is None: response = _request(self._stub.CreateSelection, request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_selection=request) )).create_selection py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse) return RemoteValue(response.value_ref, result_type, self)
async def create_selection(self, source, index=None, name=None): py_typecheck.check_type(source, RemoteValue) py_typecheck.check_type(source.type_signature, computation_types.NamedTupleType) if index is not None: py_typecheck.check_type(index, int) py_typecheck.check_none(name) result_type = source.type_signature[index] else: py_typecheck.check_type(name, str) result_type = getattr(source.type_signature, name) request = executor_pb2.CreateSelectionRequest( source_ref=source.value_ref, name=name, index=index) if not self._bidi_stream: try: response = self._stub.CreateSelection(request) except grpc.RpcError as e: self._handle_grpc_error(e) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_selection=request) )).create_selection py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse) return RemoteValue(response.value_ref, result_type, self)