def mli_accept_weights(self, weights: Weights):
     try:
         request_iterator = weights_to_iterator(weights, encode=False)
         self.stub.SetWeights(request_iterator)
     except grpc.RpcError as e:
         _logger.exception(f"Failed to call SetWeights: {e}")
         raise ConnectionError(f"GRPC error: {e}")
예제 #2
0
def test_iterator_and_back():

    part_a = bytes(b"a" * WEIGHTS_PART_SIZE_BYTES)
    part_b = bytes(b"b" * (WEIGHTS_PART_SIZE_BYTES - 2))
    test_weights = part_a + part_b
    weights = Weights(weights=test_weights)

    iterator = weights_to_iterator(input_weights=weights, encode=False)

    result = iterator_to_weights(request_iterator=iterator, decode=False)

    assert result == weights
예제 #3
0
def test_weights_to_iterator_small_limit():
    part_a = bytes(b"a" * WEIGHTS_PART_SIZE_BYTES)
    test_weights = part_a
    weights = Weights(weights=test_weights)

    iterator = weights_to_iterator(input_weights=weights, encode=False)

    val = next(iterator, b"")
    assert isinstance(val, WeightsPart)
    assert val.total_bytes == WEIGHTS_PART_SIZE_BYTES
    assert val.byte_index == 0
    assert bytes(val.weights) == part_a

    val = next(iterator, b"")
    assert val == b""
예제 #4
0
def test_weights_to_iterator_small():
    part_a = bytes(b"a")
    test_weights = part_a
    weights = Weights(weights=test_weights)

    iterator = weights_to_iterator(input_weights=weights, encode=False)

    val = next(iterator, b"")
    assert isinstance(val, WeightsPart)
    assert val.total_bytes == 1
    assert val.byte_index == 0
    assert bytes(val.weights) == part_a

    val = next(iterator, b"")
    assert val == b""
    def mli_test_weights(self, weights: Weights = None) -> ProposedWeights:
        try:
            if weights:
                response = self.stub.TestWeights(
                    weights_to_iterator(weights, encode=False))
            else:
                raise Exception(
                    "mli_test_weights(None) is not currently supported")

            return ProposedWeights(weights=weights,
                                   vote_score=response.vote_score,
                                   test_score=response.test_score,
                                   vote=response.vote)
        except grpc.RpcError as ex:
            _logger.exception(f"Failed to test_model: {ex}")
            raise ConnectionError(f"GRPC error: {ex}")
예제 #6
0
    def ProposeWeights(self, request, context):
        _count_propose.inc()
        if not self._check_model(context):
            return
        self._learner_mutex.acquire()
        try:
            _logger.debug("Start training...")
            weights = self.learner.mli_propose_weights()
            _logger.debug("Training done!")

            weights_part_iterator = weights_to_iterator(weights)
            for wp in weights_part_iterator:
                yield wp

        except Exception as ex:  # pylint: disable=W0703
            _logger.exception(f"Exception in ProposeWeights: {ex} {type(ex)}")
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(ex))
            _count_propose_err.inc()
        finally:
            self._learner_mutex.release()