예제 #1
0
    def _prepare_module(task_id, symbol, ctx_config, data_names, label_names,
                        resume_config):
        if not resume_config['is_resume'] == '0':
            return Module(symbol=symbol,
                          context=Executor._prepare_ctx(ctx_config),
                          data_names=data_names,
                          label_names=label_names,
                          logger=get_logger('mxnet_logger[tid=%s]' % task_id,
                                            log_to_console=False,
                                            log_to_file=True))
        else:
            ckp = resume_config['ckp']
            prefix = ckp['prefix']
            epoch = ckp['epoch']
            params_path = osp.join(params_root_path,
                                   '%s-%04d.params' % (prefix, epoch))
            # Copyed from MXNet

            # Licensed to the Apache Software Foundation (ASF) under one
            # or more contributor license agreements.  See the NOTICE file
            # distributed with this work for additional information
            # regarding copyright ownership.  The ASF licenses this file
            # to you under the Apache License, Version 2.0 (the
            # "License"); you may not use this file except in compliance
            # with the License.  You may obtain a copy of the License at
            #
            #   http://www.apache.org/licenses/LICENSE-2.0
            #
            # Unless required by applicable law or agreed to in writing,
            # software distributed under the License is distributed on an
            # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
            # KIND, either express or implied.  See the License for the
            # specific language governing permissions and limitations
            # under the License.
            save_dict = nd.load(params_path)
            arg_params = {}
            aux_params = {}
            for k, v in save_dict.items():
                tp, name = k.split(':', 1)
                if tp == 'arg':
                    arg_params[name] = v
                if tp == 'aux':
                    aux_params[name] = v
            mod = Module(symbol=symbol,
                         context=Executor._prepare_ctx(ctx_config),
                         logger=get_logger('mxnet_logger[tid=%s]' % task_id,
                                           log_to_console=False,
                                           log_to_file=True))
            mod._arg_params = arg_params
            mod._aux_params = aux_params
            mod.params_initialized = True
            # TODO: There is a parameter named load_optimizer_states in Module.load
            return mod
예제 #2
0
    def load_check_point(sym_json_path, params_path, ctx_config_tuple,
                         task_id):
        ctx_config = list(ctx_config_tuple)
        # Copyed from MXNet

        # Licensed to the Apache Software Foundation (ASF) under one
        # or more contributor license agreements.  See the NOTICE file
        # distributed with this work for additional information
        # regarding copyright ownership.  The ASF licenses this file
        # to you under the Apache License, Version 2.0 (the
        # "License"); you may not use this file except in compliance
        # with the License.  You may obtain a copy of the License at
        #
        #   http://www.apache.org/licenses/LICENSE-2.0
        #
        # Unless required by applicable law or agreed to in writing,
        # software distributed under the License is distributed on an
        # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
        # KIND, either express or implied.  See the License for the
        # specific language governing permissions and limitations
        # under the License.
        if not isinstance(sym_json_path, sym.Symbol):
            symbol = sym.load(sym_json_path)
        else:
            # If sym_json_path is already an instance of mxnet.sym.Symbol
            symbol = sym_json_path
        save_dict = nd.load(params_path)
        arg_params = {}
        aux_params = {}
        for k, v in save_dict.items():
            tp, name = k.split(':', 1)
            if tp == 'arg':
                arg_params[name] = v
            if tp == 'aux':
                aux_params[name] = v
        mod = Module(symbol=symbol,
                     context=generate_ctx(ctx_config),
                     logger=get_logger('mxnet_logger[tid=%s]' % task_id,
                                       log_to_console=False,
                                       log_to_file=True))
        mod._arg_params = arg_params
        mod._aux_params = aux_params
        mod.params_initialized = True
        # TODO: There is a parameter named load_optimizer_states in Module.load
        return mod
예제 #3
0
파일: run.py 프로젝트: likesundayl/mxserver
import grpc

current_dir = sys.path[0]
index = current_dir.index('flask_server')
module_dir = current_dir[0:index]
sys.path.append(module_dir)

from worker.proto import mxserver_pb2, mxserver_pb2_grpc
from util.conf_parser import mxserver_flask_config
from util.exception_handler import exception_msg
from util.logger_generator import get_logger
from worker.gpu import gpu_monitor
from dispatcher import Dispatcher
from flask_server.zk_register import ZkRegister

mxserver_flask_logger = get_logger('mxserver_flask_server')

dispatcher = Dispatcher.create_dispatcher()
mxserver_flask_logger.info(
    'The mxserver flask server has created a dispatcher with type: %s' %
    dispatcher.type())

app = Flask('MXServer-Flask-Server')


@app.route('/train', methods=['POST'])
def train():
    task_json = request.json
    task_id = task_json['task_id']
    mxserver_flask_logger.info(
        'The mxserver_flask_server receives a request to start a train task with id: %s'
예제 #4
0
 def __init__(self, task_queue):
     self._logger = get_logger('mxnet_service')
     self._queue = task_queue
     self._task_config_recorder = TaskConfigRecorder()
     self._user_action_recorder = UserActionRecorder()
     self._task_dict = {}
예제 #5
0
# -*- coding: utf-8 -*-

# @Author: Terence Wu
# @Time: 26/02/18 上午 11:37
from requests import post
from test_resources import STOP_TEST_URL, STOP_REQUEST_JSON
from util.logger_generator import get_logger
from util.exception_handler import exception_msg


if __name__ == '__main__':
    logger = get_logger('test_stop_request')
    logger.info('Begin to test API for deep learning training')
    logger.info('Begin to send request to url: %s' % STOP_TEST_URL)
    try:
        response = post(url=STOP_TEST_URL, json=STOP_REQUEST_JSON)
        logger.info('Receive a response')
        logger.info('Response\'s status code: %s' % response.status_code)
        logger.info('Response\'s content: %s' % response.content)
    except BaseException as e:
        logger.error('Fail! Error message: %s\n' % exception_msg(e))
예제 #6
0
# -*- coding: utf-8 -*-

# ------------------------------
# Copyright (c) 2017-present Terence Wu
# ------------------------------
import time
from Queue import Empty
from multiprocessing import Process

from util.logger_generator import get_logger

_logger = get_logger('executor_process_manager')


class ExecutorProcessManager(Process):
    def __init__(self, task_queue):
        super(ExecutorProcessManager, self).__init__()

        self._task_queue = task_queue

    def run(self):
        _logger.info('The executor_process_manager has been started')
        while True:
            try:
                executor_process = self._task_queue.get_nowait()
                task_id = executor_process.task_id()
                _logger.info(
                    'executor_process_manager gets an ExecutorProcess instance with task_id: %s from task queue, '
                    'now start it' % task_id)
                executor_process.start()
            except Empty:
# -*- coding: utf-8 -*-

# @Author: Terence Wu
# @Time: 26/02/18 上午 11:42
from requests import post
from test_resources import EVALUATE_TEST_URL, TEST_REQUEST_JSON
from util.logger_generator import get_logger
from util.exception_handler import exception_msg

if __name__ == '__main__':
    logger = get_logger('test_inference_request')
    logger.info('Begin to test API for deep learning inference')
    logger.info('Begin to send request to url: %s' % EVALUATE_TEST_URL)
    try:
        response = post(url=EVALUATE_TEST_URL, json=TEST_REQUEST_JSON)
        logger.info('Receive a response')
        logger.info('Response\'s status code: %s' % response.status_code)
        logger.info('Response\'s content: %s' % response.content)
    except BaseException as e:
        logger.error('Fail! Error message: %s\n' % exception_msg(e))
예제 #8
0
# -*- coding: utf-8 -*-

# @Author: Terence Wu
# @Time: 26/02/18 上午 11:36
from requests import post
from test_resources import TRAIN_TEST_URL, TRAIN_REQUEST_JSON
from util.logger_generator import get_logger
from util.exception_handler import exception_msg


if __name__ == '__main__':
    logger = get_logger('test_train_request')
    logger.info('Begin to test API for deep learning training')
    logger.info('Begin to send request to url: %s' % TRAIN_TEST_URL)
    try:
        response = post(url=TRAIN_TEST_URL, json=TRAIN_REQUEST_JSON)
        logger.info('Receive a response')
        logger.info('Response\'s status code: %s' % response.status_code)
        logger.info('Response\'s content: %s' % response.content)
    except BaseException as e:
        logger.error('Fail! Error message: %s\n' % exception_msg(e))
예제 #9
0
# -*- coding: utf-8 -*-

# ------------------------------
# Copyright (c) 2017-present Terence Wu
# ------------------------------
from multiprocessing import Process

from util.exception_handler import exception_msg
from util.logger_generator import get_logger
from util.time_getter import get_time
from worker.db.mongo_connector import TaskProgressRecorder
from worker.mxnet_extension.core.executor import Executor
from worker.mxnet_extension.io.data_manager import DataManager
from worker.task_desc_parser import parse_task_desc, get_data_config

_logger = get_logger('executor_process')


class ExecutorProcess(Process):
    def __init__(self, process_id, task_desc):
        super(ExecutorProcess, self).__init__()
        self._task_progress_recorder = TaskProgressRecorder()
        self._process_id = process_id
        self._task_desc = task_desc
        self._task_progress_list = []

    def task_id(self):
        return self._process_id

    def run(self):
        self._task_progress_recorder.insert_one({
예제 #10
0
파일: run.py 프로젝트: likesundayl/mxserver
sys.path.append(module_dir)

from util.logger_generator import get_logger
from worker.mxnet_extension.core.executor_process_manager import ExecutorProcessManager
from worker.proto import mxserver_pb2_grpc
from worker.rpc.mxnet_service import MXNetService
from util.conf_parser import mxserver_mxnet_config, mxserver_rpc_config, mxserver_task_queue_config
from util.exception_handler import exception_msg
from worker.zk_register import ZkRegister

# Add rcnn package to sys.path
sys.path.append(mxserver_mxnet_config['rcnn-path'])
print sys.path

if __name__ == '__main__':
    main_logger = get_logger('mxserver_worker_logger')
    try:
        if ZkRegister.use_zk():
            main_logger.info('The mxserver worker is trying to register to ZooKeeper')
        zk_register = ZkRegister()
        zk_register.register_worker_to_zk()
        if ZkRegister.use_zk():
            main_logger.info('The mxserver worker has registered to ZooKeeper')
    except BaseException as e:
        main_logger.error('The mxserver worker can not register to ZooKeeper! System exists! Error message: \n%s'
                          % exception_msg(e))
        sys.exit('Failed to register to ZooKeeper')
    task_queue = Queue(int(mxserver_task_queue_config['queue-max-size']))
    try:
        executor_process_manager = ExecutorProcessManager(task_queue=task_queue)
        executor_process_manager.start()
예제 #11
0
# -*- coding: utf-8 -*-

# @Author: Terence Wu
# @Time: 12/03/18 上午 09:34
from requests import post
from config import FLASK_HOST, FLASK_PORT
from util.logger_generator import get_logger

DEPLOY_ID = ''

if __name__ == '__main__':
    logger = get_logger('classify_logger')
    url = 'http://%s:%s/classify' % FLASK_HOST, FLASK_PORT