コード例 #1
0
ファイル: base.py プロジェクト: XinYao1994/sedna
    def save(self, model_url="", model_name=None):
        mname = model_name or self.model_name
        if os.path.isfile(self.model_save_path):
            self.model_save_path, mname = os.path.split(self.model_save_path)

        FileOps.clean_folder([self.model_save_path], clean=False)
        model_path = FileOps.join_path(self.model_save_path, mname)
        self.estimator.save(model_path)
        if model_url and FileOps.exists(model_path):
            FileOps.upload(model_path, model_url)
            model_path = model_url
        return model_path
コード例 #2
0
    def __init__(self, estimator, hard_example_mining: dict = None):
        super(IncrementalLearning, self).__init__(estimator=estimator)

        self.model_urls = self.get_parameters(
            "MODEL_URLS")  # use in evaluation
        self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value
        FileOps.clean_folder([self.config.model_url], clean=False)
        self.hard_example_mining_algorithm = None
        if not hard_example_mining:
            hard_example_mining = self.get_hem_algorithm_from_config()
        if hard_example_mining:
            hem = hard_example_mining.get("method", "IBT")
            hem_parameters = hard_example_mining.get("param", {})
            self.hard_example_mining_algorithm = ClassFactory.get_cls(
                ClassType.HEM, hem)(**hem_parameters)
コード例 #3
0
    def register(self, timeout=300):
        """
        Deprecated, Client proactively subscribes to the aggregation service.

        Parameters
        ----------
        timeout: int, connect timeout. Default: 300
        """
        self.log.info(
            f"Node {self.worker_name} connect to : {self.config.agg_uri}")
        self.node = AggregationClient(
            url=self.config.agg_uri,
            client_id=self.worker_name,
            ping_timeout=timeout
        )

        FileOps.clean_folder([self.config.model_url], clean=False)
        self.aggregation = self.aggregation()
        self.log.info(f"{self.worker_name} model prepared")
        if callable(self.estimator):
            self.estimator = self.estimator()
コード例 #4
0
ファイル: server.py プロジェクト: XinYao1994/sedna
    def __init__(self,
                 host: str,
                 http_port: int = 8080,
                 workers: int = 1,
                 save_dir=""):
        servername = "knowledgebase"

        super(KBServer, self).__init__(servername=servername,
                                       host=host,
                                       http_port=http_port,
                                       workers=workers)
        self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0]
        self.url = f"{self.url}/{servername}"
        self.kb_index = KBResourceConstant.KB_INDEX_NAME.value
        self.app = FastAPI(
            routes=[
                APIRoute(
                    f"/{servername}/update",
                    self.update,
                    methods=["POST"],
                ),
                APIRoute(
                    f"/{servername}/update/status",
                    self.update_status,
                    methods=["POST"],
                ),
                APIRoute(
                    f"/{servername}/query",
                    self.query,
                    response_model=TaskItem,
                    response_class=JSONResponse,
                    methods=["POST"],
                ),
                APIRoute(
                    f"/{servername}/file/download",
                    self.file_download,
                    methods=["GET"],
                ),
                APIRoute(
                    f"/{servername}/file/upload",
                    self.file_upload,
                    methods=["POST"],
                ),
            ],
            log_level="trace",
            timeout=600,
        )
コード例 #5
0
import time
import warnings

import cv2
import numpy as np

from sedna.common.config import Context
from sedna.common.file_ops import FileOps
from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator

he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']

FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False)


def draw_boxes(img, labels, scores, bboxes, class_names, colors):
    line_type = 2
    text_thickness = 1
    box_thickness = 1
    #  get color code
    colors = colors.split(",")
    colors_code = []
    for color in colors:
        if color == 'green':
            colors_code.append((0, 255, 0))
        elif color == 'blue':
            colors_code.append((255, 0, 0))
        elif color == 'yellow':