コード例 #1
0
ファイル: server.py プロジェクト: clockfly/ReNomIMG
def model_load_prediction_result(id):
    thread = PredictionThread.jobs.get(id, None)
    if thread is None:
        saved_model = storage.fetch_model(id)
        if saved_model is None:
            raise Exception("Model id {} is not found".format(id))
        # If the state == STOPPED, client will never throw request.
        if saved_model["state"] != State.STOPPED.value:
            storage.update_model(id, state=State.STOPPED.value,
                                 running_state=RunningState.STOPPING.value)
            saved_model = storage.fetch_model(id)
        return {"result": saved_model['last_prediction_result']}
    else:
        thread.need_pull = False
        return {"result": thread.prediction_result}
コード例 #2
0
ファイル: server.py プロジェクト: clockfly/ReNomIMG
def model_load_best_result(id):
    thread = TrainThread.jobs.get(id, None)
    if thread is None:
        saved_model = storage.fetch_model(id)
        if saved_model is None:
            return
        # If the state == STOPPED, client will never throw request.
        if saved_model["state"] != State.STOPPED.value:
            storage.update_model(id, state=State.STOPPED.value,
                                 running_state=RunningState.STOPPING.value)
            saved_model = storage.fetch_model(id)
        return {"best_result": saved_model['best_epoch_valid_result']}
    else:
        thread.returned_best_result2client()
        return {"best_result": thread.best_epoch_valid_result}
コード例 #3
0
    def _prepare_params(self):
        if self.stop_event.is_set():
            # Watch stop event
            self.updated = True
            return

        params = storage.fetch_model(self.model_id)
        self.task_id = int(params["task_id"])
        self.dataset_id = int(params["dataset_id"])
        self.algorithm_id = int(params["algorithm_id"])
        self.hyper_parameters = params["hyper_parameters"]
        self.last_weight_path = params["last_weight"]
        self.best_weight_path = params["best_epoch_weight"]

        dataset = storage.fetch_dataset(self.dataset_id)
        self.class_map = dataset["class_map"]

        self.common_params = [
            'total_epoch', 'batch_size', 'imsize_w', 'imsize_h', 'train_whole',
            'load_pretrained_weight'
        ]

        assert all(
            [k in self.hyper_parameters.keys() for k in self.common_params])

        # Training States
        # TODO: Need getter for json decoding.
        self.load_pretrained_weight = False
        self.train_whole = bool(self.hyper_parameters["train_whole"])
        self.imsize = (int(self.hyper_parameters["imsize_w"]),
                       int(self.hyper_parameters["imsize_h"]))
        self.batch_size = int(self.hyper_parameters["batch_size"])
コード例 #4
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def run_prediction(project_id, model_id):
    # 学習データ読み込み
    try:
        thread_id = "{}_{}".format(project_id, model_id)
        fields = 'hyper_parameters,algorithm,algorithm_params,best_epoch_weight,dataset_def_id'
        data = storage.fetch_model(project_id, model_id, fields=fields)
        (id, name, ratio, train_imgs, valid_imgs, class_map, created,
         updated) = storage.fetch_dataset_def(data['dataset_def_id'])
        # weightのh5ファイルのパスを取得して予測する
        with Executor(max_workers=MAX_THREAD_NUM) as prediction_executor:
            th = PredictionThread(thread_id, model_id,
                                  data["hyper_parameters"], data["algorithm"],
                                  data["algorithm_params"],
                                  data["best_epoch_weight"], class_map)
            ft = prediction_executor.submit(th)
            prediction_thread_pool[thread_id] = [ft, th]
        ft.result()

        if th.error_msg is not None:
            body = json.dumps({"error_msg": th.error_msg})
        else:
            data = {
                "predict_results": th.predict_results,
                "csv": th.csv_filename,
            }
            body = json.dumps(data)
    except Exception as e:
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})

    ret = create_response(body)
    return ret
コード例 #5
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def run_model(project_id, model_id):
    """
    Create thread(Future object) and submit it to executor.
    The thread is stored to train_thread_pool as a pair of thread_id and thread.
    """
    try:
        fields = 'hyper_parameters,algorithm,algorithm_params,dataset_def_id'
        data = storage.fetch_model(project_id, model_id, fields=fields)
        thread_id = "{}_{}".format(project_id, model_id)
        th = TrainThread(thread_id, project_id, model_id,
                         data['dataset_def_id'], data["hyper_parameters"],
                         data['algorithm'], data['algorithm_params'])
        ft = executor.submit(th)
        train_thread_pool[thread_id] = [ft, th]

        try:
            # This will wait for end of thread.
            ft.result()
            ft.cancel()
        except CancelledError as ce:
            # If the model is deleted or stopped,
            # program reaches here.
            pass
        error_msg = th.error_msg
        del train_thread_pool[thread_id]
        ft = None
        th = None

        model = storage.fetch_model(project_id, model_id, fields='state')
        if model['state'] != STATE_DELETED:
            storage.update_model_state(model_id, STATE_FINISHED)
        release_mem_pool()

        if error_msg is not None:
            body = json.dumps({"error_msg": error_msg})
            ret = create_response(body)
            return ret
        body = json.dumps({"dummy": ""})
        ret = create_response(body)
        return ret

    except Exception as e:
        release_mem_pool()
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})
        ret = create_response(body)
        return ret
コード例 #6
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def get_deployed_model_info(project_id):
    # This method will be called from python script.
    try:
        deployed_id = storage.fetch_deployed_model_id(
            project_id)[0]['deploy_model_id']
        ret = storage.fetch_model(project_id, deployed_id, "best_epoch_weight")
        file_name = ret['best_epoch_weight']
        ret = storage.fetch_model(
            project_id, deployed_id,
            "algorithm,algorithm_params,hyper_parameters")
        ret["filename"] = file_name
        body = json.dumps(ret)
        ret = create_response(body)
        return ret
    except Exception as e:
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})
        ret = create_response(body)
        return ret
コード例 #7
0
ファイル: server.py プロジェクト: clockfly/ReNomIMG
def export_csv(model_id):
    try:
        model = storage.fetch_model(model_id)
        prediction = model["last_prediction_result"]
        task_id = model["task_id"]
        print(task_id)
        ret = []
        if task_id == Task.CLASSIFICATION.value:
            img_path = prediction["img"]
            sizes = prediction["size"]
            prediction = prediction["prediction"]
            for img, size, pred in zip(img_path, sizes, prediction):
                ret.append({
                    'path': img,
                    'size': size,
                    'predictions': pred["class"]
                })

        elif task_id == Task.DETECTION.value:
            img_path = prediction["img"]
            sizes = prediction["size"]
            prediction = prediction["prediction"]
            for img, size, pred in zip(img_path, sizes, prediction):
                ret.append({
                    'path': img,
                    'size': size,
                    'predictions': pred
                })

        elif task_id == Task.SEGMENTATION.value:
            img_path = prediction["img"]
            sizes = prediction["size"]
            prediction = prediction["prediction"]
            for img, size, pred in zip(img_path, sizes, prediction):
                ret.append({
                    'path': img,
                    'size': size,
                    'predictions': pred
                })
        else:
            raise Exception("Not supported task id.")

        df = pd.DataFrame.from_dict(json_normalize(ret), orient='columns')
        df.to_csv('prediction.csv')
        return static_file("prediction.csv", root='.', download=True)

    except Exception as e:
        release_mem_pool()
        traceback.print_exc()
        body = json.dumps({"error_msg": "{}: {}".format(type(e).__name__, str(e))})
        ret = create_response(body, 500)
        return ret
コード例 #8
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def pull_deployed_model(project_id):
    # This method will be called from python script.
    try:
        deployed_id = storage.fetch_deployed_model_id(
            project_id)[0]['deploy_model_id']
        ret = storage.fetch_model(project_id, deployed_id, "best_epoch_weight")
        file_name = ret['best_epoch_weight']
        path = DB_DIR_TRAINED_WEIGHT
        return static_file(file_name, root=path, download='deployed_model.h5')
    except Exception as e:
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})
        ret = create_response(body)
        return ret
コード例 #9
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def get_model(project_id, model_id):
    try:
        kwargs = {}
        if request.params.fields != '':
            kwargs["fields"] = request.params.fields

        data = storage.fetch_model(project_id, model_id, **kwargs)
        body = json.dumps(data)

    except Exception as e:
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})

    ret = create_response(body)
    return ret
コード例 #10
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def delete_model(project_id, model_id):
    try:
        thread_id = "{}_{}".format(project_id, model_id)
        storage.update_model_state(model_id, STATE_DELETED)
        th = train_thread_pool.get(thread_id, None)
        if th is not None:
            if not th[0].cancel():
                th[1].stop()
                th[0].result()

        ret = storage.fetch_model(project_id, model_id, "best_epoch_weight")
        file_name = ret.get('best_epoch_weight', None)
        if file_name is not None:
            weight_path = os.path.join(DB_DIR_TRAINED_WEIGHT, file_name)
            if os.path.exists(weight_path):
                os.remove(weight_path)

    except Exception as e:
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})
        ret = create_response(body)
        return ret
コード例 #11
0
ファイル: server.py プロジェクト: DaikiOnodera/ReNomIMG
def progress_model(project_id, model_id):
    try:
        try:
            req_last_batch = request.params.get("last_batch", None)
            req_last_batch = int(
                req_last_batch) if req_last_batch is not None else 0
            req_last_epoch = request.params.get("last_epoch", None)
            req_last_epoch = int(
                req_last_epoch) if req_last_epoch is not None else 0
            req_running_state = request.params.get("running_state", None)
            req_running_state = int(
                req_running_state) if req_running_state is not None else 0
        except Exception as e:
            req_last_batch = 0
            req_last_epoch = 0
            req_running_state = 0

        thread_id = "{}_{}".format(project_id, model_id)
        for j in range(60):
            time.sleep(0.75)
            th = train_thread_pool.get(thread_id, None)
            model_state = storage.fetch_model(project_id,
                                              model_id,
                                              fields="state")["state"]
            if th is not None:
                th = th[1]
                # If thread status updated, return response.
                if isinstance(
                        th, TrainThread
                ) and th.nth_epoch != req_last_epoch and th.valid_loss_list:
                    best_epoch = int(np.argmin(th.valid_loss_list))
                    try:
                        body = json.dumps({
                            "total_batch":
                            th.total_batch,
                            "last_batch":
                            th.nth_batch,
                            "last_epoch":
                            th.nth_epoch,
                            "batch_loss":
                            th.last_batch_loss,
                            "running_state":
                            th.running_state,
                            "state":
                            model_state,
                            "validation_loss_list":
                            th.valid_loss_list,
                            "train_loss_list":
                            th.train_loss_list,
                            "best_epoch":
                            best_epoch,
                            "best_epoch_iou":
                            th.valid_iou_list[best_epoch],
                            "best_epoch_map":
                            th.valid_map_list[best_epoch],
                            "best_epoch_validation_result":
                            th.valid_predict_box[best_epoch]
                        })
                        ret = create_response(body)
                        return ret
                    except Exception as e:
                        traceback.print_exc()
                        import pdb
                        pdb.set_trace()

                elif isinstance(th, TrainThread) and (
                        th.nth_batch != req_last_batch
                        or th.running_state != req_running_state
                        or th.weight_existance == WEIGHT_DOWNLOADING):
                    body = json.dumps({
                        "total_batch": th.total_batch,
                        "last_batch": th.nth_batch,
                        "last_epoch": th.nth_epoch,
                        "batch_loss": th.last_batch_loss,
                        "running_state": th.running_state,
                        "state": model_state,
                        "validation_loss_list": [],
                        "train_loss_list": [],
                        "best_epoch": 0,
                        "best_epoch_iou": 0,
                        "best_epoch_map": 0,
                        "best_epoch_validation_result": []
                    })
                    ret = create_response(body)
                    return ret

    except Exception as e:
        traceback.print_exc()
        body = json.dumps({"error_msg": e.args[0]})
        ret = create_response(body)
        return ret
コード例 #12
0
    def _prepare_params(self):
        if self.stop_event.is_set():
            # Watch stop event
            self.updated = True
            return

        params = storage.fetch_model(self.model_id)
        self.task_id = int(params["task_id"])
        self.dataset_id = int(params["dataset_id"])
        self.algorithm_id = int(params["algorithm_id"])
        self.hyper_parameters = params["hyper_parameters"]
        self.last_weight_path = params["last_weight"]
        self.best_weight_path = params["best_epoch_weight"]

        dataset = storage.fetch_dataset(self.dataset_id)
        self.class_map = dataset["class_map"]
        self.train_img = dataset["train_data"]["img"]
        self.train_target = dataset["train_data"]["target"]
        self.valid_img = dataset["valid_data"]["img"]
        self.valid_target = dataset["valid_data"]["target"]

        for path in self.train_img + self.valid_img:
            if not os.path.exists(path):
                raise FileNotFoundError(
                    "The image file {} is not found.".format(path))

        self.train_dist = None
        self.valid_dist = None

        n_data = len(self.train_img)

        self.common_params = [
            'total_epoch', 'batch_size', 'imsize_w', 'imsize_h', 'train_whole',
            'load_pretrained_weight'
        ]

        assert all(
            [k in self.hyper_parameters.keys() for k in self.common_params])

        # Training States
        # TODO: Need getter for json decoding.
        self.load_pretrained_weight = bool(
            self.hyper_parameters["load_pretrained_weight"])
        self.train_whole = bool(self.hyper_parameters["train_whole"])
        self.imsize = (int(self.hyper_parameters["imsize_w"]),
                       int(self.hyper_parameters["imsize_h"]))
        self.batch_size = int(self.hyper_parameters["batch_size"])
        self.total_epoch = int(self.hyper_parameters["total_epoch"])
        self.nth_epoch = 0
        self.total_batch = int(np.ceil(n_data / self.batch_size))
        self.nth_batch = 0
        self.last_batch_loss = 0
        self.train_loss_list = []
        self.valid_loss_list = []
        self.best_epoch_valid_result = {}

        # Augmentation Setting.
        self.augmentation = Augmentation([
            Shift(10, 10),
            Rotate(),
            Flip(),
            ContrastNorm(),
        ])
コード例 #13
0
ファイル: server.py プロジェクト: clockfly/ReNomIMG
def polling_train(id):
    """

    Cations:
        This function is possible to return empty dictionary.
    """
    threads = TrainThread.jobs
    active_train_thread = threads.get(id, None)
    if active_train_thread is None:
        saved_model = storage.fetch_model(id)
        if saved_model is None:
            return

        # If the state == STOPPED, client will never throw request.
        if saved_model["state"] != State.STOPPED.value:
            storage.update_model(id, state=State.STOPPED.value,
                                 running_state=RunningState.STOPPING.value)
            saved_model = storage.fetch_model(id)

        return {
            "state": saved_model["state"],
            "running_state": saved_model["running_state"],
            "total_epoch": saved_model["total_epoch"],
            "nth_epoch": saved_model["nth_epoch"],
            "total_batch": saved_model["total_batch"],
            "nth_batch": saved_model["nth_batch"],
            "last_batch_loss": saved_model["last_batch_loss"],
            "total_valid_batch": 0,
            "nth_valid_batch": 0,
            "best_result_changed": False,
            "train_loss_list": saved_model["train_loss_list"],
            "valid_loss_list": saved_model["valid_loss_list"],
        }
    elif active_train_thread.state == State.RESERVED or \
            active_train_thread.state == State.CREATED:

        for _ in range(60):
            if active_train_thread.state == State.RESERVED or \
                    active_train_thread.state == State.CREATED:
                time.sleep(1)
                if active_train_thread.updated:
                    active_train_thread.returned2client()
                    break
            else:
                time.sleep(1)
                break

        active_train_thread.consume_error()
        return {
            "state": active_train_thread.state.value,
            "running_state": active_train_thread.running_state.value,
            "total_epoch": 0,
            "nth_epoch": 0,
            "total_batch": 0,
            "nth_batch": 0,
            "last_batch_loss": 0,
            "total_valid_batch": 0,
            "nth_valid_batch": 0,
            "best_result_changed": False,
            "train_loss_list": [],
            "valid_loss_list": [],
        }
    else:
        for _ in range(10):
            time.sleep(0.5)  # Avoid many request.
            if active_train_thread.updated:
                break
            active_train_thread.consume_error()
        active_train_thread.returned2client()
        return {
            "state": active_train_thread.state.value,
            "running_state": active_train_thread.running_state.value,
            "total_epoch": active_train_thread.total_epoch,
            "nth_epoch": active_train_thread.nth_epoch,
            "total_batch": active_train_thread.total_batch,
            "nth_batch": active_train_thread.nth_batch,
            "last_batch_loss": active_train_thread.last_batch_loss,
            "total_valid_batch": 0,
            "nth_valid_batch": 0,
            "best_result_changed": active_train_thread.best_valid_changed,
            "train_loss_list": active_train_thread.train_loss_list,
            "valid_loss_list": active_train_thread.valid_loss_list,
        }