コード例 #1
0
def load_fairseq_lm_model_and_dict(checkpoint_path, data_path):
    # Initialize model
    parser = options.get_eval_lm_parser()
    parsed_args = options.parse_args_and_arch(parser, ['--path', checkpoint_path, data_path])
    task = tasks.setup_task(parsed_args)
    models, _ = utils.load_ensemble_for_inference([checkpoint_path], task)
    return models[0], task.dictionary
コード例 #2
0
def cli_main():
    parser = options.get_eval_lm_parser()
    parser.add_argument("--log2",
                        action='store_true',
                        help="Fairseq defaults to natural log")
    args = options.parse_args_and_arch(parser)
    main(args)
コード例 #3
0
ファイル: visualize.py プロジェクト: Zhong-Zhang/fairseq
def cli_main(path, layers):
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    args.max_sentences = 2
    args.tokens_per_sample = 512
    args.context_window = 400
    if path is not None:
        args.path = path

    args.cpu = True
    # args.num_shards = 100

    gl._init()
    gl.set_value('visualize', True)
    gl.set_value('attn_weights', [0 for _ in range(layers)])
    gl.set_value('attn_weight_count', [0 for _ in range(layers)])
    gl.set_value('current_layer', 0)

    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)

    attn_weight_count = gl.get_value('attn_weight_count')
    attn_weights = gl.get_value('attn_weights')

    attn_weights = [
        x / attn_weight_count[idx] for idx, x in enumerate(attn_weights)
    ]
    torch.save(attn_weights, args.path + "." + 'svd')

    print('attn_weight_counts: ', attn_weight_count[0])
コード例 #4
0
ファイル: eval_lm.py プロジェクト: walkoncross/fairseq
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)

    # only override args that are explicitly given on the command line
    override_parser = options.get_validation_parser()
    override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)

    distributed_utils.call_main(args, main, override_args=override_args)
コード例 #5
0
def eval_lm_main(data_dir):
    eval_lm_parser = options.get_eval_lm_parser()
    eval_lm_args = options.parse_args_and_arch(
        eval_lm_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--no-progress-bar',
        ],
    )
    eval_lm.main(eval_lm_args)
コード例 #6
0
ファイル: test_binaries.py プロジェクト: fyabc/fairseq
def eval_lm_main(data_dir):
    eval_lm_parser = options.get_eval_lm_parser()
    eval_lm_args = options.parse_args_and_arch(
        eval_lm_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--no-progress-bar',
        ],
    )
    eval_lm.main(eval_lm_args)
コード例 #7
0
    def __init__(self, model_path, dict_path):
        parser = options.get_eval_lm_parser()
        parsed_args = options.parse_args_and_arch(parser,
                                                  input_args=[None],
                                                  parse_known=True)[0]
        parsed_args.path = model_path
        parsed_args.dict = dict_path
        parsed_args.max_sentence = 1
        parsed_args.gen_subset = 'test'
        parsed_args.raw_text = True
        parsed_args.no_progress_bar = True
        import_user_module(parsed_args)
        print(parsed_args)

        task = tasks.setup_task(parsed_args)
        print('| loading model(s) from {}'.format(parsed_args.path))
        models, args = utils.load_ensemble_for_inference(
            parsed_args.path.split(':'),
            task,
            model_arg_overrides=eval(parsed_args.model_overrides),
        )
        for arg in vars(parsed_args).keys():
            if arg not in {
                    'self_target', 'future_target', 'past_target',
                    'tokens_per_sample', 'output_size_dictionary'
            }:
                setattr(args, arg, getattr(parsed_args, arg))
        task = tasks.setup_task(args)

        self.use_cuda = torch.cuda.is_available() and not parsed_args.cpu
        for model in models:
            model.make_generation_fast_()
            if self.use_cuda:
                model.cuda()
        assert len(models) > 0

        scorer = SequenceScorer(task.target_dictionary)

        self.args = args
        self.task = task
        self.models = models
        self.scorer = scorer
コード例 #8
0
def cli_main(path, model_overrides, name):
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    args.max_sentences = 2
    args.tokens_per_sample = 512
    args.context_window = 400
    # args.cpu = True
    # args.num_shards = 300

    gl._init()
    gl.set_value('visualize', True)
    gl.set_value('attn_weight_layers', [0 for _ in range(16)])
    gl.set_value('attn_weight_heads', [0 for _ in range(8)])
    gl.set_value('attn_weight_counts', [0 for _ in range(16)])
    gl.set_value('current_layer', 0)

    args.path = path
    args.model_overrides = model_overrides

    distributed_utils.call_main(args, main)

    if name == 'layer':
        attn_weight_layers = gl.get_value('attn_weight_layers')
        attn_weight_counts = gl.get_value('attn_weight_counts')

        attn_weight_layers = [
            x / attn_weight_counts[idx]
            for idx, x in enumerate(attn_weight_layers)
        ]
        torch.save(attn_weight_layers, path + "." + name)
    elif name == 'head':
        attn_weight_heads = gl.get_value('attn_weight_heads')
        attn_weight_counts = gl.get_value('attn_weight_counts')

        attn_weight_heads = [
            x / attn_weight_counts[idx]
            for idx, x in enumerate(attn_weight_heads)
        ]
        torch.save(attn_weight_heads, path + "." + name)
    else:
        exit(0)
コード例 #9
0
ファイル: eval_lm.py プロジェクト: mbevila/qbert
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                        for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)


if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
コード例 #10
0
ファイル: eval_lm.py プロジェクト: kdrivas/syl_nmt_fairseq
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)
コード例 #11
0
ファイル: rerank_utils.py プロジェクト: skeshaw/LoReNMT
def lm_scoring(preprocess_directory, bpe_status, gen_output, pre_gen,
               cur_lm_dict, cur_lm_name, cur_language_model, cur_lm_bpe_code,
               batch_size, lm_score_file, target_lang, source_lang, prefix_len=None):
    if prefix_len is not None:
        assert bpe_status == "different", "bpe status must be different to use prefix len"
    if bpe_status == "no bpe":
        # run lm on output without bpe
        write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
                          gen_output.no_bpe_target, pre_gen+"/rescore_data_no_bpe.de",
                          pre_gen+"/rescore_data_no_bpe.en", pre_gen+"/reference_file_no_bpe")

        preprocess_lm_param = ["--only-source",
                               "--trainpref", pre_gen+"/rescore_data_no_bpe."+target_lang,
                               "--srcdict", cur_lm_dict,
                               "--destdir", preprocess_directory]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [preprocess_directory,
                         "--path", cur_language_model,
                         "--output-word-probs",
                         "--batch-size", str(batch_size),
                         "--max-tokens", "1024",
                         "--sample-break-mode", "eos",
                         "--gen-subset", "train"]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, 'w') as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)

    elif bpe_status == "shared":
            preprocess_lm_param = ["--only-source",
                                   "--trainpref", pre_gen+"/rescore_data."+target_lang,
                                   "--srcdict", cur_lm_dict,
                                   "--destdir", preprocess_directory]
            preprocess_parser = options.get_preprocessing_parser()
            input_args = preprocess_parser.parse_args(preprocess_lm_param)
            preprocess.main(input_args)

            eval_lm_param = [preprocess_directory,
                             "--path", cur_language_model,
                             "--output-word-probs",
                             "--batch-size", str(batch_size),
                             "--sample-break-mode", "eos",
                             "--gen-subset", "train"]

            eval_lm_parser = options.get_eval_lm_parser()
            input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

            with open(lm_score_file, 'w') as f:
                with redirect_stdout(f):
                    eval_lm.main(input_args)

    elif bpe_status == "different":
        rescore_file = pre_gen+"/rescore_data_no_bpe"
        rescore_bpe = pre_gen+"/rescore_data_new_bpe"

        rescore_file += "."
        rescore_bpe += "."

        write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
                          gen_output.no_bpe_target, rescore_file+source_lang,
                          rescore_file+target_lang, pre_gen+"/reference_file_no_bpe",
                          bpe_symbol=None)

        # apply LM bpe to nbest list
        bpe_src_param = ["-c", cur_lm_bpe_code,
                         "--input", rescore_file+target_lang,
                         "--output", rescore_bpe+target_lang]
        subprocess.call(["python",
                         os.path.join(os.path.dirname(__file__),
                                      "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param,
                        shell=False)
        # uncomment to use fastbpe instead of subword-nmt bpe
        # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
        # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)

        preprocess_dir = preprocess_directory

        preprocess_lm_param = ["--only-source",
                               "--trainpref", rescore_bpe+target_lang,
                               "--srcdict", cur_lm_dict,
                               "--destdir", preprocess_dir]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [preprocess_dir,
                         "--path", cur_language_model,
                         "--output-word-probs",
                         "--batch-size", str(batch_size),
                         "--max-tokens", "1024",
                         "--sample-break-mode", "eos",
                         "--gen-subset", "train"]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, 'w') as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)
コード例 #12
0
ファイル: eval_lm.py プロジェクト: insop/pytorch-hackathon
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
コード例 #13
0
ファイル: eval_lm.py プロジェクト: harveenchadha/fairseq
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)

    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
コード例 #14
0
def get_task_args():
    parser = options.get_eval_lm_parser()
    return options.parse_args_and_arch(parser)
コード例 #15
0
ファイル: eval_lm.py プロジェクト: fyabc/fairseq
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)


if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    main(args)