예제 #1
0
def generate_main(data_dir, extra_flags=None):
    generate_parser = options.get_generation_parser()
    generate_args = options.parse_args_and_arch(
        generate_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--beam', '3',
            '--batch-size', '64',
            '--max-len-b', '5',
            '--gen-subset', 'valid',
            '--no-progress-bar',
            '--print-alignment',
        ] + (extra_flags or []),
    )

    # evaluate model in batch mode
    generate.main(generate_args)

    # evaluate model interactively
    generate_args.buffer_size = 0
    generate_args.max_sentences = None
    orig_stdin = sys.stdin
    sys.stdin = StringIO('h e l l o\n')
    interactive.main(generate_args)
    sys.stdin = orig_stdin
예제 #2
0
def eval_lm_main(data_dir):
    eval_lm_parser = options.get_eval_lm_parser()
    eval_lm_args = options.parse_args_and_arch(
        eval_lm_parser,
        [
            data_dir,
            '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
            '--no-progress-bar',
        ],
    )
    eval_lm.main(eval_lm_args)
예제 #3
0
def train_translation_model(data_dir, arch, extra_flags=None):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task', 'translation',
            data_dir,
            '--save-dir', data_dir,
            '--arch', arch,
            '--optimizer', 'nag',
            '--lr', '0.05',
            '--max-tokens', '500',
            '--max-epoch', '1',
            '--no-progress-bar',
            '--distributed-world-size', '1',
            '--source-lang', 'in',
            '--target-lang', 'out',
        ] + (extra_flags or []),
    )
    train.main(train_args)
예제 #4
0
def train_language_model(data_dir, arch):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            '--task', 'language_modeling',
            data_dir,
            '--arch', arch,
            '--optimizer', 'nag',
            '--lr', '1.0',
            '--criterion', 'adaptive_loss',
            '--adaptive-softmax-cutoff', '5,10,15',
            '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
            '--decoder-embed-dim', '280',
            '--max-tokens', '500',
            '--tokens-per-sample', '500',
            '--save-dir', data_dir,
            '--max-epoch', '1',
            '--no-progress-bar',
            '--distributed-world-size', '1',
        ],
    )
    train.main(train_args)
예제 #5
0

def add_multitask_args(parser):
    mtt_group = parser.add_argument_group("Multitask related arguments")
    mtt_group.add_argument("--freeze-embeddings",
                           action="store_true",
                           help="Freeze word embeddings when finetuning")
    mtt_group.add_argument("--freeze-decoder",
                           action="store_true",
                           help="Freeze decoder when finetuning")


if __name__ == '__main__':
    parser = options.get_training_parser()
    add_multitask_args(parser)
    args = options.parse_args_and_arch(parser)

    if args.distributed_port > 0 or args.distributed_init_method is not None:
        raise NotImplementedError(
            "Multitask doesn't support multiprocessing yet")
        from distributed_train import main as distributed_main

        distributed_main(args)
    elif args.distributed_world_size > 1:
        raise NotImplementedError(
            "Multitask doesn't support multiprocessing yet")
        from multiprocessing_train import main as multiprocessing_main

        multiprocessing_main(args)
    else:
        main(args)
예제 #6
0
def main():
    parser = get_parser_with_args()
    args = options.parse_args_and_arch(parser)
    validate_args(args)
    generate(args)
예제 #7
0
def cli_main():
    parser = rerank_options.get_reranking_parser()
    args = options.parse_args_and_arch(parser)
    gen_and_reprocess_nbest(args)
예제 #8
0
def cli_main():
    parser = options.get_generation_parser(default_task="speech_recognition_espresso")
    args = options.parse_args_and_arch(parser)
    assert args.results_path is not None, "please specify --results-path"
    main(args)
예제 #9
0
def lm_scoring(preprocess_directory,
               bpe_status,
               gen_output,
               pre_gen,
               cur_lm_dict,
               cur_lm_name,
               cur_language_model,
               cur_lm_bpe_code,
               batch_size,
               lm_score_file,
               target_lang,
               source_lang,
               prefix_len=None):
    if prefix_len is not None:
        assert bpe_status == "different", "bpe status must be different to use prefix len"
    if bpe_status == "no bpe":
        # run lm on output without bpe
        write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo,
                          gen_output.no_bpe_target,
                          pre_gen + "/rescore_data_no_bpe.de",
                          pre_gen + "/rescore_data_no_bpe.en",
                          pre_gen + "/reference_file_no_bpe")

        preprocess_lm_param = [
            "--only-source", "--trainpref",
            pre_gen + "/rescore_data_no_bpe." + target_lang, "--srcdict",
            cur_lm_dict, "--destdir", preprocess_directory
        ]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [
            preprocess_directory, "--path", cur_language_model,
            "--output-word-probs", "--batch-size",
            str(batch_size), "--max-tokens", "1024", "--sample-break-mode",
            "eos", "--gen-subset", "train"
        ]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, 'w') as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)

    elif bpe_status == "shared":
        preprocess_lm_param = [
            "--only-source", "--trainpref",
            pre_gen + "/rescore_data." + target_lang, "--srcdict", cur_lm_dict,
            "--destdir", preprocess_directory
        ]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [
            preprocess_directory, "--path", cur_language_model,
            "--output-word-probs", "--batch-size",
            str(batch_size), "--sample-break-mode", "eos", "--gen-subset",
            "train"
        ]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, 'w') as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)

    elif bpe_status == "different":
        rescore_file = pre_gen + "/rescore_data_no_bpe"
        rescore_bpe = pre_gen + "/rescore_data_new_bpe"

        rescore_file += "."
        rescore_bpe += "."

        write_reprocessed(gen_output.no_bpe_source,
                          gen_output.no_bpe_hypo,
                          gen_output.no_bpe_target,
                          rescore_file + source_lang,
                          rescore_file + target_lang,
                          pre_gen + "/reference_file_no_bpe",
                          bpe_symbol=None)

        # apply LM bpe to nbest list
        bpe_src_param = [
            "-c", cur_lm_bpe_code, "--input", rescore_file + target_lang,
            "--output", rescore_bpe + target_lang
        ]
        subprocess.call([
            "python",
            os.path.join(os.path.dirname(__file__),
                         "subword-nmt/subword_nmt/apply_bpe.py")
        ] + bpe_src_param,
                        shell=False)
        # uncomment to use fastbpe instead of subword-nmt bpe
        # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
        # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)

        preprocess_dir = preprocess_directory

        preprocess_lm_param = [
            "--only-source", "--trainpref", rescore_bpe + target_lang,
            "--srcdict", cur_lm_dict, "--destdir", preprocess_dir
        ]
        preprocess_parser = options.get_preprocessing_parser()
        input_args = preprocess_parser.parse_args(preprocess_lm_param)
        preprocess.main(input_args)

        eval_lm_param = [
            preprocess_dir, "--path", cur_language_model,
            "--output-word-probs", "--batch-size",
            str(batch_size), "--max-tokens", "1024", "--sample-break-mode",
            "eos", "--gen-subset", "train"
        ]

        eval_lm_parser = options.get_eval_lm_parser()
        input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)

        with open(lm_score_file, 'w') as f:
            with redirect_stdout(f):
                eval_lm.main(input_args)
예제 #10
0
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #11
0
def train_main(alpha, beta, save_path):
    parser = options.get_training_parser()
    input_args = [
        data_set, '--share-decoder-input-output-embed', '--arch',
        'transformer_iwslt_de_en', '--max-tokens', '4000', '--lr', '5e-4',
        '--save-interval', '2', '--max-epoch', '85', '--patience', '5',
        '--optimizer', 'adam', '--adam-betas', '(0.9, 0.98)', '--clip-norm',
        '0.0', '--weight-decay', '0.0001', '--dropout', '0.3',
        '--lr-scheduler', 'inverse_sqrt', '--warmup-updates', '4000',
        '--keep-last-epochs', '4', '--criterion', 'jensen_cross_entropy',
        '--alpha',
        str(alpha), '--beta',
        str(beta), '--use-uniform', '--fp16', '--save-dir', save_path
    ]

    args = options.parse_args_and_arch(parser, input_args=input_args)
    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)

    if args.distributed_init_method is not None:
        # distributed training
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
    elif args.distributed_world_size > 1:
        # fallback for single node with multiple GPUs
        assert args.distributed_world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(
            port=port)
        args.distributed_rank = None  # set based on device id
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print(
                '| NOTE: you may get better performance with: --ddp-backend=no_c10d'
            )
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
    else:
        # single GPU training
        main(args)

    ckpts = os.listdir(args.save_dir)
    try:
        ckpts.remove('checkpoint_last.pt')
    except ValueError:
        print("no checkpoint_last.pt in folder", args.save_dir)

    f = open(os.path.join(args.save_dir, "final_entropies.txt"), "a+")
    results = {}
    entropies = {}
    for ckpt in ckpts:
        if '.pt' in ckpt:
            path = os.path.join(args.save_dir, ckpt)
            f.write(path + '\n')
            run_generation(path, results, entropies)

            f.write('{entropy: ' + str(entropies[path]) + ', bleu: ' +
                    str(results[path]) + '}\n')

    f.close()
    return results
예제 #12
0
def run_generation(ckpt, results, ents):
    gen_parser = options.get_generation_parser()
    args = options.parse_args_and_arch(gen_parser,
                                       input_args=[
                                           data_set, '--gen-subset', 'valid',
                                           '--path', ckpt, '--beam', '10',
                                           '--max-tokens', '4000',
                                           '--sacrebleu', '--remove-bpe',
                                           '--log-format', 'none'
                                       ])

    use_cuda = torch.cuda.is_available() and not args.cpu
    # if use_cuda:
    #     lock.acquire()
    #     torch.cuda.set_device(device_id)
    #     lock.release()

    utils.import_user_module(args)
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    entropies = []
    token_counts = []
    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            if 'avg_ent' in sample:
                entropies.append(sample['avg_ent'][0])
                token_counts.append(sample['avg_ent'][1])

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)
                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            num_sentences += sample['nsentences']

    results[ckpt] = scorer.score()
    ents[ckpt] = sum(entropies) / sum(token_counts)
    def from_checkpoint(self,
                        checkpoint,
                        roberta_cache_path=None,
                        inspector=None):
        '''
        Initialize model from checkpoint
        '''

        # load fairseq task
        parser = options.get_interactive_generation_parser()
        options.add_optimization_args(parser)
        args = options.parse_args_and_arch(parser, input_args=['--data dummy'])

        # Read extra arguments
        model_folder = os.path.dirname(checkpoint.split(':')[0])
        # config with fairseq-preprocess and fairseq-train args
        config_json = f'{model_folder}/config.json'
        assert os.path.isfile(config_json), \
            "Model trained with v0.3.0 or above?"
        with open(config_json) as fid:
            extra_args = json.loads(fid.read())
        prepro_args = extra_args['fairseq_preprocess_args']
        train_args = extra_args['fairseq_train_args']
        # extra args by hand
        args.source_lang = 'en'
        args.target_lang = 'actions'
        args.path = checkpoint
        args.roberta_cache_path = roberta_cache_path
        dim = train_args['--pretrained-embed-dim'][0]
        args.model_overrides = \
            "{'pretrained_embed_dim':%s, 'task': 'translation'}" % dim
        assert bool(args.left_pad_source), "Only left pad supported"

        # dictionaries
        src_dict_path = f'{model_folder}/dict.{args.source_lang}.txt'
        tgt_dict_path = f'{model_folder}/dict.{args.target_lang}.txt'
        assert os.path.isfile(src_dict_path), \
            f"Missing {src_dict_path}.\nModel trained with v0.3.0 or above?"\
            "\ncheck scripts/stack-transformer/update_model_to_v0.3.0.sh"
        assert os.path.isfile(tgt_dict_path), \
            f"Missing {tgt_dict_path}.\nModel trained with v0.3.0 or above?"\
            "\ncheck scripts/stack-transformer/update_model_to_v0.3.0.sh"
        src_dict = Dictionary.load(src_dict_path)
        tgt_dict = Dictionary.load(tgt_dict_path)

        use_cuda = torch.cuda.is_available() and not args.cpu

        # Override task to ensure compatibility with old models and overide
        # TODO: Task may not be even needed
        task = TranslationTask(args, src_dict, tgt_dict)
        model = load_models(args, task, use_cuda)

        # Load RoBERTa
        embeddings = PretrainedEmbeddings(
            name=prepro_args['--pretrained-embed'][0],
            bert_layers=[int(x) for x in prepro_args['--bert-layers']]
            if '--bert-layers' in prepro_args else None,
            model=load_roberta(name=prepro_args['--pretrained-embed'][0],
                               roberta_cache_path=args.roberta_cache_path,
                               roberta_use_gpu=use_cuda))

        print("Finished loading models")

        # State machine variables
        machine_rules = f'{model_folder}/train.rules.json'
        assert os.path.isfile(machine_rules), f"Missing {machine_rules}"
        machine_type = prepro_args['--machine-type'][0]

        return self(model,
                    machine_rules,
                    machine_type,
                    src_dict,
                    tgt_dict,
                    use_cuda,
                    embeddings=embeddings,
                    inspector=inspector)
예제 #14
0
def generate_from_script(list_args):
    parser = options.get_generation_parser()
    group = parser.add_argument_group('Generation output')
    group.add_argument('--decode-dir',
                       metavar='DIR',
                       default='outputs',
                       help='path to save predictions')
    group.add_argument('--reference-dir',
                       metavar='DIR',
                       default='outputs/reference/valid',
                       help='path to save predictions')
    group.add_argument('--usekeys',
                       action='store_true',
                       help='whether to use target key prediction')
    group.add_argument(
        '--context',
        action='store_true',
        help=
        'whether to use previous sentences as context for current sentence decoding'
    )
    group.add_argument(
        '--ngram',
        type=int,
        default=0,
        help='whether to use hard constrains on ngram repetition when decoding'
    )
    group.add_argument(
        '--sepahypo',
        action='store_true',
        help=
        'decode sentence hypothesis independently. sort best for each sentence.'
    )
    group.add_argument(
        '--naive',
        action='store_true',
        help=
        'decode sentence hypothesis independently. sort best for each sentence.'
    )
    parser.add_argument(
        '--outindices',
        required=False,
        type=str,
        help='load set of indices that were out for a category dataset.')
    parser.add_argument('--covpen',
                        type=float,
                        default=0,
                        metavar='D',
                        help='coverage penalty (Gehrmann et al. 2018).')
    group.add_argument(
        '--keystop',
        action='store_true',
        help=
        'whether to use topic prediction to spot EndOfDocumet. Makes only sense '
        'with models using topic-key-prediction')

    args = options.parse_args_and_arch(parser, list_args)

    if not os.path.isdir(args.decode_dir):
        os.mkdir(args.decode_dir)

    main(args)
예제 #15
0
def train_translation_model(data_dir, extra_flags, criterion=None):
    parser = train.get_parser_with_args()
    args = options.parse_args_and_arch(
        parser,
        [
            "--save-dir",
            data_dir,
            "--train-source-text-file",
            os.path.join(data_dir, "train.in"),
            "--train-target-text-file",
            os.path.join(data_dir, "train.out"),
            "--eval-source-text-file",
            os.path.join(data_dir, "valid.in"),
            "--eval-target-text-file",
            os.path.join(data_dir, "valid.out"),
            "--source-max-vocab-size",
            "26",
            "--target-max-vocab-size",
            "26",
            "--max-tokens",
            "500",
            "--optimizer",
            "sgd",
            "--lr",
            "0.05",
            "--lr-scheduler",
            "fixed",
            "--lr-shrink",
            "0.95",
            "--momentum",
            "0.0",
            "--clip-norm",
            "5.0",
            "--sentence-avg",
            "--label-smoothing",
            "0.1",
            "--beam",
            "3",
            "--stop-no-best-bleu-eval",
            "5",
            "--unk-reward",
            "0.5",
            "--generate-bleu-eval-avg-checkpoints",
            "10",
            "--generate-bleu-eval-per-epoch",
            "--max-epoch",
            "1",
            "--stop-time-hr",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--source-lang",
            "in",
            "--target-lang",
            "out",
        ] + (extra_flags or []) +
        (criterion or ["--criterion", "label_smoothed_cross_entropy"]),
    )
    train.validate_and_set_default_args(args)
    train.main(args)
예제 #16
0
def cli_main():
    parser = rerank_options.get_reranking_parser()
    args = options.parse_args_and_arch(parser)
    score_lm(args)
예제 #17
0
파일: generate.py 프로젝트: fyabc/fairseq
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))


if __name__ == '__main__':
    parser = options.get_generation_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #18
0
    def __init__(self, data_path, checkpoint_path="checkpoint_best.pt"):
        self.parser = options.get_generation_parser(interactive=True)
        self.parser.set_defaults(path=checkpoint_path,
            remove_bpe="sentencepiece", dataset_impl="lazy", num_wokers=5
        )
        self.args = options.parse_args_and_arch(self.parser, 
            input_args=[data_path]
        )

        utils.import_user_module(self.args)

        if self.args.buffer_size < 1:
            self.args.buffer_size = 1
        if self.args.max_tokens is None and self.args.max_sentences is None:
            self.args.max_sentences = 1

        assert not self.args.sampling or self.args.nbest == self.args.beam, \
            '--sampling requires --nbest to be equal to --beam'
        assert not self.args.max_sentences or self.args.max_sentences <= self.args.buffer_size, \
            '--max-sentences/--batch-size cannot be larger than --buffer-size'

        self.use_cuda = torch.cuda.is_available() and not self.args.cpu

        self.task = tasks.setup_task(self.args)

        self.models, self._model_args = checkpoint_utils.load_model_ensemble(
            self.args.path.split(':'),
            arg_overrides=eval(self.args.model_overrides),
            task=self.task,
        )

        self.src_dict = self.task.source_dictionary
        self.tgt_dict = self.task.target_dictionary

        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam,
                need_attn=self.args.print_alignment,
            )
            if self.args.fp16:
                model.half()
            if self.use_cuda:
                model.cuda()

        self.generator = self.task.build_generator(self.args)

        if self.args.remove_bpe == 'gpt2':
            from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
            self.decoder = get_encoder(
                'fairseq/gpt2_bpe/encoder.json',
                'fairseq/gpt2_bpe/vocab.bpe',
            )
            self.encode_fn = lambda x: ' '.join(map(str, self.decoder.encode(x)))
        else:
            self.decoder = None
            self.encode_fn = lambda x: x

        self.align_dict = utils.load_align_dict(self.args.replace_unk)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(),
            *[model.max_positions() for model in self.models]
        )
예제 #19
0
def train_translation_model(
    data_dir,
    arch,
    extra_flags=None,
    task="translation",
    run_validation=False,
    lang_flags=None,
    extra_valid_flags=None,
):
    if lang_flags is None:
        lang_flags = [
            "--source-lang",
            "in",
            "--target-lang",
            "out",
        ]
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            task,
            data_dir,
            "--save-dir",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "nag",
            "--lr",
            "0.05",
            "--max-tokens",
            "500",
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--num-workers",
            "0",
        ]
        + lang_flags
        + (extra_flags or []),
    )
    train.main(train_args)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                task,
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
                "--num-workers",
                "0",
            ]
            + lang_flags
            + (extra_valid_flags or []),
        )
        validate.main(validate_args)
예제 #20
0
if __name__ == '__main__':

    parser = options.get_training_parser()
    parser.add_argument(
        '--config',
        type=str,
        nargs='*',
        help='paths to JSON files of experiment configurations, from high to low priority',
    )
    parser.add_argument('--torch-file-system', action='store_true')
    pre_parsed_args, unknown = parser.parse_known_args()

    config_dict = {}
    for config_path in pre_parsed_args.config:
        config_dict = update_config(config_dict, compose_configs(config_path))

    parser_modifier = modify_factory(config_dict)

    args = options.parse_args_and_arch(parser, modify_parser=parser_modifier)

    update_namespace(args, config_dict)

    # set sharing strategy file system in case /dev/shm/ limits are small
    if args.torch_file_system:
        torch.multiprocessing.set_sharing_strategy('file_system')

    main(args)


예제 #21
0
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)
예제 #22
0
def cli_main():
    parser = options.get_generation_parser()
    XGCNModel.add_args(parser)
    parser = add_gen_args(parser)
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #23
0
def train_legacy_masked_language_model(data_dir, arch, extra_args=()):
    train_parser = options.get_training_parser()
    # TODO: langs should be in and out right?
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            "cross_lingual_lm",
            data_dir,
            "--arch",
            arch,
            # Optimizer args
            "--optimizer",
            "adam",
            "--lr-scheduler",
            "reduce_lr_on_plateau",
            "--lr-shrink",
            "0.5",
            "--lr",
            "0.0001",
            "--min-lr",
            "1e-09",
            # dropout, attention args
            "--dropout",
            "0.1",
            "--attention-dropout",
            "0.1",
            # MLM args
            "--criterion",
            "legacy_masked_lm_loss",
            "--masked-lm-only",
            "--monolingual-langs",
            "in,out",
            "--num-segment",
            "5",
            # Transformer args: use a small transformer model for fast training
            "--encoder-layers",
            "1",
            "--encoder-embed-dim",
            "32",
            "--encoder-attention-heads",
            "1",
            "--encoder-ffn-embed-dim",
            "32",
            # Other training args
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--dataset-impl",
            "raw",
        ] + list(extra_args),
    )
    train.main(train_args)
예제 #24
0
def score_bw(args):
    if args.backwards1:
        scorer1_src = args.target_lang
        scorer1_tgt = args.source_lang
    else:
        scorer1_src = args.source_lang
        scorer1_tgt = args.target_lang

    if args.score_model2 is not None:
        if args.backwards2:
            scorer2_src = args.target_lang
            scorer2_tgt = args.source_lang
        else:
            scorer2_src = args.source_lang
            scorer2_tgt = args.target_lang

    rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
    rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None

    pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
        backwards_preprocessed_dir, lm_preprocessed_dir = \
        rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
                                     args.gen_model_name, args.shard_id, args.num_shards,
                                     args.sampling, args.prefix_len, args.target_prefix_frac,
                                     args.source_prefix_frac)

    score1_file = rerank_utils.rescore_file_name(
        pre_gen,
        args.prefix_len,
        args.model1_name,
        target_prefix_frac=args.target_prefix_frac,
        source_prefix_frac=args.source_prefix_frac,
        backwards=args.backwards1)

    if args.score_model2 is not None:
        score2_file = rerank_utils.rescore_file_name(
            pre_gen,
            args.prefix_len,
            args.model2_name,
            target_prefix_frac=args.target_prefix_frac,
            source_prefix_frac=args.source_prefix_frac,
            backwards=args.backwards2)

    if args.right_to_left1:
        rerank_data1 = right_to_left_preprocessed_dir
    elif args.backwards1:
        rerank_data1 = backwards_preprocessed_dir
    else:
        rerank_data1 = left_to_right_preprocessed_dir

    gen_param = [
        "--batch-size",
        str(128), "--score-reference", "--gen-subset", "train"
    ]
    if not rerank1_is_gen and not os.path.isfile(score1_file):
        print("STEP 4: score the translations for model 1")

        model_param1 = [
            "--path", args.score_model1, "--source-lang", scorer1_src,
            "--target-lang", scorer1_tgt
        ]
        gen_model1_param = [rerank_data1] + gen_param + model_param1

        gen_parser = options.get_generation_parser()
        input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)

        with open(score1_file, 'w') as f:
            with redirect_stdout(f):
                generate.main(input_args)

    if args.score_model2 is not None and not os.path.isfile(
            score2_file) and not rerank2_is_gen:
        print("STEP 4: score the translations for model 2")

        if args.right_to_left2:
            rerank_data2 = right_to_left_preprocessed_dir
        elif args.backwards2:
            rerank_data2 = backwards_preprocessed_dir
        else:
            rerank_data2 = left_to_right_preprocessed_dir

        model_param2 = [
            "--path", args.score_model2, "--source-lang", scorer2_src,
            "--target-lang", scorer2_tgt
        ]
        gen_model2_param = [rerank_data2] + gen_param + model_param2

        gen_parser = options.get_generation_parser()
        input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)

        with open(score2_file, 'w') as f:
            with redirect_stdout(f):
                generate.main(input_args)
예제 #25
0
def gen_and_reprocess_nbest(args):
    if args.score_dict_dir is None:
        args.score_dict_dir = args.data
    if args.prefix_len is not None:
        assert (
            args.right_to_left1 is False
        ), "prefix length not compatible with right to left models"
        assert (
            args.right_to_left2 is False
        ), "prefix length not compatible with right to left models"

    if args.nbest_list is not None:
        assert args.score_model2 is None

    if args.backwards1:
        scorer1_src = args.target_lang
        scorer1_tgt = args.source_lang
    else:
        scorer1_src = args.source_lang
        scorer1_tgt = args.target_lang

    store_data = (
        os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
    )
    if not os.path.exists(store_data):
        os.makedirs(store_data)

    (
        pre_gen,
        left_to_right_preprocessed_dir,
        right_to_left_preprocessed_dir,
        backwards_preprocessed_dir,
        lm_preprocessed_dir,
    ) = rerank_utils.get_directories(
        args.data_dir_name,
        args.num_rescore,
        args.gen_subset,
        args.gen_model_name,
        args.shard_id,
        args.num_shards,
        args.sampling,
        args.prefix_len,
        args.target_prefix_frac,
        args.source_prefix_frac,
    )
    assert not (
        args.right_to_left1 and args.backwards1
    ), "backwards right to left not supported"
    assert not (
        args.right_to_left2 and args.backwards2
    ), "backwards right to left not supported"
    assert not (
        args.prefix_len is not None and args.target_prefix_frac is not None
    ), "target prefix frac and target prefix len incompatible"

    # make directory to store generation results
    if not os.path.exists(pre_gen):
        os.makedirs(pre_gen)

    rerank1_is_gen = (
        args.gen_model == args.score_model1 and args.source_prefix_frac is None
    )
    rerank2_is_gen = (
        args.gen_model == args.score_model2 and args.source_prefix_frac is None
    )

    if args.nbest_list is not None:
        rerank2_is_gen = True

    # make directories to store preprossed nbest list for reranking
    if not os.path.exists(left_to_right_preprocessed_dir):
        os.makedirs(left_to_right_preprocessed_dir)
    if not os.path.exists(right_to_left_preprocessed_dir):
        os.makedirs(right_to_left_preprocessed_dir)
    if not os.path.exists(lm_preprocessed_dir):
        os.makedirs(lm_preprocessed_dir)
    if not os.path.exists(backwards_preprocessed_dir):
        os.makedirs(backwards_preprocessed_dir)

    score1_file = rerank_utils.rescore_file_name(
        pre_gen,
        args.prefix_len,
        args.model1_name,
        target_prefix_frac=args.target_prefix_frac,
        source_prefix_frac=args.source_prefix_frac,
        backwards=args.backwards1,
    )
    if args.score_model2 is not None:
        score2_file = rerank_utils.rescore_file_name(
            pre_gen,
            args.prefix_len,
            args.model2_name,
            target_prefix_frac=args.target_prefix_frac,
            source_prefix_frac=args.source_prefix_frac,
            backwards=args.backwards2,
        )

    predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"

    using_nbest = args.nbest_list is not None

    if using_nbest:
        print("Using predefined n-best list from interactive.py")
        predictions_bpe_file = args.nbest_list

    else:
        if not os.path.isfile(predictions_bpe_file):
            print("STEP 1: generate predictions using the p(T|S) model with bpe")
            print(args.data)
            param1 = [
                args.data,
                "--path",
                args.gen_model,
                "--shard-id",
                str(args.shard_id),
                "--num-shards",
                str(args.num_shards),
                "--nbest",
                str(args.num_rescore),
                "--batch-size",
                str(args.batch_size),
                "--beam",
                str(args.num_rescore),
                "--batch-size",
                str(args.num_rescore),
                "--gen-subset",
                args.gen_subset,
                "--source-lang",
                args.source_lang,
                "--target-lang",
                args.target_lang,
            ]
            if args.sampling:
                param1 += ["--sampling"]

            gen_parser = options.get_generation_parser()
            input_args = options.parse_args_and_arch(gen_parser, param1)

            print(input_args)
            with open(predictions_bpe_file, "w") as f:
                with redirect_stdout(f):
                    generate.main(input_args)

    gen_output = rerank_utils.BitextOutputFromGen(
        predictions_bpe_file,
        bpe_symbol=args.post_process,
        nbest=using_nbest,
        prefix_len=args.prefix_len,
        target_prefix_frac=args.target_prefix_frac,
    )

    if args.diff_bpe:
        rerank_utils.write_reprocessed(
            gen_output.no_bpe_source,
            gen_output.no_bpe_hypo,
            gen_output.no_bpe_target,
            pre_gen + "/source_gen_bpe." + args.source_lang,
            pre_gen + "/target_gen_bpe." + args.target_lang,
            pre_gen + "/reference_gen_bpe." + args.target_lang,
        )
        bitext_bpe = args.rescore_bpe_code
        bpe_src_param = [
            "-c",
            bitext_bpe,
            "--input",
            pre_gen + "/source_gen_bpe." + args.source_lang,
            "--output",
            pre_gen + "/rescore_data." + args.source_lang,
        ]
        bpe_tgt_param = [
            "-c",
            bitext_bpe,
            "--input",
            pre_gen + "/target_gen_bpe." + args.target_lang,
            "--output",
            pre_gen + "/rescore_data." + args.target_lang,
        ]

        subprocess.call(
            [
                "python",
                os.path.join(
                    os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
                ),
            ]
            + bpe_src_param,
            shell=False,
        )

        subprocess.call(
            [
                "python",
                os.path.join(
                    os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
                ),
            ]
            + bpe_tgt_param,
            shell=False,
        )

    if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
        args.score_model2 is not None
        and not os.path.isfile(score2_file)
        and not rerank2_is_gen
    ):
        print(
            "STEP 2: process the output of generate.py so we have clean text files with the translations"
        )

        rescore_file = "/rescore_data"
        if args.prefix_len is not None:
            prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
        if args.target_prefix_frac is not None:
            target_prefix_frac_rescore_file = (
                rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
            )
        if args.source_prefix_frac is not None:
            source_prefix_frac_rescore_file = (
                rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
            )

        if not args.right_to_left1 or not args.right_to_left2:
            if not args.diff_bpe:
                rerank_utils.write_reprocessed(
                    gen_output.source,
                    gen_output.hypo,
                    gen_output.target,
                    pre_gen + rescore_file + "." + args.source_lang,
                    pre_gen + rescore_file + "." + args.target_lang,
                    pre_gen + "/reference_file",
                    bpe_symbol=args.post_process,
                )
                if args.prefix_len is not None:
                    bw_rescore_file = prefix_len_rescore_file
                    rerank_utils.write_reprocessed(
                        gen_output.source,
                        gen_output.hypo,
                        gen_output.target,
                        pre_gen + prefix_len_rescore_file + "." + args.source_lang,
                        pre_gen + prefix_len_rescore_file + "." + args.target_lang,
                        pre_gen + "/reference_file",
                        prefix_len=args.prefix_len,
                        bpe_symbol=args.post_process,
                    )
                elif args.target_prefix_frac is not None:
                    bw_rescore_file = target_prefix_frac_rescore_file
                    rerank_utils.write_reprocessed(
                        gen_output.source,
                        gen_output.hypo,
                        gen_output.target,
                        pre_gen
                        + target_prefix_frac_rescore_file
                        + "."
                        + args.source_lang,
                        pre_gen
                        + target_prefix_frac_rescore_file
                        + "."
                        + args.target_lang,
                        pre_gen + "/reference_file",
                        bpe_symbol=args.post_process,
                        target_prefix_frac=args.target_prefix_frac,
                    )
                else:
                    bw_rescore_file = rescore_file

                if args.source_prefix_frac is not None:
                    fw_rescore_file = source_prefix_frac_rescore_file
                    rerank_utils.write_reprocessed(
                        gen_output.source,
                        gen_output.hypo,
                        gen_output.target,
                        pre_gen
                        + source_prefix_frac_rescore_file
                        + "."
                        + args.source_lang,
                        pre_gen
                        + source_prefix_frac_rescore_file
                        + "."
                        + args.target_lang,
                        pre_gen + "/reference_file",
                        bpe_symbol=args.post_process,
                        source_prefix_frac=args.source_prefix_frac,
                    )
                else:
                    fw_rescore_file = rescore_file

        if args.right_to_left1 or args.right_to_left2:
            rerank_utils.write_reprocessed(
                gen_output.source,
                gen_output.hypo,
                gen_output.target,
                pre_gen + "/right_to_left_rescore_data." + args.source_lang,
                pre_gen + "/right_to_left_rescore_data." + args.target_lang,
                pre_gen + "/right_to_left_reference_file",
                right_to_left=True,
                bpe_symbol=args.post_process,
            )

        print("STEP 3: binarize the translations")
        if (
            not args.right_to_left1
            or args.score_model2 is not None
            and not args.right_to_left2
            or not rerank1_is_gen
        ):

            if args.backwards1 or args.backwards2:
                if args.backwards_score_dict_dir is not None:
                    bw_dict = args.backwards_score_dict_dir
                else:
                    bw_dict = args.score_dict_dir
                bw_preprocess_param = [
                    "--source-lang",
                    scorer1_src,
                    "--target-lang",
                    scorer1_tgt,
                    "--trainpref",
                    pre_gen + bw_rescore_file,
                    "--srcdict",
                    bw_dict + "/dict." + scorer1_src + ".txt",
                    "--tgtdict",
                    bw_dict + "/dict." + scorer1_tgt + ".txt",
                    "--destdir",
                    backwards_preprocessed_dir,
                ]
                preprocess_parser = options.get_preprocessing_parser()
                input_args = preprocess_parser.parse_args(bw_preprocess_param)
                preprocess.main(input_args)

            preprocess_param = [
                "--source-lang",
                scorer1_src,
                "--target-lang",
                scorer1_tgt,
                "--trainpref",
                pre_gen + fw_rescore_file,
                "--srcdict",
                args.score_dict_dir + "/dict." + scorer1_src + ".txt",
                "--tgtdict",
                args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
                "--destdir",
                left_to_right_preprocessed_dir,
            ]
            preprocess_parser = options.get_preprocessing_parser()
            input_args = preprocess_parser.parse_args(preprocess_param)
            preprocess.main(input_args)

        if args.right_to_left1 or args.right_to_left2:
            preprocess_param = [
                "--source-lang",
                scorer1_src,
                "--target-lang",
                scorer1_tgt,
                "--trainpref",
                pre_gen + "/right_to_left_rescore_data",
                "--srcdict",
                args.score_dict_dir + "/dict." + scorer1_src + ".txt",
                "--tgtdict",
                args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
                "--destdir",
                right_to_left_preprocessed_dir,
            ]
            preprocess_parser = options.get_preprocessing_parser()
            input_args = preprocess_parser.parse_args(preprocess_param)
            preprocess.main(input_args)

    return gen_output
예제 #26
0
파일: eval_lm.py 프로젝트: stas00/fairseq
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)

    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
예제 #27
0
def cli_main():
    parser = options.get_interactive_generation_parser()
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
예제 #28
0
def _quantize_language_model(data_dir,
                             arch,
                             extra_flags=None,
                             run_validation=False):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            "language_modeling",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--criterion",
            "adaptive_loss",
            "--adaptive-softmax-cutoff",
            "5,10,15",
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--ddp-backend",
            "no_c10d",
            "--num-workers",
            "0",
        ] + (extra_flags or []),
    )
    train.main(train_args)

    # try scalar quantization
    scalar_quant_train_parser = options.get_training_parser()
    scalar_quant_train_args = options.parse_args_and_arch(
        scalar_quant_train_parser,
        [
            "--task",
            "language_modeling",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--criterion",
            "adaptive_loss",
            "--adaptive-softmax-cutoff",
            "5,10,15",
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-update",
            "3",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--ddp-backend",
            "no_c10d",
            "--num-workers",
            "0",
            "--quant-noise-scalar",
            "0.5",
        ] + (extra_flags or []),
    )
    train.main(scalar_quant_train_args)

    # try iterative PQ quantization
    quantize_parser = options.get_training_parser()
    quantize_args = options.parse_args_and_arch(
        quantize_parser,
        [
            "--task",
            "language_modeling",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--criterion",
            "adaptive_loss",
            "--adaptive-softmax-cutoff",
            "5,10,15",
            "--max-tokens",
            "50",
            "--tokens-per-sample",
            "50",
            "--max-update",
            "6",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--ddp-backend",
            "no_c10d",
            "--num-workers",
            "0",
            "--restore-file",
            os.path.join(data_dir, "checkpoint_last.pt"),
            "--reset-optimizer",
            "--quantization-config-path",
            os.path.join(os.path.dirname(__file__),
                         "transformer_quantization_config.yaml"),
        ] + (extra_flags or []),
    )
    train.main(quantize_args)
예제 #29
0
def cli_main():
    parser = options.get_generation_parser()
    parser = add_asr_eval_argument(parser)
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #30
0
def cli_main():
    parser = make_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #31
0
def cli_main():
    parser = options.get_generation_parser(interactive=True)
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #32
0
def main():
    parser = get_parser_with_args()
    args = options.parse_args_and_arch(parser)
    save_top_k(args)
예제 #33
0
def cli_main():
    parser = options.get_generation_parser()
    args = options.parse_args_and_arch(parser)
    main(args)
예제 #34
0
파일: utils.py 프로젝트: ys-0-sy/translate
def train_translation_model(
    data_dir,
    extra_flags,
    criterion=None,
    set_empty_data_positional_arg=False,
    set_lang_args=True,
    save_dir: str = None,
):
    parser = train.get_parser_with_args()
    args = options.parse_args_and_arch(
        parser,
        ([""] if set_empty_data_positional_arg else []) + [
            "--save-dir",
            save_dir if save_dir else data_dir,
            "--train-source-text-file",
            os.path.join(data_dir, "train.in"),
            "--train-target-text-file",
            os.path.join(data_dir, "train.out"),
            "--eval-source-text-file",
            os.path.join(data_dir, "valid.in"),
            "--eval-target-text-file",
            os.path.join(data_dir, "valid.out"),
            "--source-max-vocab-size",
            "26",
            "--target-max-vocab-size",
            "26",
            "--max-tokens",
            "500",
            "--optimizer",
            "sgd",
            "--lr",
            "0.05",
            "--lr-scheduler",
            "fixed",
            "--lr-shrink",
            "0.95",
            "--momentum",
            "0.0",
            "--clip-norm",
            "5.0",
            "--sentence-avg",
            "--beam",
            "3",
            "--stop-no-best-bleu-eval",
            "5",
            "--unk-reward",
            "0.5",
            "--num-avg-checkpoints",
            "10",
            "--max-epoch",
            "1",
            "--stop-time-hr",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            "1",
            "--local-num-gpus",
            "1" if torch.cuda.device_count() >= 1 else "0",
        ] + (["--source-lang", "in", "--target-lang", "out"]
             if set_lang_args else []) + (extra_flags or []) + (criterion or [
                 "--criterion",
                 "label_smoothed_cross_entropy",
                 "--label-smoothing",
                 "0.1",
             ]),
    )
    train.validate_and_set_default_args(args)
    train.main(args)
def model_fn(model_dir):
    
    model_name = 'checkpoint_best.pt'
    model_path = os.path.join(model_dir, model_name)

    logger.info('Loading the model')
    with open(model_path, 'rb') as f:
        model_info = torch.load(f, map_location=torch.device('cpu'))

    # Will be overidden by the model_info['args'] - need to keep for pre-trained models   
    parser = options.get_generation_parser(interactive=True)
    # get args for FairSeq by converting the hyperparameters as if they were command-line arguments
    argv_copy = copy.deepcopy(sys.argv)
    # remove the modifications we did in the command-line arguments
    sys.argv[1:] = ['--path', model_path, model_dir]
    args = options.parse_args_and_arch(parser)
    # restore previous command-line args
    sys.argv = argv_copy
    
    saved_args = model_info['args']
    for key, value in vars(saved_args).items():
        setattr(args, key, value)

    args.data = [model_dir]
    print(args)

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info('Current device: {}'.format(device))

    model_paths = [os.path.join(model_dir, model_name)]
    models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides={})

    # Set dictionaries
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Initialize generator
    translator = SequenceGenerator(
        models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
        stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen, unk_penalty=args.unkpen,
        sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
        diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
    )

    if device.type == 'cuda':
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    # align_dict = utils.load_align_dict(args.replace_unk)
    align_dict = utils.load_align_dict(None)


    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

    return dict(
        translator=translator,
        task=task,
        max_positions=max_positions,
        align_dict=align_dict,
        tgt_dict=tgt_dict,
        args=args,
        device=device,
    )
예제 #36
0
def cli_main():
    parser = options.get_generation_parser(interactive=True)
    parser.add_argument('--output-file', required=True,
                        help='Output sentence embeddings')
    args = options.parse_args_and_arch(parser)
    main(args)