def mli_get_current_weights(self) -> Weights: request = empty_pb2.Empty() try: response = self.stub.GetCurrentWeights(request) weights = iterator_to_weights(response, decode=False) return weights except grpc.RpcError as ex: _logger.exception(f"Failed to get_current_weights: {ex}") raise Exception(f"Failed to get_current_weights: {ex}")
def mli_propose_weights(self) -> Weights: request = empty_pb2.Empty() try: response = self.stub.ProposeWeights(request) weights = iterator_to_weights(response, decode=False) return weights except grpc.RpcError as ex: _logger.exception(f"Failed to train_model: {ex}") raise ConnectionError(f"GRPC error: {ex}")
def test_iterator_and_back(): part_a = bytes(b"a" * WEIGHTS_PART_SIZE_BYTES) part_b = bytes(b"b" * (WEIGHTS_PART_SIZE_BYTES - 2)) test_weights = part_a + part_b weights = Weights(weights=test_weights) iterator = weights_to_iterator(input_weights=weights, encode=False) result = iterator_to_weights(request_iterator=iterator, decode=False) assert result == weights
def test_in_order_iterator_to_weights(): test_weights = b"abc" parts = [ WeightsPart(weights=test_weights[i:i + 1], byte_index=i, total_bytes=len(test_weights)) for i in range(len(test_weights)) ] result = iterator_to_weights(request_iterator=iter(parts), decode=False) assert result.weights == test_weights
def SetWeights(self, request_iterator, context): _count_set.inc() if not self._check_model(context): return empty_pb2.Empty() self._learner_mutex.acquire() try: weights = iterator_to_weights(request_iterator) self.learner.mli_accept_weights(weights) except Exception as ex: # pylint: disable=W0703 _logger.exception(f"Exception in SetWeights: {ex} {type(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) _count_set_err.inc() finally: self._learner_mutex.release() return empty_pb2.Empty()
def TestWeights(self, request_iterator, context): _count_test.inc() pw = ipb2.ProposedWeights() if not self._check_model(context): return pw self._learner_mutex.acquire() try: _logger.debug("Test weights...") weights = iterator_to_weights(request_iterator) proposed_weights = self.learner.mli_test_weights(weights) pw.vote_score = proposed_weights.vote_score pw.test_score = proposed_weights.test_score pw.vote = proposed_weights.vote _logger.debug("Testing done!") except Exception as ex: # pylint: disable=W0703 _logger.exception(f"Exception in TestWeights: {ex} {type(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) _count_test_err.inc() finally: self._learner_mutex.release() return pw