Example #1
0
def main(args):
    # we should not do this!
    '''
    if args.max_tokens is None:
        args.max_tokens = 6000
    '''
    utils.xpprint(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    utils.xprintln('setup task done!')

    # Load dataset splits
    load_dataset_splits(args, task, ['train'])
    valid_dataset = args.valid_subset.split(',')
    load_dataset_splits(args, task, valid_dataset, shuffle=False)
    utils.xprintln('load dataset done!')

    if args.task.startswith('extractive_summarization'):
        if distributed_utils.is_master(args):
            from sum_eval import MultiProcSumEval
            sum_eval_pool = MultiProcSumEval(args.ncpu_eval)
            sum_valid_pool_params = dict(
                article_file=args.raw_valid + '.article',
                summary_file=args.raw_valid + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )

            sum_test_pool_params = dict(
                article_file=args.raw_test + '.article',
                summary_file=args.raw_test + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )
            sum_pool_params = dict(valid=sum_valid_pool_params,
                                   test=sum_test_pool_params)

            def make_params(default_dict,
                            result_file,
                            out_rouge_file,
                            rerank=False,
                            with_m=False):
                para_dict = dict(default_dict)
                para_dict['result_file'] = result_file
                para_dict['out_rouge_file'] = out_rouge_file
                para_dict['rerank'] = rerank
                para_dict['with_m'] = with_m
                return para_dict

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))
    # print(model)
    import sys
    sys.stdout.flush()

    # if summarization try to load pretrained model
    # if args.task.startswith('extractive_summarization') or args.task == 'pretrain_document_modeling':
    #     # assume this is a single GPU program
    if args.init_from_pretrained_doc_model:
        task.load_pretrained_model(model, args.pretrained_doc_model_path)
    sys.stdout.flush()

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=False)

    # Load the latest checkpoint if one is available
    # load_checkpoint(args, trainer, epoch_itr)
    # make sure training from a different checkpoint will use different random seed
    cur_dataset = task.dataset('train')
    if hasattr(cur_dataset, 'rng'):
        print('epoch ', epoch_itr.epoch)
        cur_dataset.rng = numpy.random.RandomState(args.seed + epoch_itr.epoch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    for alpha in range(10, 9, -1):
        # train for one epoch
        # train(args, trainer, task, epoch_itr)

        epoch_itr.next_epoch_itr()

        if epoch_itr.epoch % args.validate_interval == 0:
            if args.task.startswith('extractive_summarization'):
                if distributed_utils.is_master(args):
                    validate_metric(args, trainer, task, epoch_itr,
                                    valid_subsets)
Example #2
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'
    # if args.save_path is not None:
    #     if check_file_exists(args):
    #         return
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    # print(args)
    utils.xpprint(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            # *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    if args.isRoberta:
        from pytorch_transformers import RobertaTokenizer
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    else:
        tokenizer = None
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        if not args.isRoberta:
                            print('S-{}\t{}'.format(sample_id, src_str))
                        else:
                            src_text = ''.join(src_str.strip().split())
                            src_out = tokenizer.convert_tokens_to_string(
                                src_text)
                            print('S-{}\t{}'.format(sample_id, src_out))
                    if has_target:
                        if not args.isRoberta:
                            print('T-{}\t{}'.format(sample_id, target_str))
                        else:
                            tgt_text = ''.join(target_str.strip().split())
                            tgt_out = tokenizer.convert_tokens_to_string(
                                tgt_text)
                            print('T-{}\t{}'.format(sample_id, tgt_out))
                # Process top predictions
                for i, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        if not args.isRoberta:
                            print('H-{}\t{}\t{}'.format(
                                sample_id, hypo['score'], hypo_str))
                        else:
                            hypo_text = ''.join(hypo_str.strip().split())
                            hypo_out = tokenizer.convert_tokens_to_string(
                                hypo_text)
                            print('H-{}\t{}\t{}'.format(
                                sample_id, hypo['score'], hypo_out))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    return scorer
Example #3
0
def main(args):
    from fairseq import utils
    utils.xpprint(args)
    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    def build_dictionary(filenames):
        d = dictionary.Dictionary()
        for filename in filenames:
            Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
        return d

    def build_dictionary_label(filenames):
        d = flexible_dictionary.FlexibleDictionary([('PAD', '<pad>')])
        for filename in filenames:
            Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, append_eos=False)
        return d

    def train_path(lang):
        return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += f'.{lang}'
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path('dict', lang) + '.txt'

    def dataset_dest_path(output_prefix, lang, extension):
        base = f'{args.destdir}/{output_prefix}'
        lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
        return f'{base}{lang_part}.{extension}'

    assert args.srcdict is not None, 'where is the Bert Dict!'
    if args.srcdict:
        src_dict = gpt2_dictionary.GPT2Dictionary.load(args.srcdict)
        src_dict.save(dict_path(args.source_lang))
        print('load bert dict from {} | size {}'.format(args.srcdict, len(src_dict)))
    else:
        assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
        src_dict = build_dictionary([train_path(args.source_lang)])
    if target:
        if args.tgtdict:
            tgt_dict = flexible_dictionary.FlexibleDictionary.load(args.tgtdict)
            print('load label dict from {} | size {}'.format(args.tgtdict, len(tgt_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
            tgt_dict = build_dictionary_label([train_path(args.target_lang)])
            print('build target dict from {} done'.format(train_path(args.target_lang)))

    src_dict.save(dict_path(args.source_lang))
    if target:
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=1,
            )
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(input_prefix, output_prefix, lang, append_eos=False):
        if lang == args.target_lang:
            dict = flexible_dictionary.FlexibleDictionary.load(dict_path(lang))
        else:
            # dict = bert_dictionary.BertDictionary.load(dict_path(lang))
            dict = gpt2_dictionary.GPT2Dictionary.load(dict_path(lang))

        print('| [{}] Dictionary: {} types | {} types (for real)'.format(lang, len(dict) - 1, len(dict)))

        ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))

        def consumer(tensor):
            ds.add_item(tensor)

        input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
        if lang == args.target_lang:
            res = Tokenizer.binarize(input_file, dict, consumer, append_eos=append_eos)
            print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
                lang, input_file, res['nseq'], res['ntok'],
                100 * res['nunk'] / res['ntok'], dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))
        else:
            # read article
            # from pytorch_pretrained_bert.tokenization import BertTokenizer
            # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
            from pytorch_transformers import RobertaTokenizer
            tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

            def penn_token2orig_token(sent):
                # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
                '''
                penn2orig = {"``":'"', "''": '"',
                             "-LRB-": '(', "-RRB-": ')',
                             "-LSB-":'[', "-RSB-":']',
                             "-LCB-":'{', "-RCB-":'}'}
                '''
                penn2orig = {"-LRB-": '(', "-RRB-": ')',
                             "-LSB-": '[', "-RSB-": ']',
                             "-LCB-": '{', "-RCB-": '}',
                             "-lrb-": '(', "-rrb-": ')',
                             "-lsb-": '[', "-rsb-": ']',
                             "-lcb-": '{', "-rcb-": '}',}
                words = sent.strip().split()
                words = [wd if not wd in penn2orig else penn2orig[wd] for wd in words]
                return ' '.join(words)

            num_token, num_unk_token = 0, 0
            num_seq = 0
            skip_line = 0
            for line in open(input_file, encoding='utf8'):
                sents = line.strip().split('<S_SEP>')
                sents = sents[0:args.max_num_sentences]
                sents = [' '.join(sent.strip().split()[0:args.max_num_words]) for sent in sents]
                # print(sents)
                sents = [tokenizer.tokenize(penn_token2orig_token(sent)) for sent in sents]
                article_wids = []
                for i, sent in enumerate(sents):
                    # sometimes there are too many tokens
                    MAXLEN = 500
                    if len(sent) > MAXLEN:
                        # sent = sent[0:MAXLEN]
                        print(' '.join(sent))
                        skip_line += 1
                        print(skip_line)
                        continue
                    if i != 0:
                        article_wids.append( dict.sep_index )
                    wids = tokenizer.convert_tokens_to_ids(sent)
                    # wids_vocab = [dict.index(word) for word in sent]
                    # assert wids == wids_vocab, 'word indices should be the same!'
                    article_wids.extend(wids)
                    for wid in wids:
                        if wid == dict.unk_index:
                            num_unk_token += 1
                        num_token += 1

                num_seq += 1
                tensor = torch.IntTensor(article_wids)
                # print( dict.string_complete(tensor) )
                ds.add_item(tensor)

            print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
                lang, input_file, num_seq, num_token,
                100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

        ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))

    def make_dataset(input_prefix, output_prefix, lang):
        if args.output_format == 'binary':
            make_binary_dataset(input_prefix, output_prefix, lang)
        elif args.output_format == 'raw':
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang):
        if args.trainpref:
            make_dataset(args.trainpref, 'train', lang)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(',')):
                outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
                make_dataset(validpref, outprefix, lang)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(',')):
                outprefix = 'test{}'.format(k) if k > 0 else 'test'
                make_dataset(testpref, outprefix, lang)

    make_all(args.source_lang)
    if target:
        make_all(args.target_lang)

    print('| Wrote preprocessed data to {}'.format(args.destdir))
Example #4
0
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = xlnet_dictionary.XLNetDictionary.load(args.srcdict)
            print('load xlnet dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = xlnet_dictionary.XLNetDictionary.load(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = xlnet_dictionary.XLNetDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import XLNetConfig, XLNetTokenizer
        import torch

        tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_wids = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_wids.append(dict.sep_index)
                wids = tokenizer.convert_tokens_to_ids(sent)
                # wids_vocab = [dict.index(word) for word in sent]
                # assert wids == wids_vocab, 'word indices should be the same!'
                article_wids.extend(wids)
                for wid in wids:
                    if wid == dict.unk_index:
                        num_unk_token += 1
                    num_token += 1

            num_seq += 1
            tensor = torch.IntTensor(article_wids)
            # print( dict.string_complete(tensor) )
            ds.add_item(tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, num_seq, num_token,
            100 * num_unk_token / num_token,
            dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

        #
        #     n_seq_tok = [0, 0]
        #     replaced = Counter()
        #
        #     def merge_result(worker_result):
        #         replaced.update(worker_result["replaced"])
        #         n_seq_tok[0] += worker_result["nseq"]
        #         n_seq_tok[1] += worker_result["ntok"]
        #
        #     input_file = "{}{}".format(
        #         input_prefix, ("." + lang) if lang is not None else ""
        #     )
        #     offsets = Binarizer.find_offsets(input_file, num_workers)
        #     pool = None
        #     if num_workers > 1:
        #         pool = Pool(processes=num_workers - 1)
        #         for worker_id in range(1, num_workers):
        #             prefix = "{}{}".format(output_prefix, worker_id)
        #             pool.apply_async(
        #                 binarize,
        #                 (
        #                     args,
        #                     input_file,
        #                     vocab,
        #                     prefix,
        #                     lang,
        #                     offsets[worker_id],
        #                     offsets[worker_id + 1]
        #                 ),
        #                 callback=merge_result
        #             )
        #         pool.close()
        #
        #     ds = indexed_dataset.IndexedDatasetBuilder(
        #         dataset_dest_file(args, output_prefix, lang, "bin")
        #     )
        #     merge_result(
        #         Binarizer.binarize(
        #             input_file, vocab, lambda t: ds.add_item(t),
        #             offset=0, end=offsets[1]
        #         )
        #     )
        #     if num_workers > 1:
        #         pool.join()
        #         for worker_id in range(1, num_workers):
        #             prefix = "{}{}".format(output_prefix, worker_id)
        #             temp_file_path = dataset_dest_prefix(args, prefix, lang)
        #             ds.merge_file_(temp_file_path)
        #             os.remove(indexed_dataset.data_file_path(temp_file_path))
        #             os.remove(indexed_dataset.index_file_path(temp_file_path))
        #
        #     ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        #
        #     print(
        #         "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
        #             lang,
        #             input_file,
        #             n_seq_tok[0],
        #             n_seq_tok[1],
        #             100 * sum(replaced.values()) / n_seq_tok[1],
        #             vocab.unk_word,
        #         )
        #     )

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)

            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Example #5
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Print args
    utils.xpprint(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split,
                          shuffle=False,
                          combine=False,
                          epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator

    if hasattr(args, 'init_from_pretrained_doc_model'
               ) and args.init_from_pretrained_doc_model:
        import os
        if not os.path.exists(os.path.join(args.save_dir,
                                           "checkpoint_last.pt")):
            args.restore_file = args.pretrained_doc_model_path
            args.reset_optimizer, args.reset_dataloader, args.resetmeters = True, True, True

    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
    sys.stdout.flush()

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    if isinstance(lr, list):
        lr = min(lr)
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or
                                 (epoch_itr.epoch == max_epoch
                                  and epoch_itr._next_epoch_itr is not None))
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
        if isinstance(lr, list):
            lr = min(lr)

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Example #6
0
def main(args, init_distributed=False):
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 6000

    # print(args)
    utils.xpprint(args)

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Initialize distributed training (after data loading)
    if init_distributed:
        import socket
        args.distributed_rank = distributed_utils.distributed_init(args)
        print('| initialized host {} as rank {}'.format(
            socket.gethostname(), args.distributed_rank))
    args.init_distributed = init_distributed
    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    import sys
    sys.stdout.flush()

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(
        args.max_tokens, max_positions, batch_size=args.max_sentences)
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # print(trainer.get_model().decoder.layers[11].output.LayerNorm.weight.data)

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        if args.task == 'abstractive_summarization_bert' or args.task == 'abstractive_summarization_roberta':
            if args.init_from_pretrained_model and args.pretrained_model_path:
                task.load_pretrained_model(model, args.pretrained_model_path)
            elif hasattr(
                    args,
                    'roberta_decoder') and args.roberta_decoder and hasattr(
                        args, 'roberta_decoder_initialization'
                    ) and args.roberta_decoder_initialization:
                model.initilize_roberta_decoder()
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    if args.sep_optim:
        dec_lr = trainer.get_dec_lr()

    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
        if args.sep_optim:
            dec_lr = trainer.dec_lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Example #7
0
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = roberta_dictionary.RobertaDictionary.load_json(
                args.srcdict)
            # src_dict.save('roberta-vocab/roberta-base-vocab.txt')
            print('load bert dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = roberta_dictionary.RobertaDictionary.load_json(
                    args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = roberta_dictionary.RobertaDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import RobertaTokenizer
        import torch

        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        output_ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, 'article_next', "bin"))
        truncated_number = 512
        output_length = 256

        CLS_TOKEN = '<s>'
        SEP_TOKEN = '</s>'

        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_toks = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_toks.append(SEP_TOKEN)
                article_toks.extend(sent)
            article_segments = []
            output_segments = []
            tmp_seg = []
            for i, tok in enumerate(article_toks):
                if len(tmp_seg) == 0:
                    tmp_seg.append(CLS_TOKEN)
                tmp_seg.append(tok)
                if tok == SEP_TOKEN:
                    tmp_seg.append(tok)
                if len(tmp_seg) >= truncated_number:
                    tmp_seg = tmp_seg[:truncated_number]
                    if tmp_seg[-1] != SEP_TOKEN:
                        tmp_seg[-1] = SEP_TOKEN
                    tmp_output = article_toks[
                        i + 1:min(i + 1 + output_length, len(article_toks))]
                    if len(tmp_output) < 0.3 * output_length:
                        break
                    article_segments.append(
                        tokenizer.convert_tokens_to_ids(tmp_seg))
                    output_segments.append(
                        tokenizer.convert_tokens_to_ids(tmp_output))
                    tmp_seg = []
            assert len(article_segments) == len(output_segments)
            for i in range(len(article_segments)):
                assert len(article_segments[i]) <= truncated_number
                assert len(output_segments[i]) <= output_length and len(
                    output_segments[i]) >= 0.3 * output_length
                tensor = torch.IntTensor(article_segments[i])
                ds.add_item(tensor)
                output_tensor = torch.IntTensor(output_segments[i])
                output_ds.add_item(output_tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        output_ds.finalize(
            dataset_dest_file(args, output_prefix, 'article_next', "idx"))
        print('done!')
        # print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
        #     lang, input_file, num_seq, num_token,
        #     100 * num_unk_token / num_token, dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        # if args.testpref:
        #     for k, testpref in enumerate(args.testpref.split(",")):
        #         outprefix = "test{}".format(k) if k > 0 else "test"
        #         make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    # if target:
    #     make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
        elif args.tgtdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
            print('load bert dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = bert_dictionary.BertDictionary.load(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = bert_dictionary.BertDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import BertTokenizer
        import torch

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        output_ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, 'article_next', "bin"))
        article_input = 511
        article_next = 256
        BERT_CLS_ID = tokenizer.convert_tokens_to_ids([BERT_CLS])[0]
        BERT_SEP_ID = tokenizer.convert_tokens_to_ids([BERT_SEP])[0]
        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_wids = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_wids.append(dict.sep_index)
                if len(sent) > article_input:

                    wids = []
                    temp_sent = [
                        sent[x:x + article_input]
                        for x in range(0, len(sent), article_input)
                    ]
                    for se in temp_sent:
                        se_ids = tokenizer.convert_tokens_to_ids(se)
                        wids.extend(se_ids)

                else:
                    wids = tokenizer.convert_tokens_to_ids(sent)
                # wids_vocab = [dict.index(word) for word in sent]
                # assert wids == wids_vocab, 'word indices should be the same!'
                article_wids.extend(wids)
                for wid in wids:
                    if wid == dict.unk_index:
                        num_unk_token += 1
                    num_token += 1

            article_segments = [
                article_wids[x:x + article_input]
                for x in range(0, len(article_wids), article_input)
            ]

            cur_position = 0
            for i in range(len(article_segments)):
                article_seq = article_segments[i]
                cur_position += len(article_seq)
                output_seg = article_wids[
                    cur_position:min(len(article_wids), cur_position +
                                     article_next)]
                if len(output_seg) < 0.3 * article_next:
                    continue
                num_seq += 1
                if len(article_seq) > article_input:
                    print('lang: %s, token len: %d, truncated len: %d' %
                          (lang, len(article_seq), article_input))
                if lang == 'article':
                    if article_seq[-1] != BERT_SEP_ID:
                        if article_seq[-2] != BERT_SEP_ID:
                            article_seq[-1] = BERT_SEP_ID
                    article_seq = [BERT_CLS_ID] + article_seq

                if len(output_seg) > article_next:
                    print(
                        'lang: article_next, token len: %d, truncated len: %d'
                        % (len(output_seg), article_next))

                tensor = torch.IntTensor(article_seq)
                ds.add_item(tensor)
                output_tensor = torch.IntTensor(output_seg)
                output_ds.add_item(output_tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        output_ds.finalize(
            dataset_dest_file(args, output_prefix, 'article_next', "idx"))
        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, num_seq, num_token,
            100 * num_unk_token / num_token,
            dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    # if target:
    #     make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)