예제 #1
0
def main(args, checkpoint_name="best"):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'
    
    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)
    
    use_cuda = torch.cuda.is_available() and not args.cpu
    torch.manual_seed(args.seed)

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
    args.taskobj = task

    sys.argv = sys.argv[:1]
    import tensorflow as tf
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)
    bleurt_scorer = score.BleurtScorer(os.path.join(
        cached_path(
            "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip",
            extract_compressed_file=True
        ), "bleurt-base-128"
    ))
    # Set dictionaries
    #src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    dict = tgt_dict
    
    # Load decoding strategy
    strategy = strategies.setup_strategy(args)

    # Load ensemble
    if args.path.startswith("nsml://"):
        print("| loading nsml checkpoint", args.path)
        import nsml
        session = args.path.replace("nsml://", "")
        model = task.build_model(args)
        def load(dir_path):
            state = torch.load(os.path.join(dir_path, 'best.pt'))
            state_dict = state["model"]
            model.load_state_dict(state_dict)
            print("loaded")
        nsml.load(args.checkpoint_name, load_fn=load, session=session)
        models = [model.cuda()]
    elif args.path == "pretrain":
        from nsml import DATASET_PATH
        from fairseq import checkpoint_utils
        data_token = "en-de"
        pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(DATASET_PATH, data_token.split(".")[-1].replace("-", "_"))
        print("| loading", pretrained_path)
        model = task.build_model(args)
        state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
        model.load_state_dict(state["model"], strict=True)
        models = [model.cuda()]
    elif args.path.startswith("wb://"):
        print("| loading wb checkpoint", args.path)
        import wandb
        wandb.restore("best.pt", args.path.replace("wb://", ""), root="/tmp/")
        assert os.path.exists("/tmp/best.pt")
        state = torch.load("/tmp/best.pt")
        model = task.build_model(args)
        model.load_state_dict(state["model"])
        models = [model.cuda()]
    elif args.path.startswith("http://"):
        print("| loading http checkpoint", args.path)
        url = "http://trains.deeplearn.org:8081/{}".format(args.path.replace("http://", ""))
        os.system("curl -o /tmp/model.pt {}".format(url))
        state = torch.load("/tmp/model.pt")
        model = task.build_model(args)
        model.load_state_dict(state["model"])
        models = [model.cuda()]
    else:
        print('| loading model(s) from {}'.format(args.path))
        models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
        models = [model.cuda() for model in models]

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)
    
    results = []
    scorer = pybleu.PyBleuScorer()
    num_sentences = 0
    has_target = True
    timer = TimeMeter()

    with progress_bar.build_progress_bar(args, itr) as t:

        translations = generate_batched_itr(t, strategy, models, tgt_dict, length_beam_size=args.length_beam, use_gold_target_len=args.gold_target_len)
        for sample_id, src_tokens, target_tokens, hypos in translations:
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = dict.string(src_tokens, args.remove_bpe)
                if args.dehyphenate:
                    src_str = dehyphenate(src_str)
                if has_target:
                    target_str = dict.string(target_tokens, args.remove_bpe, escape_unk=True)
                    if args.dehyphenate:
                        target_str = dehyphenate(target_str)

            if not args.quiet or True:
                # print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    # print('T-{}\t{}'.format(sample_id, target_str))
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypos.int().cpu(),
                        src_str=src_str,
                        alignment= None,
                        align_dict=align_dict,
                        tgt_dict=dict,
                        remove_bpe=args.remove_bpe,
                    )
                    if args.dehyphenate:
                        hypo_str = dehyphenate(hypo_str)

                    if not args.quiet:
                        print('H-{}\t{}'.format(sample_id, hypo_str))
                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
                                ' '.join(map(lambda x: str(utils.item(x)), alignment))
                            ))
                        # print()
                        
                        # Score only the top hypothesis
                        if has_target:
                            if align_dict is not None or args.remove_bpe is not None:
                                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                                target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)

                    results.append((target_str, hypo_str))
                    num_sentences += 1
        if has_target:
            print('Time = {}'.format(timer.elapsed_time))
            ref, out = zip(*results)
            from fairseq.criterions.lib_sbleu import smoothed_bleu
            sbleu = np.mean([smoothed_bleu(p[0].split(), p[1].split()) for p in results])
            print("| SBLEU = {:.2f}".format(sbleu))
            bleurt_scores = bleurt_scorer.score([p[0] for p in results], [p[1] for p in results])
            print("| BLEURT = {:.4f}".format(np.mean((np.array(bleurt_scores)))))
            print('| Generate {} with beam={}: BLEU4 = {:2.2f}, '.format(args.gen_subset, args.length_beam, scorer.score(ref, out)))
예제 #2
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'
    
    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)
    
    use_cuda = torch.cuda.is_available() and not args.cpu
    torch.manual_seed(args.seed)

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    dict = tgt_dict
    
    # Load decoding strategy
    strategy = strategies.setup_strategy(args)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
    models = [model.cuda() for model in models]

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)
    
    results = []
    scorer = pybleu.PyBleuScorer()
    num_sentences = 0
    has_target = True
    timer = TimeMeter()

    with progress_bar.build_progress_bar(args, itr) as t:

        translations = generate_batched_itr(t, strategy, models, tgt_dict, length_beam_size=args.length_beam, use_gold_target_len=args.gold_target_len)
        for sample_id, src_tokens, target_tokens, hypos in translations:
                
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if args.dehyphenate:
                    src_str = dehyphenate(src_str)
                if has_target:
                    target_str = dict.string(target_tokens, args.remove_bpe, escape_unk=True)
                    if args.dehyphenate:
                        target_str = dehyphenate(target_str)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypos.int().cpu(),
                        src_str=src_str,
                        alignment= None,
                        align_dict=align_dict,
                        tgt_dict=dict,
                        remove_bpe=args.remove_bpe,
                    )
                    if args.dehyphenate:
                        hypo_str = dehyphenate(hypo_str)

                    if not args.quiet:
                        print('H-{}\t{}'.format(sample_id, hypo_str))
                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
                                ' '.join(map(lambda x: str(utils.item(x)), alignment))
                            ))
                        print()
                        
                        # Score only the top hypothesis
                        if has_target:
                            if align_dict is not None or args.remove_bpe is not None:
                                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                                target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)

                            results.append((target_str, hypo_str))
                    num_sentences += 1
            else:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypos.int().cpu(),
                    src_str=src_str,
                    alignment= None,
                    align_dict=align_dict,
                    tgt_dict=dict,
                    remove_bpe=args.remove_bpe,
                )
                #if args.dehyphenate:
                #    hypo_str = dehyphenate(hypo_str)
                results.append((target_str, hypo_str))


        if has_target:
            print('Time = {}'.format(timer.elapsed_time))
            ref, out = zip(*results)
            print('| Generate {} with beam={}: BLEU4 = {:2.2f}, '.format(args.gen_subset, args.beam, scorer.score(ref, out)))
        if hasattr(strategy, 'nb_sents'):
            print(strategy.nb_sents)
            print(strategy.counts)
            print(strategy.counts/strategy.nb_sents)
예제 #3
0
def main(args, checkpoint_name="best"):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    torch.manual_seed(args.seed)

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))
    args.taskobj = task

    # Set dictionaries
    #src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    dict = tgt_dict

    # Load decoding strategy
    strategy = strategies.setup_strategy(args)

    # Load ensemble
    if args.path.startswith("nsml://"):
        print("| loading nsml checkpoint", args.path)
        import nsml
        session = args.path.replace("nsml://", "")
        model = task.build_model(args)

        def load(dir_path):
            state = torch.load(os.path.join(dir_path, 'best.pt'))
            state_dict = state["model"]
            model.load_state_dict(state_dict)
            print("loaded")

        nsml.load(args.checkpoint_name, load_fn=load, session=session)
        models = [model.cuda()]
    elif args.path == "pretrain":
        from nsml import DATASET_PATH
        from fairseq import checkpoint_utils
        data_token = "en-de"
        pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(
            DATASET_PATH,
            data_token.split(".")[-1].replace("-", "_"))
        print("| loading", pretrained_path)
        model = task.build_model(args)
        state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
        model.load_state_dict(state["model"], strict=True)
        models = [model.cuda()]
    elif args.path.startswith("wb://"):
        print("| loading wb checkpoint", args.path)
        import wandb
        wandb.restore("best.pt", args.path.replace("wb://", ""), root="/tmp/")
        assert os.path.exists("/tmp/best.pt")
        state = torch.load("/tmp/best.pt")
        model = task.build_model(args)
        model.load_state_dict(state["model"])
        models = [model.cuda()]
    elif args.path.startswith("http://"):
        print("| loading http checkpoint", args.path)
        url = "http://trains.deeplearn.org:8081/{}".format(
            args.path.replace("http://", ""))
        os.system("curl -o /tmp/model.pt {}".format(url))
        state = torch.load("/tmp/model.pt")
        model = task.build_model(args)
        model.load_state_dict(state["model"])
        models = [model.cuda()]
    else:
        print('| loading model(s) from {}'.format(args.path))
        models, _ = utils.load_ensemble_for_inference(
            args.path.split(':'),
            task,
            model_arg_overrides=eval(args.model_overrides))
        models = [model.cuda() for model in models]

    original_target_dataset = None
    assert args.original_target
    if args.original_target:
        original_target_dataset = IndexedCachedDataset(args.original_target,
                                                       fix_lua_indexing=True)

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    results = []
    scorer = pybleu.PyBleuScorer()
    num_sentences = 0
    has_target = True
    timer = TimeMeter()
    rel_reward_log = []

    with progress_bar.build_progress_bar(args, itr) as t:

        translations = generate_batched_itr(
            t,
            strategy,
            models,
            tgt_dict,
            length_beam_size=args.length_beam,
            use_gold_target_len=args.gold_target_len)
        for sample_id, src_tokens, target_tokens, hypos, logp in translations:

            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            distill_str = dict.string(target_tokens,
                                      args.remove_bpe,
                                      escape_unk=True)
            hypo_str = dict.string(hypos, args.remove_bpe, escape_unk=True)
            hypo_str_bpe = dict.string(hypos, None, escape_unk=True)

            # Compute reward
            original_target_dataset.prefetch([sample_id])
            orig_target = dict.string(original_target_dataset[sample_id],
                                      args.remove_bpe,
                                      escape_unk=True)
            hypo_reward = smoothed_bleu(hypo_str.split(), orig_target.split())
            distill_reward = smoothed_bleu(distill_str.split(),
                                           orig_target.split())
            rel_reward = hypo_reward - distill_reward
            rel_reward_log.append(rel_reward)

            print("{} | {:.4f} | {:.4f} | {}".format(sample_id, rel_reward,
                                                     logp, hypo_str_bpe))
    print("mean rel reward:", np.mean(rel_reward_log))