def mli_accept_weights(self, weights: Weights): try: request_iterator = weights_to_iterator(weights, encode=False) self.stub.SetWeights(request_iterator) except grpc.RpcError as e: _logger.exception(f"Failed to call SetWeights: {e}") raise ConnectionError(f"GRPC error: {e}")
def test_iterator_and_back(): part_a = bytes(b"a" * WEIGHTS_PART_SIZE_BYTES) part_b = bytes(b"b" * (WEIGHTS_PART_SIZE_BYTES - 2)) test_weights = part_a + part_b weights = Weights(weights=test_weights) iterator = weights_to_iterator(input_weights=weights, encode=False) result = iterator_to_weights(request_iterator=iterator, decode=False) assert result == weights
def test_weights_to_iterator_small_limit(): part_a = bytes(b"a" * WEIGHTS_PART_SIZE_BYTES) test_weights = part_a weights = Weights(weights=test_weights) iterator = weights_to_iterator(input_weights=weights, encode=False) val = next(iterator, b"") assert isinstance(val, WeightsPart) assert val.total_bytes == WEIGHTS_PART_SIZE_BYTES assert val.byte_index == 0 assert bytes(val.weights) == part_a val = next(iterator, b"") assert val == b""
def test_weights_to_iterator_small(): part_a = bytes(b"a") test_weights = part_a weights = Weights(weights=test_weights) iterator = weights_to_iterator(input_weights=weights, encode=False) val = next(iterator, b"") assert isinstance(val, WeightsPart) assert val.total_bytes == 1 assert val.byte_index == 0 assert bytes(val.weights) == part_a val = next(iterator, b"") assert val == b""
def mli_test_weights(self, weights: Weights = None) -> ProposedWeights: try: if weights: response = self.stub.TestWeights( weights_to_iterator(weights, encode=False)) else: raise Exception( "mli_test_weights(None) is not currently supported") return ProposedWeights(weights=weights, vote_score=response.vote_score, test_score=response.test_score, vote=response.vote) except grpc.RpcError as ex: _logger.exception(f"Failed to test_model: {ex}") raise ConnectionError(f"GRPC error: {ex}")
def ProposeWeights(self, request, context): _count_propose.inc() if not self._check_model(context): return self._learner_mutex.acquire() try: _logger.debug("Start training...") weights = self.learner.mli_propose_weights() _logger.debug("Training done!") weights_part_iterator = weights_to_iterator(weights) for wp in weights_part_iterator: yield wp except Exception as ex: # pylint: disable=W0703 _logger.exception(f"Exception in ProposeWeights: {ex} {type(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) _count_propose_err.inc() finally: self._learner_mutex.release()