コード例 #1
0
    def mli_get_current_weights(self) -> Weights:
        request = empty_pb2.Empty()
        try:
            response = self.stub.GetCurrentWeights(request)
            weights = iterator_to_weights(response, decode=False)
            return weights

        except grpc.RpcError as ex:
            _logger.exception(f"Failed to get_current_weights: {ex}")
            raise Exception(f"Failed to get_current_weights: {ex}")
コード例 #2
0
    def mli_propose_weights(self) -> Weights:
        request = empty_pb2.Empty()
        try:
            response = self.stub.ProposeWeights(request)
            weights = iterator_to_weights(response, decode=False)
            return weights

        except grpc.RpcError as ex:
            _logger.exception(f"Failed to train_model: {ex}")
            raise ConnectionError(f"GRPC error: {ex}")
コード例 #3
0
def test_iterator_and_back():

    part_a = bytes(b"a" * WEIGHTS_PART_SIZE_BYTES)
    part_b = bytes(b"b" * (WEIGHTS_PART_SIZE_BYTES - 2))
    test_weights = part_a + part_b
    weights = Weights(weights=test_weights)

    iterator = weights_to_iterator(input_weights=weights, encode=False)

    result = iterator_to_weights(request_iterator=iterator, decode=False)

    assert result == weights
コード例 #4
0
def test_in_order_iterator_to_weights():

    test_weights = b"abc"
    parts = [
        WeightsPart(weights=test_weights[i:i + 1],
                    byte_index=i,
                    total_bytes=len(test_weights))
        for i in range(len(test_weights))
    ]

    result = iterator_to_weights(request_iterator=iter(parts), decode=False)

    assert result.weights == test_weights
コード例 #5
0
    def SetWeights(self, request_iterator, context):
        _count_set.inc()
        if not self._check_model(context):
            return empty_pb2.Empty()
        self._learner_mutex.acquire()
        try:
            weights = iterator_to_weights(request_iterator)

            self.learner.mli_accept_weights(weights)
        except Exception as ex:  # pylint: disable=W0703
            _logger.exception(f"Exception in SetWeights: {ex} {type(ex)}")
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(ex))
            _count_set_err.inc()
        finally:
            self._learner_mutex.release()
        return empty_pb2.Empty()
コード例 #6
0
    def TestWeights(self, request_iterator, context):
        _count_test.inc()
        pw = ipb2.ProposedWeights()
        if not self._check_model(context):
            return pw
        self._learner_mutex.acquire()
        try:
            _logger.debug("Test weights...")

            weights = iterator_to_weights(request_iterator)
            proposed_weights = self.learner.mli_test_weights(weights)
            pw.vote_score = proposed_weights.vote_score
            pw.test_score = proposed_weights.test_score
            pw.vote = proposed_weights.vote
            _logger.debug("Testing done!")
        except Exception as ex:  # pylint: disable=W0703
            _logger.exception(f"Exception in TestWeights: {ex} {type(ex)}")
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(ex))
            _count_test_err.inc()
        finally:
            self._learner_mutex.release()
        return pw