Пример #1
0
def test_config_file_only_with_invalid_text_data_unknown_value(config_file):
    config_file.write("daemon: unknown\n")
    config_file.flush()
    argv = ['-c', config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file option daemon must be True or False." == str(e.value)
def test_config_file_only_with_invalid_text_data_not_yaml(config_file):
    config_file.write("daemon\n")
    config_file.flush()
    argv = ["-c", config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file %s contents didn't yield dict or not YAML: daemon" % config_file.name == str(e.value)
def test_config_file_only_tab_character(config_file):
    config_file.write("domain: mydomain.com\nuser:\tthisuser\npasswd: abc")
    config_file.flush()
    argv = ["-c", config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Tab character found in config file %s. Must use spaces only!" % config_file.name == str(e.value)
def test_config_file_only_with_invalid_binary_data(config_file):
    config_file.write(os.urandom(1024))
    config_file.flush()
    argv = ["-c", config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Unable to read config file %s, invalid data." % config_file.name == str(e.value)
Пример #5
0
def test_config_file_and_cli_overlapping_with_incomplete_data(config_file):
    config_file.write("domain: mydomain3.com")
    config_file.flush()
    argv = ['-c', config_file.name, '-n', 'abc.com', '-u', 'usera']
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "A domain, username, and password must be specified." == str(e.value)
def test_config_file_only_with_invalid_text_data_unknown_option(config_file):
    config_file.write("test: true\n")
    config_file.flush()
    argv = ["-c", config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Unknown option test in config file %s." % config_file.name == str(e.value)
def test_config_file_only_with_invalid_text_data_unknown_value(config_file):
    config_file.write("daemon: unknown\n")
    config_file.flush()
    argv = ["-c", config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file option daemon must be True or False." == str(e.value)
Пример #8
0
def test_config_file_only_with_invalid_text_data_unknown_option(config_file):
    config_file.write("test: true\n")
    config_file.flush()
    argv = ['-c', config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Unknown option test in config file %s." % config_file.name == str(
        e.value)
Пример #9
0
def test_config_file_only_with_invalid_binary_data(config_file):
    config_file.write(os.urandom(1024))
    config_file.flush()
    argv = ['-c', config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Unable to read config file %s, invalid data." % config_file.name == str(
        e.value)
Пример #10
0
def test_config_file_only_tab_character(config_file):
    config_file.write("domain: mydomain.com\nuser:\tthisuser\npasswd: abc")
    config_file.flush()
    argv = ['-c', config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Tab character found in config file %s. Must use spaces only!" % config_file.name == str(
        e.value)
Пример #11
0
def test_config_file_only_with_invalid_text_data_not_yaml(config_file):
    config_file.write("daemon\n")
    config_file.flush()
    argv = ['-c', config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file %s contents didn't yield dict or not YAML: daemon" % config_file.name == str(
        e.value)
Пример #12
0
def test_cli_interval_fail():
    argv = ['-n', 'test.com', '-p', 'testpw', '-u', 'testuser', '-i', 'shouldBeNum']
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config option 'interval' must be a number." == str(e.value)
    argv = ['-n', 'test.com', '-p', 'testpw', '-u', 'testuser', '-i', '0']
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config option 'interval' must be greater than 0." == str(e.value)
Пример #13
0
def test_config_file_only_with_invalid_text_data_not_yaml_big(config_file):
    config_file.write("""
        domain mydomain.com  # i am a comment
        user thisuser#comment
        #another comment
        passwd abc"
    """)
    config_file.flush()
    argv = ['-c', config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file %s contents not YAML formatted:" % config_file.name in str(
        e.value)
def test_config_file_only_with_invalid_text_data_not_yaml_big(config_file):
    config_file.write(
        """
        domain mydomain.com  # i am a comment
        user thisuser#comment
        #another comment
        passwd abc"
    """
    )
    config_file.flush()
    argv = ["-c", config_file.name]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file %s contents not YAML formatted:" % config_file.name in str(e.value)
Пример #15
0
def test_quiet_logfile_verbose(capsys, log_file):
    config = libs.get_config(
        docopt(
            uddns_doc,
            version=uddns_ver,
            argv=["-n", "x", "-u", "x", "-p", "x", "--quiet", "--verbose", "--log", log_file.name],
        )
    )
    with libs.LoggingSetup(config["verbose"], config["log"], config["quiet"]) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert len(logging.getLogger().handlers) == 1
    assert isinstance(logging.getLogger().handlers[0], libs.LoggingSetup.TimedRotatingFileHandler)

    timestamp = log_samples()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = ""
    stderr_expected = ""
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual

    log_actual = log_file.read(1024)
    log_expected = "%s DEBUG    root                           Test debug testing.\n" % timestamp
    log_expected += "%s INFO     root                           Test info testing.\n" % timestamp
    log_expected += "%s WARNING  root                           Test warn testing.\n" % timestamp
    log_expected += "%s ERROR    root                           Test error testing.\n" % timestamp
    log_expected += "%s CRITICAL root                           Test critical testing.\n" % timestamp
    assert log_expected == log_actual
Пример #16
0
def test_logfile(capsys, log_file):
    config = libs.get_config(
        docopt(uddns_doc,
               version=uddns_ver,
               argv=['-n', 'x', '-u', 'x', '-p', 'x', '--log', log_file.name]))
    with libs.LoggingSetup(config['verbose'], config['log'],
                           config['quiet']) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert 2 == len(logging.getLogger().handlers)
    assert isinstance(logging.getLogger().handlers[0],
                      libs.LoggingSetup.ConsoleHandler)
    assert isinstance(logging.getLogger().handlers[1],
                      libs.LoggingSetup.TimedRotatingFileHandler)

    timestamp = log_samples()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = "Test info testing.\n"
    stderr_expected = "Test warn testing.\nTest error testing.\nTest critical testing.\n"
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual

    log_actual = log_file.read(1024)
    log_expected = "%s INFO     root                           Test info testing.\n" % timestamp
    log_expected += "%s WARNING  root                           Test warn testing.\n" % timestamp
    log_expected += "%s ERROR    root                           Test error testing.\n" % timestamp
    log_expected += "%s CRITICAL root                           Test critical testing.\n" % timestamp
    assert log_expected == log_actual
Пример #17
0
def test_logfile_multiple_loggers(capsys, log_file):
    config = libs.get_config(
        docopt(uddns_doc, version=uddns_ver, argv=["-n", "x", "-u", "x", "-p", "x", "--log", log_file.name])
    )
    with libs.LoggingSetup(config["verbose"], config["log"], config["quiet"]) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert 2 == len(logging.getLogger().handlers)
    assert isinstance(logging.getLogger().handlers[0], libs.LoggingSetup.ConsoleHandler)
    assert isinstance(logging.getLogger().handlers[1], libs.LoggingSetup.TimedRotatingFileHandler)

    timestamp = log_samples()
    time.sleep(1)
    timestamp_named = log_samples_named()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = "Test info testing.\nTest info testing.\n"
    stderr_expected = "Test warn testing.\nTest error testing.\nTest critical testing.\n"
    stderr_expected += "Test warn testing.\nTest error testing.\nTest critical testing.\n"
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual

    log_actual = log_file.read(1024)
    log_expected = "%s INFO     root                           Test info testing.\n" % timestamp
    log_expected += "%s WARNING  root                           Test warn testing.\n" % timestamp
    log_expected += "%s ERROR    root                           Test error testing.\n" % timestamp
    log_expected += "%s CRITICAL root                           Test critical testing.\n" % timestamp
    log_expected += "%s INFO     test_logging                   Test info testing.\n" % timestamp_named
    log_expected += "%s WARNING  test_logging                   Test warn testing.\n" % timestamp_named
    log_expected += "%s ERROR    test_logging                   Test error testing.\n" % timestamp_named
    log_expected += "%s CRITICAL test_logging                   Test critical testing.\n" % timestamp_named
    assert log_expected == log_actual
def test_config_file_only_with_full_valid_data_and_comments(config_file):
    config_file.write(
        """
        domain:    mydomain.com  # i am a comment
        user: thisuser #comment
        #another comment
        passwd: abc
    """
    )
    config_file.flush()
    argv = ["-c", config_file.name]
    expected = dict(
        log=None,
        daemon=False,
        verbose=False,
        interval=60,
        pid=None,
        quiet=False,
        version=False,
        registrar="name.com",
        config=config_file.name,
        help=False,
        user="******",
        passwd="abc",
        domain="mydomain.com",
    )
    actual = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert expected == actual
Пример #19
0
def test_config_file_only_missing_log_value(config_file):
    config_file.write(
        "domain: mydomain.com\nuser: thisuser\npasswd: abc\nlog: #True\n")
    config_file.flush()
    argv = ['-c', config_file.name]
    config = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert None == config['log']
Пример #20
0
def test_cli_pass():
    argv = ['-n', 'test.com', '-p', 'testpw', '-u', 'testuser']
    expected = dict(log=None, daemon=False, verbose=False, interval=60, pid=None, quiet=False, version=False,
                    registrar='name.com', config=None, help=False,
                    user='******', passwd='testpw', domain='test.com')
    actual = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert expected == actual
Пример #21
0
 def __load_config(self):
     '''读取配置文件并完成所有配置
     '''
     config = libs.get_config()
     self.__qps = config.getint("limit_server", "qps")
     self.__qpd = config.getint("limit_server", "qpd")
     self.__max_user = config.getint("limit_server", "max_user")
     self.__port = config.getint("limit_server", "port")
Пример #22
0
 def __init__(self):
     conf = libs.get_config()
     host = conf.get('redis', 'host')
     port = conf.getint('redis', 'port')
     db_index = conf.getint('redis', 'db_index')
     self.__prefix = conf.get('redis', 'prefix')
     self.__super = super(KeywordMap, self)
     self.__super.__init__(host=host, port=port, db=db_index)
Пример #23
0
def test_config_file_and_cli_complimentary_with_full_valid_data(config_file):
    config_file.write("domain: mydomain.com")
    config_file.flush()
    argv = ['-c', config_file.name, '-u', 'usera', '-p', 'pass']
    expected = dict(log=None, daemon=False, verbose=False, interval=60, pid=None, quiet=False, version=False,
                    registrar='name.com', config=config_file.name, help=False,
                    user='******', passwd='pass', domain='mydomain.com')
    actual = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert expected == actual
Пример #24
0
def test_quiet(capsys):
    config = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=["-n", "x", "-u", "x", "-p", "x", "--quiet"]))
    with libs.LoggingSetup(config["verbose"], config["log"], config["quiet"]) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert len(logging.getLogger().handlers) == 1
    assert isinstance(logging.getLogger().handlers[0], libs.LoggingSetup.NullHandler)

    log_samples()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = ""
    stderr_expected = ""
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual
Пример #25
0
def test_verbose(capsys):
    config = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=["-n", "x", "-u", "x", "-p", "x", "--verbose"]))
    with libs.LoggingSetup(config["verbose"], config["log"], config["quiet"]) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert len(logging.getLogger().handlers) == 1
    assert isinstance(logging.getLogger().handlers[0], libs.LoggingSetup.ConsoleHandler)

    log_samples()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = "Test debug testing.\nTest info testing.\n"
    stderr_expected = "Test warn testing.\nTest error testing.\nTest critical testing.\n"
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual
Пример #26
0
def test_default(capsys):
    config = libs.get_config(
        docopt(uddns_doc,
               version=uddns_ver,
               argv=['-n', 'x', '-u', 'x', '-p', 'x']))
    with libs.LoggingSetup(config['verbose'], config['log'],
                           config['quiet']) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert 1 == len(logging.getLogger().handlers)
    assert isinstance(logging.getLogger().handlers[0],
                      libs.LoggingSetup.ConsoleHandler)

    log_samples()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = "Test info testing.\n"
    stderr_expected = "Test warn testing.\nTest error testing.\nTest critical testing.\n"
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual
Пример #27
0
def test_quiet(capsys):
    config = libs.get_config(
        docopt(uddns_doc,
               version=uddns_ver,
               argv=['-n', 'x', '-u', 'x', '-p', 'x', '--quiet']))
    with libs.LoggingSetup(config['verbose'], config['log'],
                           config['quiet']) as f:
        logging.config.fileConfig(f.config)  # Setup logging.
    assert len(logging.getLogger().handlers) == 1
    assert isinstance(logging.getLogger().handlers[0],
                      libs.LoggingSetup.NullHandler)

    log_samples()
    stdout_actual, stderr_actual = capsys.readouterr()
    stdout_expected = ""
    stderr_expected = ""
    assert stdout_expected == stdout_actual
    assert stderr_expected == stderr_actual
Пример #28
0
def test_config_file_only_with_full_valid_data(config_file):
    config_file.write("domain: mydomain.com\nuser: thisuser\npasswd: abc")
    config_file.flush()
    argv = ['-c', config_file.name]
    expected = dict(log=None,
                    daemon=False,
                    verbose=False,
                    interval=60,
                    pid=None,
                    quiet=False,
                    version=False,
                    registrar='name.com',
                    config=config_file.name,
                    help=False,
                    user='******',
                    passwd='abc',
                    domain='mydomain.com')
    actual = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert expected == actual
Пример #29
0
def run_server():
    '''初始化服务器
    '''
    # 仅 2 个 URI
    application = tornado.web.Application([
        (r"/ad", ad_handler.ADHandler),
        (r"/cid", cid_handler.CategoryHandler),
        (r"/trace", trace_handler.TraceHandler),
    ])

    config = libs.get_config()
    port = config.getint('ad_server', 'port')
    server = tornado.httpserver.HTTPServer(application)
    server.bind(port)
    # 多进程模式
    server.start(0)

    # 开始跑
    tornado.ioloop.IOLoop.instance().start()
Пример #30
0
class ADHandler(tornado.web.RequestHandler):
    '''处理 AD 请求的 Handler.

    请求流程:
        用户请求 Simeji Server (也就是这个代码), 然后 Simeji Server 会请求 Yahoo
        Server, 然后 Yahoo Server 将结果返还给 Simeji Server, 再通过 Simeji
        Server 返回给用户.

        输入为: Yahoo 提供的 Category ID, Category ID 可以通过 CategoryHandler
        来获取.

    优势:
        整个过程中, Yahoo Server 对用户是透明的. 而且与客户端的交互 API
        也是我们自己设定, 所以我们会有更强的掌控能力. 甚至某天我们更换广告提供方
        (Yahoo Server), 对用户而言也是没有区别的.

    返回:
        返回 JSON 格式的数据, 格式为:
        {
            'errno': 正整数错误码,
            'data': [
                {
                    'rank': 排名,
                    'title': 标题,
                    'description': 描述,
                    'url': 广告链接,
                },
                { ... },
                { ... },
                ...,
                    ],
        }

        其中, 错误码包括:
            0   正常返回
            1   非法请求
            2   内部错误
            3   YDN 返回结果非法
            4   qps 禁止
            5   用户达到每日上限
            6   达到用户流量控制上限
    '''

    config = libs.get_config()
    __max_limit = config.getint('ad_server', 'max_limit')
    __default_limit = config.getint('ad_server', 'default_limit')
    __limit_timeout = config.getint('ad_server', 'limit_timeout') / 1000.0
    __limit_server_addr = (
        config.get('ad_server', 'limit_server_host'),
        config.getint('ad_server', 'limit_server_port'),
    )
    __udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    __udp_socket.settimeout(__limit_timeout)

    def get(self):
        '''根据输入参数 cid (Category ID, 由 Yahoo 提供) 来获取广告数据.

        参数 (从 HTTP GET 参数中获取):
            cid: Yahoo 提供的 Categroy ID
            limit: 返回广告条目数. 默认为 1, 最大为 10.
            os: 请求的客户端系统信息. Android 为 1, iOS 为 2.
            user_id: 唯一标识用户的 ID
        '''
        global ad_logger

        # 请求的 Category ID
        category_id = self.get_argument('cid', None)
        user_id = self.get_argument('uid', None)
        session_id = self.get_argument('sid', None)
        # Android: 1, iOS: 2
        # 默认为 Android
        os = self.get_argument('os', '1')
        # 请求的广告数目. 最大值为10, 默认值为 1
        limit = self.get_argument('lmt', self.__class__.__default_limit)
        limit = int(limit)
        limit = min(self.__class__.__max_limit, limit)

        ip = self.request.headers.get('clientip', None)
        ua = self.request.headers.get('User-Agent', None)
        errno = 0  # 错误码
        msg = ''  # 错误信息
        ad_data = None  # 广告数据
        if None in (category_id, user_id, session_id):  # 非法请求
            msg = 'invalid request'
            errno = 1
        else:
            ret = self.__limit_permit(user_id)
            if ret == 1:  # 不可访问
                errno = 4
                msg = 'qps forbid'
            elif ret == 0:  # 可以访问
                try:
                    ad_data = ydn.request(ip, ua, category_id, limit, os)
                except Exception as e:
                    errno = 3
                    msg = str(e)
                else:
                    if ad_data:
                        msg = 'ok'
                        errno = 0
                    else:
                        msg = 'invalid response'
                        errno = 3
#            elif ret == 2:  # 用户达到上限
#                errno = 5
#                msg = 'user limit'
#            elif ret == 3:  # 达到用户流量上限
#                errno = 6
#                msg = 'user stream limit'
            else:
                errno = 2
                msg = 'unknown error'

        # 每次请求都会记一条日志
        title, desc, url = None, None, None
        if errno == 0:
            title = ad_data["ads"][0]["title"].encode("utf-8")
            desc = ad_data["ads"][0]["description"].encode("utf-8")
            url = ad_data["ads"][0]["url"]

        log_string = ('ip={ip}\tuid={uid}\tsid={sid}\tcid={cid}\tlmt={lmt}\t'
                      'os={os}\terrno={errno}\tmsg={msg}\trt={rt:.3f}\t'
                      'ua={ua}\ttit={tit}\tdesc={desc}').format(
                ip=ip,
                uid=user_id,
                sid=session_id,
                cid=category_id,
                lmt=limit,
                os=os,
                errno=errno,
                msg=msg,
                rt=1000.0 * self.request.request_time(),
                ua=ua,
                tit=title,
                desc=desc,
        )
        ad_logger.info(log_string)

        # 返回结果
        ret_json = libs.utils.compose_ret(errno, ad_data)
        self.write(ret_json)

    def __limit_permit(self, user_id):
        # 发送用户 id
        try:
            self.__class__.__udp_socket.sendto(
                user_id, self.__class__.__limit_server_addr)
            data, addr = self.__class__.__udp_socket.recvfrom(1024)
        except Exception as e:
            ad_logger.error('info={}'.format(str(e)))
            return 0

        return int(data)
Пример #31
0
import redis
import sys

import libs


class KeywordMap(redis.StrictRedis):
    def __init__(self):
        conf = libs.get_config()
        host = conf.get('redis', 'host')
        port = conf.getint('redis', 'port')
        db_index = conf.getint('redis', 'db_index')
        self.__prefix = conf.get('redis', 'prefix')
        self.__super = super(KeywordMap, self)
        self.__super.__init__(host=host, port=port, db=db_index)

    def set(self, key, value):
        return self.__super.set(self.__prefix + key, value)

    def get(self, key):
        return self.__super.get(self.__prefix + key)


if __name__ == '__main__':
    conf = libs.get_config()
    km = KeywordMap()
    for line in sys.stdin:
        line = line.strip()
        key, value = line.split("\t")
        print key, value, km.set(key, value)
Пример #32
0
def start_func(config):
    from global_constants import data_process_func
    from global_constants import ModelEnums, DatasetEnums, TrainModesEnums, ConfigEnums
    me, de, tme, ce = ModelEnums, DatasetEnums, TrainModesEnums, ConfigEnums
    config = {ce[k]: v for k, v in config.items() if k in ce.__members__}
    # print(config)
    mode = tme[get_config(config, ce.mode)]
    fields = mode.value.fields
    con = {k: get_config(config, k) for k in fields}
    # print(con)
    model_type = me[con[ce.model]]
    load_path = get_config(con, ce.load_path)
    save_path = get_config(con, ce.save_path)

    if save_path is not None:
        if save_path[-1] != '/':
            save_path += '/'
        log_path = list(os.path.split(save_path)[:-1])
        log_path.append('log/')
        log_path = '/'.join(log_path)
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        initial_loggers(log_path)

    prepare_logger, cuda_logger, final_logger = loggers.prepare_logger, loggers.cuda_logger, loggers.final_logger
    json_encoder = json.JSONEncoder(ensure_ascii=False, indent=2)
    log_info(
        prepare_logger, 'config loaded:\n' +
        json_encoder.encode({k.name: v
                             for k, v in con.items()}))

    log_info(prepare_logger, 'loading models: ' + load_path)

    tok = tfm.GPT2Tokenizer.from_pretrained(load_path)
    log_info(prepare_logger, 'model loaded')
    log_info(cuda_logger,
             'avaliable cudas {}'.format(torch.cuda.device_count()))
    # log_info(prepare_logger, 'start training:\n\tepochs: {}\n\tbatch_len: {}\n\tbatch_size: {}'.format(
    #     con[ce.epochs], con[ce.batch_len], con[ce.batch_size]))

    # gpu = GPUtil.getGPUs()[0]
    # log_info(cuda_logger, 'GPU Free {} Used {} Total {}'.format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryTotal))
    log_info(cuda_logger, 'Start cuda memory {}'.format(cuda_mem_in_mb()))
    log_info(cuda_logger, 'Allocated model {}'.format(cuda_mem_in_mb()))
    model = model_type.value.from_pretrained(load_path)

    dataset_type = de[con[ce.dataset_type]]
    dataset_class = dataset_type.value.class_type
    con[ce.data_func] = data_process_func[mode][model_type] \
        [dataset_type](max_len=con[ce.max_len], batch_size=con[ce.batch_size] if ce.batch_size in con else 1)
    con[ce.dataset_type] = dataset_class
    con[ce.tokenizer] = tok
    con[ce.model] = model
    if ce.gpt2 in con:
        con[ce.gpt2] = tfm.GPT2LMHeadModel.from_pretrained(con[ce.gpt2])
    method = mode.value.func

    con[ce.idx_file] = open(con[ce.idx_path], 'r')
    if ce.ent_file in dataset_type.value.fields:
        con[ce.ent_file] = open(con[ce.ent_path], 'r')
    if ce.sent_file in dataset_type.value.fields:
        con[ce.sent_file] = open(con[ce.sent_path], 'r')

    dataset_parameters = {k.name: con[k] for k in dataset_type.value.fields}
    ids = con[ce.ids]
    if ids == '':
        ids = None

    if ids is not None:
        with open(ids, 'r') as f:
            ids = json.load(f)
        ids = np.array_split(ids, con[ce.loaders])
        ids = [x.tolist() for x in ids]
    loaders = []
    for i in range(con[ce.loaders]):
        dataset_parameters[ce.ids] = ids[i]
        loaders.append(dataset_type(**dataset_parameters))

    first_len = loaders[0].get_loaded_length()[0]
    all_len = sum([x.get_loaded_length()[0] for x in loaders])
    dataset_parameters[ce.ids] = list(
        range(all_len, all_len + con[ce.eval_len] * first_len))
    con[ce.eval_set] = dataset_type(**dataset_parameters)

    for i in range(con[ce.loaders]):
        new_con = dict(con)
        new_con[ce.dataset] = loaders[i]
        if new_con[ce.dataset] is None:
            break
        new_con[ce.epoch_iter] = len(new_con[ce.dataset]) // (
            new_con[ce.batch_size] if ce.batch_size in new_con else 1)
        new_model, loss = method(new_con, i)
        con[ce.model] = new_model
        con[ce.prev_eval_loss] = loss
Пример #33
0
                logger.exception(message)
            else:
                logger.error(message)
                logger.error(exc)

        logger.debug("Sleeping for %d seconds" % sleep)
        time.sleep(sleep)


if __name__ == "__main__":
    signal.signal(signal.SIGINT,
                  lambda a, b: sys.exit(0))  # Properly handle Control+C

    # Get CLI args/options and parse config file.
    try:
        main_config = libs.get_config(docopt(__doc__, version=__version__))
    except libs.MultipleConfigSources.ConfigError as e:
        print("ERROR: %s" % e, file=sys.stderr)
        sys.exit(1)

    # Initialize logging.
    umask = 0o027
    os.umask(umask)
    with libs.LoggingSetup(main_config['verbose'], main_config['log'],
                           main_config['quiet']) as cm:
        try:
            logging.config.fileConfig(cm.config)  # Setup logging.
        except IOError:
            print("ERROR: Unable to write to file %s" % main_config['log'],
                  file=sys.stderr)
            sys.exit(1)
def test_config_file_only_with_nonexistent_file():
    argv = ["-c", "/tmp/doesNotExist.28520"]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file /tmp/doesNotExist.28520 does not exist, not a file, or no permission." == str(e.value)
def test_config_file_only_with_no_read_permissions():
    argv = ["-c", "/etc/sudoers"]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Unable to read config file /etc/sudoers." == str(e.value)
def test_config_file_only_with_directory_instead_of_file():
    argv = ["-c", "/etc"]
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file /etc does not exist, not a file, or no permission." == str(e.value)
def test_config_file_only_missing_log_value(config_file):
    config_file.write("domain: mydomain.com\nuser: thisuser\npasswd: abc\nlog: #True\n")
    config_file.flush()
    argv = ["-c", config_file.name]
    config = libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert None == config["log"]
Пример #38
0
            if config['verbose']:
                logger.exception(message)
            else:
                logger.error(message)
                logger.error(exc)

        logger.debug("Sleeping for %d seconds" % sleep)
        time.sleep(sleep)


if __name__ == "__main__":
    signal.signal(signal.SIGINT, lambda a, b: sys.exit(0))  # Properly handle Control+C

    # Get CLI args/options and parse config file.
    try:
        main_config = libs.get_config(docopt(__doc__, version=__version__))
    except libs.MultipleConfigSources.ConfigError as e:
        print("ERROR: %s" % e, file=sys.stderr)
        sys.exit(1)

    # Initialize logging.
    umask = 0o027
    os.umask(umask)
    with libs.LoggingSetup(main_config['verbose'], main_config['log'], main_config['quiet']) as cm:
        try:
            logging.config.fileConfig(cm.config)  # Setup logging.
        except IOError:
            print("ERROR: Unable to write to file %s" % main_config['log'], file=sys.stderr)
            sys.exit(1)
    sys.excepthook = lambda t, v, b: logging.critical("Uncaught exception!", exc_info=(t, v, b))  # Log exceptions.
    atexit.register(lambda: logging.info("%s pid %d shutting down." % (__program__, os.getpid())))  # Log when exiting.
Пример #39
0
Файл: dmp.py Проект: npiaq/dmp
# -*- encoding: utf-8 -*-
'''
DMP (Data Manager Platform)
有如下几个概念:
text    输入文本. 还没有切分成词的列表.
doc     文档. 已经通过 split 函数切分成词的列表了.
bow     bag-of-word. 词袋, 已经切分并去重统计词频的列表, 格式为 (词, 词频).
'''
from gensim.corpora import Dictionary
from gensim.models import LdaModel
import sys
import codecs
import libs


config = libs.get_config()


class DMP(object):

    def __init__(self):
        self.dic = None
        self.lda = None
        self.topic_num = config.getint('dmp', 'topic_num')
        self.corpus_file = config.get('dmp', 'corpus_file')

    @staticmethod
    def __text2doc(iterator, sep=u' '):
        '''将文本转换为文档
        通过 split 函数将文本切成词的列表.
Пример #40
0
def test_config_file_only_with_directory_instead_of_file():
    argv = ['-c', '/etc']
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file /etc does not exist, not a file, or no permission." == str(
        e.value)
Пример #41
0
def test_config_file_only_with_no_read_permissions():
    argv = ['-c', '/etc/sudoers']
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Unable to read config file /etc/sudoers." == str(e.value)
Пример #42
0
def test_config_file_only_with_nonexistent_file():
    argv = ['-c', '/tmp/doesNotExist.28520']
    with pytest.raises(libs.MultipleConfigSources.ConfigError) as e:
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))
    assert "Config file /tmp/doesNotExist.28520 does not exist, not a file, or no permission." == str(
        e.value)
Пример #43
0
def test_cli_invalid_options():
    argv = ['-n', 'test.com', '-p', 'testpw', '-u', 'testuser', '-d', 'shouldBeFlag']
    with pytest.raises(SystemExit):
        libs.get_config(docopt(uddns_doc, version=uddns_ver, argv=argv))