Example #1
0
def prepare_config(params):
    # Experiment configuration
    opts = [(k.replace('/', '.'), str(v[0])) for (k, v) in params.items()]
    config = Configuration('parser_config.yaml',
                           '$LVSR/lvsr/configs/schema.yaml', opts)
    logger.info("Config:\n" + pprint.pformat(config, width=120))
    return config
Example #2
0
def prepare_config(cmd_args):
    # Experiment configuration
    original_cmd_args = dict(cmd_args)
    config = Configuration(cmd_args.pop('config_path'),
                           '$LVSR/lvsr/configs/schema.yaml',
                           cmd_args.pop('config_changes'))
    config['cmd_args'] = original_cmd_args
    logger.info("Config:\n" + pprint.pformat(config, width=120))
    return config
Example #3
0
def prepare_config(cmd_args):
    # Experiment configuration
    original_cmd_args = dict(cmd_args)
    config_path = cmd_args.pop('config_path')
    if (config_path.lower() == 'none') and (cmd_args['load_path']):
        logger.info("Loading config from piclke!")
        v = numpy.load(cmd_args['load_path'])
        config_path = pickle.loads("".join(
            numpy.frombuffer(v['_config_pickle'], 'S1')))
    config = Configuration(
        config_path, '$LVSR/lvsr/configs/schema.yaml'
        if cmd_args.pop("validate_config") else None,
        cmd_args.pop('config_changes'))
    config['cmd_args'] = original_cmd_args
    logger.info("Config:\n" + pprint.pformat(config, width=120))
    return config
Example #4
0
def get_parser(load_path,
               decoder_type,
               lang,
               tag_char=None,
               mask_path=None,
               **params):
    logger.info("Loading config from piclke!")
    v = numpy.load(load_path)
    config_path = extract_from_data(v['_config_pickle'])

    config = Configuration(config_path, '$LVSR/lvsr/configs/schema.yaml', {})

    if 'input_languages' in config['data']:
        langs = config['data'].pop('input_languages')
    else:
        langs = ['default']

    if tag_char is not None:
        langs_tags = {k: tag_char for k in langs}
    else:
        langs_tags = {k: str(i) for i, k in enumerate(langs)}

    if lang is None:
        lang_id = 0
    else:
        if lang not in langs:
            raise ValueError('Wrong language {}. Available are: {}'.format(
                lang, repr(langs)))
        lang_id = langs.index(lang)
    lang = langs[lang_id]

    if '_dataset_pickle' in v and '_postfix_pickle' in v:
        info_dataset = extract_from_data(v['_dataset_pickle'])
        postfix_manager = extract_from_data(v['_postfix_pickle'])
    else:
        data = MultilangData(langs, **config['data'])
        net_config = dict(config["net"])
        addidional_sources = ['labels']
        if 'additional_sources' in net_config:
            addidional_sources += net_config['additional_sources']
        data.default_sources = net_config['input_sources'] + addidional_sources
        info_dataset = data.info_dataset
        postfix_manager = data.postfix_manager

    logger.info("Recognizer initialization started")
    multi_recognizer = create_recognizer(config, config['net'], langs,
                                         info_dataset, postfix_manager,
                                         load_path, mask_path)
    recognizer = multi_recognizer.children[lang_id]
    recognizer.init_beam_search(0)
    logger.info("Recognizer is initialized")

    required_inputs = recognizer.inputs.keys()

    def parse_sentences(sentences, decoder_type=decoder_type):
        remapped_sentences = remap_sentences(
            sentences, required_inputs, info_dataset,
            postfix_manager.get_lang_postfix(lang), langs_tags[lang])
        try:
            outputs, cost, pos_tags = recognizer.beam_search_multi(
                remapped_sentences,
                decoder_type=decoder_type,
                full_pointers=False,
                validate_solution_function=getattr(info_dataset,
                                                   'validate_solution', None))
        except CandidateNotFoundError:
            outputs = ()
        return [
            output_conll(info_dataset, postfix_manager, sentence, lang,
                         outputs[0][i][1:-1], outputs[1][i][1:-1],
                         pos_tags[i][1:-1])
            for i, sentence in enumerate(sentences)
        ]

    return parse_sentences