Пример #1
0
def main(cfg: DictConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)
    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in cfg.dataset.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {}".format(criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info("max tokens per GPU = {} and batch size per GPU = {}".format(
        cfg.dataset.max_tokens,
        cfg.dataset.batch_size,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr})")
            break

        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Пример #2
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        ## 单机多卡和多机多卡训练都会调用这个函数
        ## 此函数中调用init_process_group函数,
        ## 此时还没有load数据,因此应该就没有了之前版本多机训练时因为load数据速度不同导致的超时问题
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args): ## 判断当前GPU是否是master GPU(args.distributed_rank = 0)
        checkpoint_utils.verify_checkpoint_directory(args.save_dir) ## 确认checkpoint的目标存储路径

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    ## 创建对应的TranslationTask类,读入两个dictionary: self.src_dict, self.tgt_dict, 并确定是left paddig or right padding
    task = tasks.setup_task(args) 

    # Load valid dataset (we load training data below, based on the latest checkpoint) 
    # 用于验证的开发集, 每个集合的名字为valid_sub_split。load之后,根据valid_sub_split的名字存放在task.datasets中
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args) ## 搭建神经网络模型, 翻译即使用TransformerModel类, 继承自FairseqEncoderDecoderModel
    criterion = task.build_criterion(args) ## 搭建loss函数, 此处即使用LabelSmoothedCrossEntropyCriterion
    print(model)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    ##print the number of parameters of each matrix
    #for name, param in model.named_parameters(recurse=True):
    #    print (name, param.numel())
    #exit(0)

    # Build trainer
    # 如果distributed_world_size > 1, 则会对model和criterion使用models.DistributedFairseqModel进行wrap
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) ## generate data iterator, epoch_itr

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while (
        lr > args.min_lr
        and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch
            and epoch_itr._next_epoch_itr is not None))
        and trainer.get_num_updates() < max_update
    ):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ##每个epoch都新建一个epoch data iterator来遍历所有的训练数据
        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Пример #3
0
def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionary
    tgt_dict = task.target_dictionary

    if args.ctc or args.rnnt:
        tgt_dict.add_symbol("<ctc_blank>")
        if args.ctc:
            logger.info("| decoding a ctc model")
        if args.rnnt:
            logger.info("| decoding a rnnt model")

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(":"),
        task,
        model_arg_overrides=eval(args.model_overrides),  # noqa
    )
    optimize_models(args, use_cuda, models)

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0

    if not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    sp = spm.SentencePieceProcessor()
    sp.Load(os.path.join(args.data, 'spm.model'))

    res_files = prepare_result_files(args)
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                speaker = task.dataset(
                    args.gen_subset).speakers[int(sample_id)]
                id = task.dataset(args.gen_subset).ids[int(sample_id)]
                target_tokens = (utils.strip_pad(sample["target"][i, :],
                                                 tgt_dict.pad()).int().cpu())
                # Process top predictions
                process_predictions(args, hypos[i], sp, tgt_dict,
                                    target_tokens, res_files, speaker, id)

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"]

    logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                "sentences/s, {:.2f} tokens/s)".format(
                    num_sentences,
                    gen_timer.n,
                    gen_timer.sum,
                    num_sentences / gen_timer.sum,
                    1.0 / gen_timer.avg,
                ))
    logger.info("| Generate {} with beam={}".format(args.gen_subset,
                                                    args.beam))
Пример #4
0
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    start_id = 0
    for inputs in buffered_read(args.input, args.buffer_size):
        results = []
        for batch in make_batches(inputs, args, task, max_positions,
                                  encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = task.inference_step(generator, models, sample)
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
                hypo_str = decode_fn(hypo_str)
                print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                print('P-{}\t{}'.format(
                    id, ' '.join(
                        map(lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist()))))
                if args.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print('A-{}\t{}'.format(id, alignment_str))

        # update running id counter
        start_id += len(inputs)
Пример #5
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    ########### xml initialization
    ##  xml_model/mlm_tlm_xnli15_1024.pth
    reloaded = torch.load(args.xml_model_path)
    params = AttrDict(reloaded['params'])
    print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

    # build dictionary / update parameters
    xml_dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
    params.n_words = len(xml_dico)
    params.bos_index = xml_dico.index(BOS_WORD)
    params.eos_index = xml_dico.index(EOS_WORD)
    params.pad_index = xml_dico.index(PAD_WORD)
    params.unk_index = xml_dico.index(UNK_WORD)
    params.mask_index = xml_dico.index(MASK_WORD)

    xml_model = TransformerModel(params, xml_dico, True, True)
    xml_model.eval()
    print(xml_model._modules['position_embeddings']._parameters['weight'][0, 0].item())
    xml_model.load_state_dict(reloaded['model'])
    print(xml_model._modules['position_embeddings']._parameters['weight'][0, 0].item())
    ############## end of xml initilization

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, shuffle=False, epoch=0, xml_dico=xml_dico, xml_params=params, task=args.task)

    # Build model and criterion
    predictor, estimator, xml_estmimator = task.build_model(args)
    # xml_estmimator = None

    for n, p in predictor.named_parameters():
        if not n.startswith('estimator'):
            p.requires_grad = False

    criterion = task.build_criterion(args)
    print(predictor)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in predictor.parameters()),
        sum(p.numel() for p in predictor.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, predictor, estimator, criterion, xml_model=xml_model, xml_estmimator=xml_estmimator)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    print(predictor._modules['encoder']._modules['embed_tokens']._parameters['weight'].data[0, 0].item())
    checkpoint_utils.load_predictor_checkpoint(args, trainer)
    print(predictor._modules['encoder']._modules['embed_tokens']._parameters['weight'].data[0, 0].item())

    extra_state, epoch_itr = checkpoint_utils.load_estimator_checkpoint(args, trainer, xml_dico=xml_dico, xml_params=params)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    if args.evaluate > 0:
        validate(args, trainer, task, epoch_itr, valid_subsets, evaluate=True)
        return

    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train_qe(args, trainer, task, epoch_itr)

        # if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
        if not args.disable_validation:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        # if epoch_itr.epoch % args.save_interval == 0:
        #     checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Пример #6
0
def main(args):
    utils.import_user_module(args)

    os.makedirs(args.destdir, exist_ok=True)

    logger.addHandler(
        logging.FileHandler(filename=os.path.join(args.destdir,
                                                  'preprocess.log'), ))
    logger.info(args)

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames,
                         src=False,
                         tgt=False,
                         custom_bos='<bos>',
                         custom_pad='<pad>',
                         custom_eos='</s>',
                         custom_unk='<unk>'):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
            custom_bos=custom_bos,
            custom_pad=custom_pad,
            custom_eos=custom_eos,
            custom_unk=custom_unk)

    target = not args.only_source

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True,
                                        custom_bos=args.bos,
                                        custom_pad=args.pad,
                                        custom_eos=args.eos,
                                        custom_unk=args.unk)

        if target:
            if args.tgtdict:
                tgt_dict = \
                    task.load_dictionary(args.tgtdict, custom_bos=args.bos, custom_pad=args.pad,
                                         custom_eos=args.eos, custom_unk=args.unk,
                                         add_sentence_limit_words_after=args.tgtdict_add_sentence_limit_words_after)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize,
                                 (args, input_file, vocab, prefix, lang,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        # TODO: Use tokenizer to BERT tokenize the data.
        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, lang, "bin"),
                                          impl=args.dataset_impl,
                                          vocab_size=len(vocab))
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            ))

    def make_binary_alignment_dataset(input_prefix, output_prefix,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result['nseq']

        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (args, input_file, utils.parse_alignment, prefix,
                     offsets[worker_id], offsets[worker_id + 1]),
                    callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(input_file,
                                          utils.parse_alignment,
                                          lambda t: ds.add_item(t),
                                          offset=0,
                                          end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)

    def make_all(lang, vocab):
        if args.trainpref:
            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    def make_all_alignments():
        if args.trainpref and os.path.exists(args.trainpref + "." +
                                             args.align_suffix):
            make_binary_alignment_dataset(args.trainpref + "." +
                                          args.align_suffix,
                                          "train.align",
                                          num_workers=args.workers)
        if args.validpref and os.path.exists(args.validpref + "." +
                                             args.align_suffix):
            make_binary_alignment_dataset(args.validpref + "." +
                                          args.align_suffix,
                                          "valid.align",
                                          num_workers=args.workers)
        if args.testpref and os.path.exists(args.testpref + "." +
                                            args.align_suffix):
            make_binary_alignment_dataset(args.testpref + "." +
                                          args.align_suffix,
                                          "test.align",
                                          num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)
    if args.align_suffix:
        make_all_alignments()

    logger.info("Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Пример #7
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True,
                                                 extra_symbols_to_ignore={
                                                     generator.eos,
                                                 })

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not args.quiet:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                print(tgt_dict[2])
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                    extra_symbols_to_ignore={})
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    # detokenized hypothesis
                    print('D-{}\t{}\t{}'.format(sample_id, score,
                                                detok_hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args.print_step:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    if getattr(args, 'retain_iter_history', False):
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        if args.bpe and not args.sacrebleu:
            if args.remove_bpe:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    return scorer
Пример #8
0
def main(args):
    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    # 是否仅处理源语言, target为True表示不是
    target = not args.only_source
    # task表示要进行的任务类型,这里是根据给定的参数进行初始化后的 <class 'fairseq.tasks.translation.TranslationTask'>
    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        # src和tgt只能由一个为True
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,  # workers线程数
            threshold=args.thresholdsrc
            if src else args.thresholdtgt,  #将出现少于阈值次数的单词映射为unknow
            nwords=args.nwordssrc if src else args.nwordstgt,  #保留的目标或源词数
            padding_factor=args.padding_factor,
        )

    #检查字典文件是否存在,字典文件不应该存在,src字典和tgt字典
    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))
    # 默认是False,
    if args.copy_ext_dict:
        assert args.joined_dictionary, \
            "--joined-dictionary must be set if --copy-extended-dictionary is specified"
        assert args.workers == 1, \
            "--workers must be set to 1 if --copy-extended-dictionary is specified"

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        #如果srcdict不指定,那么trainpref必须制定, 训练数据用来生成字典
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            # 构建源序列的字典
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                #构建目标序列的字典
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None
    # eg: source_lang:  保存源和目标的字典
    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab,
                            input_prefix,
                            output_prefix,
                            lang,
                            num_workers,
                            copy_src_words=None):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()
        copyied = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            copyied.update(worker_result["copied"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:  # todo: not support copy
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize,
                                 (args, input_file, vocab, prefix, lang,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        words_list = []

        def binarize_consumer(ids, words):
            ds.add_item(ids)
            words_list.append(words)

        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               binarize_consumer,
                               offset=0,
                               end=offsets[1],
                               copy_ext_dict=args.copy_ext_dict,
                               copy_src_words=copy_src_words))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}, {:.3}% <unk> copied from src"
            .format(lang, input_file, n_seq_tok[0], n_seq_tok[1],
                    100 * sum(replaced.values()) / n_seq_tok[1],
                    vocab.unk_word,
                    100 * sum(copyied.values()) / n_seq_tok[1]))

        return words_list

    def make_dataset(vocab,
                     input_prefix,
                     output_prefix,
                     lang,
                     num_workers=1,
                     copy_src_words=None):
        if args.output_format == "binary":
            return make_binary_dataset(vocab, input_prefix, output_prefix,
                                       lang, num_workers, copy_src_words)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

            return None

    def make_all(lang, vocab,
                 source_words_list_dict=defaultdict(lambda: None)):
        words_list_dict = defaultdict(lambda: None)

        if args.trainpref:
            words_list_dict["train"] = \
                make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers,
                             copy_src_words=source_words_list_dict['train'])
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                words_list_dict["valid"] = \
                    make_dataset(vocab, validpref, outprefix, lang, copy_src_words=source_words_list_dict['valid'])
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                words_list_dict["test"] = \
                    make_dataset(vocab, testpref, outprefix, lang, copy_src_words=source_words_list_dict['test'])
        # 返回一个默认的dict, eg 'train' = {NoneType} None   'valid' = {NoneType} None  或者 test
        return words_list_dict

    # source_words_list_dict
    source_words_list_dict = make_all(args.source_lang, src_dict)
    if target:
        target_words_list_dict = make_all(args.target_lang, tgt_dict,
                                          source_words_list_dict)

    print("把预处理数据写入到 {}".format(args.destdir))

    if False:  #args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)

    # if args.alignfile:
    #     from fairseq.tokenizer import tokenize_line
    #     import numpy as np
    #     assert args.trainpref, "--trainpref must be set if --alignfile is specified"
    #     src_file_name = train_path(args.source_lang)
    #     tgt_file_name = train_path(args.target_lang)
    #     src_labels_list = []
    #     tgt_labels_list = []
    #     with open(args.alignfile, "r", encoding='utf-8') as align_file:
    #         with open(src_file_name, "r", encoding='utf-8') as src_file:
    #             with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
    #                 for a, s, t in zip_longest(align_file, src_file, tgt_file):
    #                     src_words = tokenize_line(s)
    #                     tgt_words = tokenize_line(t)
    #                     ai = list(map(lambda x: tuple(x.split("-")), a.split()))
    #                     src_labels = np.ones(len(src_words), int)
    #                     tgt_labels = np.ones(len(tgt_words), int)
    #                     for sai, tai in ai:
    #                         if int(tai) >= len(tgt_words):
    #                             print('Bad case:')
    #                             print(tgt_words)
    #                             print(ai)
    #                             continue
    #                         src_word = src_words[int(sai)]
    #                         tgt_word = tgt_words[int(tai)]
    #                         if src_word == tgt_word:
    #                             src_labels[int(sai)] = 0
    #                             tgt_labels[int(tai)] = 0
    #                     src_labels_list.append(src_labels)
    #                     tgt_labels_list.append(tgt_labels)
    #
    #     save_label_file(os.path.join(args.destdir, "train.label.{}.txt".format(args.source_lang)), src_labels_list)
    #     save_label_file(os.path.join(args.destdir, "train.label.{}.txt".format(args.target_lang)), tgt_labels_list)
    print("预处理完成")
Пример #9
0
    def __init__(self, data_path, checkpoint_path, beam, nbest):
        self.parser = options.get_generation_parser(interactive=True)
        self.parser.set_defaults(path=checkpoint_path,
                                 remove_bpe=None,
                                 dataset_impl="lazy",
                                 num_wokers=5)
        self.args = options.parse_args_and_arch(self.parser,
                                                input_args=[data_path])
        self.args.beam = beam
        self.args.nbest = nbest

        utils.import_user_module(self.args)

        if self.args.buffer_size < 1:
            self.args.buffer_size = 1
        if self.args.max_tokens is None and self.args.max_sentences is None:
            self.args.max_sentences = 1

        assert not self.args.sampling or self.args.nbest == self.args.beam, \
            '--sampling requires --nbest to be equal to --beam'
        assert not self.args.max_sentences or self.args.max_sentences <= self.args.buffer_size, \
            '--max-sentences/--batch-size cannot be larger than --buffer-size'

        self.use_cuda = torch.cuda.is_available() and not self.args.cpu

        self.task = tasks.setup_task(self.args)

        self.models, self._model_args = checkpoint_utils.load_model_ensemble(
            self.args.path.split(':'),
            arg_overrides=eval(self.args.model_overrides),
            task=self.task,
        )

        self.src_dict = self.task.source_dictionary
        self.tgt_dict = self.task.target_dictionary

        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None
                if self.args.no_beamable_mm else self.args.beam,
                need_attn=self.args.print_alignment,
            )
            if self.args.fp16:
                model.half()
            if self.use_cuda:
                model.cuda()

        self.generator = self.task.build_generator(self.args)

        if self.args.remove_bpe == 'gpt2':
            from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
            self.decoder = get_encoder(
                'fairseq/gpt2_bpe/encoder.json',
                'fairseq/gpt2_bpe/vocab.bpe',
            )
            self.encode_fn = lambda x: ' '.join(
                map(str, self.decoder.encode(x)))
        else:
            self.decoder = None
            self.encode_fn = lambda x: x

        self.align_dict = utils.load_align_dict(self.args.replace_unk)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(),
            *[model.max_positions() for model in self.models])
Пример #10
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
            for name, _module in model.named_modules():
                if 'layer_norm' in name:
                    _module.float()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    print('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(i for i in range(len(task.source_dictionary))
                           if task.source_dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        for sample in t:
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(
                    float('-inf'))
                if inf_scores.any():
                    print(
                        '| Skipping tokens with inf scores:',
                        task.target_dictionary.string(
                            tokens[inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print(
                            str(int(sample_id)) + " " +
                            ('\t'.join('{} [{:2f}]'.format(x[0], x[1])
                                       for x in word_prob)))

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss,
                                                      np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            print(ws)
Пример #11
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu and not getattr(
            args, 'tpu', False):
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info('training on {} devices (GPUs/TPUs)'.format(
        args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
    if args.tpu:
        import torch_xla.core.xla_model as xm
        xm.rendezvous('load_checkpoint')  # wait for all workers
        xm.mark_step()

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch):
        # train for one epoch
        valid_losses, should_stop = train(args, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Пример #12
0
def main(cfg: DictConfig, override_args=None, **unused_kwargs):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    else:
        overrides = None

    logger.info(cfg)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))

    # reduce tokens per sample by the required context window size
    cfg.task.tokens_per_sample -= cfg.eval_lm.context_window

    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [cfg.common_eval.path],
        arg_overrides=overrides,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Load dataset splits
    gen_subset = cfg.dataset.gen_subset
    task.load_dataset(gen_subset)
    dataset = task.dataset(gen_subset)
    if cfg.eval_lm.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=cfg.task.tokens_per_sample,
            context_window=cfg.eval_lm.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    assert len(models) > 0

    logger.info(
        "num. model params: {}".format(sum(p.numel() for p in models[0].parameters()))
    )

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=cfg.dataset.max_tokens or 36000,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=True,
        num_shards=max(
            cfg.dataset.num_shards,
            cfg.distributed_training.distributed_world_size,
        ),
        shard_id=max(
            cfg.dataset.shard_id,
            cfg.distributed_training.distributed_rank,
        ),
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
    )

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch)

    score_sum = 0.0
    count = 0

    if cfg.common_eval.post_process is not None:
        if cfg.common_eval.post_process == "sentencepiece":
            raise NotImplementedError
        else:
            bpe_cont = cfg.common_eval.post_process.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    wps_meter = TimeMeter()

    for sample in progress:
        if "net_input" not in sample:
            continue

        sample = utils.move_to_cuda(sample) if use_cuda else sample

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample["ntokens"])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample["id"][i]

            tokens = hypo["tokens"]
            tgt_len = tokens.numel()
            pos_scores = hypo["positional_scores"].float()

            if cfg.task.add_bos_token:
                assert hypo["tokens"][0].item() == task.target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
            if inf_scores.any():
                logger.info(
                    "skipping tokens with inf scores:",
                    task.target_dictionary.string(tokens[inf_scores.nonzero()]),
                )
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks

            if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats:
                w = ""
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += task.source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                            pos_scores[i].item(), next_prob
                        )
                        is_bpe = False
                        w = ""
                if cfg.eval_lm.output_word_probs:
                    logger.info(
                        str(int(sample_id))
                        + " "
                        + (
                            "\t".join(
                                "{} [{:2f}]".format(x[0], x[1]) for x in word_prob
                            )
                        )
                    )

        wps_meter.update(sample["ntokens"])
        progress.log({"wps": round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info(
        "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format(
            gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg
        )
    )
    logger.info(
        "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
            avg_nll_loss, 2 ** avg_nll_loss
        )
    )

    if cfg.eval_lm.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            logger.info(ws)
def main(args):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    torch.manual_seed(args.seed)

    # Print args
    print(args)

    # Setup task
    task = tasks.setup_task(args)

    # Build model
    model = task.build_model(args)
    print(model)

    # specify the length of the dummy input for profile
    # for iwslt, the average length is 23, for wmt, that is 30
    dummy_sentence_length_dict = {'iwslt': 23, 'wmt': 30}
    if 'iwslt' in args.arch:
        dummy_sentence_length = dummy_sentence_length_dict['iwslt']
    elif 'wmt' in args.arch:
        dummy_sentence_length = dummy_sentence_length_dict['wmt']
    else:
        raise NotImplementedError

    dummy_src_tokens = [2] + [7] * (dummy_sentence_length - 1)
    dummy_prev = [7] * (dummy_sentence_length - 1) + [2]

    # for device predictor: device dataset generation
    # The temporary format is (SubTransformer, DeviceFeature) -> CacheMisses
    with open(args.lat_dataset_path, 'w') as fid:
        src_tokens_test = torch.tensor([dummy_src_tokens], dtype=torch.long)
        src_lengths_test = torch.tensor([dummy_sentence_length])
        prev_output_tokens_test_with_beam = torch.tensor([dummy_prev] *
                                                         args.beam,
                                                         dtype=torch.long)
        if args.latcpu:
            model.cpu()
            print(
                'Measuring model cache misses on CPU for dataset generation...'
            )
        elif args.latgpu:
            return

        feature_info = utils.get_feature_info()
        fid.write(','.join(feature_info) + ',')
        device_feature_info = [
            'l1d_size', 'l1i_size', 'l2_size', 'l3_size', 'num_cores'
        ]  # unit for cache sizes is KB
        fid.write(','.join(device_feature_info) + ',')
        misses_info = ['l1d_misses', 'l1i_misses', 'l2_misses', 'l3_misses']
        fid.write(','.join(misses_info) + '\n')

        device_feature_set = [[32, 32, 256, 4096, 1], [32, 32, 512, 4096, 1],
                              [32, 32, 512, 8192, 1], [64, 64, 512, 8192, 1]]
        num_devices = len(device_feature_set)

        for i in range(args.lat_dataset_size):
            print(i)

            config_sam = utils.sample_configs(
                utils.get_all_choices(args),
                reset_rand_seed=False,
                super_decoder_num_layer=args.decoder_layers)
            features = utils.get_config_features(config_sam)
            fid.write(','.join(map(str, features)) + ',')

            device_id = random.randint(0,
                                       num_devices - 1)  # randomly select data
            device_features = device_feature_set[device_id]
            fid.write(','.join(map(str, device_features)) + ',')

            model.set_sample_config(config_sam)
            # dry runs
            for _ in range(dry_run_iters):
                encoder_out_test = model.encoder(src_tokens=src_tokens_test,
                                                 src_lengths=src_lengths_test)

            # pass the model(features&device_features) as parameters into a new python process
            arguments = [
                dynamorio_dir + '/bin64/drrun', '-t', 'drcachesim',
                '-config_file',
                dev_conf_dir + '/dev' + str(device_id) + '.conf', '--',
                'python', '-u', 'encoder_infer.py',
                str(dummy_sentence_length)
            ]
            print(arguments)
            enproc = subprocess.Popen([
                dynamorio_dir + '/bin64/drrun', '-t', 'drcachesim',
                '-config_file', dev_conf_dir + '/dev' + str(device_id) +
                '.conf', '--', 'python', '-u', 'encoder_infer.py',
                str(dummy_sentence_length)
            ],
                                      stdout=subprocess.PIPE)
            while True:
                line = enproc.stdout.readline()
                if not line:
                    break
                print('line: ' + line.decode('utf-8'))
            encoder_misses = []
            print('Measuring encoder for dataset generation...')
            #TODO run model.encoder(src_tokens=src_tokens_test, src_lengths=src_lengths_test) in simulator and write results into encoder_misses
            print(encoder_misses)

            bsz = 1
            new_order = torch.arange(bsz).view(-1, 1).repeat(
                1, args.beam).view(-1).long()

            encoder_out_test_with_beam = model.encoder.reorder_encoder_out(
                encoder_out_test, new_order)

            # dry runs
            for _ in range(dry_run_iters):
                model.decoder(
                    prev_output_tokens=prev_output_tokens_test_with_beam,
                    encoder_out=encoder_out_test_with_beam)

            # decoder is more complicated because we need to deal with incremental states and auto regressive things
            decoder_iterations_dict = {'iwslt': 23, 'wmt': 30}
            if 'iwslt' in args.arch:
                decoder_iterations = decoder_iterations_dict['iwslt']
            elif 'wmt' in args.arch:
                decoder_iterations = decoder_iterations_dict['wmt']

            decoder_misses = []
            print('Measuring decoder for dataset generation...')
            #TODO run the below commands in simulator and write results into decoder_misses
            #incre_states = {}
            #for k_regressive in range(decoder_iterations):
            #    model.decoder(prev_output_tokens=prev_output_tokens_test_with_beam[:, :k_regressive + 1],
            #                  encoder_out=encoder_out_test_with_beam, incremental_state=incre_states)
            print(decoder_misses)

            #TODO Do we need to add the encoder/decoder misses together
            fid.write(','.join(map(str, encoder_misses)) + '\n')  #temp
Пример #14
0
def _initialize_fairseq(user_dir):
    logging.info("Setting up fairseq library...")
    if user_dir:
        args = type("", (), {"user_dir": user_dir})()
        fairseq_utils.import_user_module(args)
Пример #15
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.dataset_impl == 'raw', \
        '--replace-unk requires a raw text dataset (--dataset-impl=raw)'

    if args.results_path is not None:
        os.makedirs(args.results_path, exist_ok=True)
        output_file = open(os.path.join(args.results_path,
                                        "generate-results.txt"),
                           'w',
                           buffering=1)
    else:
        output_file = None

    wandb.init(job_type='generate', config=args)
    wandb.config['limit_samples'] = LIMIT_SAMPLES

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args, file=output_file)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path), file=output_file)
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str),
                              file=output_file)
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))),
                              file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            print("\n".join([
                                'E-{}_{}\t{}'.format(
                                    sample_id, step,
                                    utils.post_process_prediction(
                                        h['tokens'].int().cpu(), src_str, None,
                                        None, tgt_dict, None)[1])
                                for step, h in enumerate(hypo['history'])
                            ]),
                                  file=output_file)

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    wandb.summary['result'] = pandas.DataFrame(
        dataframe_rows, columns=['Source', 'Target', 'Hypothesis', 'Score'])

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg),
        file=output_file)
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()),
              file=output_file)

    return scorer
def main(args):
    from fairseq import utils
    utils.xpprint(args)

    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt
        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    if target and not args.tgtdict and os.path.exists(
            dict_path(args.target_lang)):
        raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary(
                {
                    train_path(lang)
                    for lang in [args.source_lang, args.target_lang]
                },
                src=True)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = bert_dictionary.BertDictionary.load(args.srcdict)
            print('load bert dict from {} | size {}'.format(
                args.srcdict, len(src_dict)))
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        src=True)

        if target:
            if args.tgtdict:
                tgt_dict = bert_dictionary.BertDictionary.load(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            tgt=True)
        else:
            tgt_dict = None

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        print('input_prefix', input_prefix)
        print(dict_path(lang))

        dict = bert_dictionary.BertDictionary.load(dict_path(lang))
        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        from pytorch_transformers import BertTokenizer
        import torch

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        def penn_token2orig_token(sent):
            # -LRB- -RRB- -LSB- -RSB- -LCB- -RCB-
            penn2orig = {
                "``": '"',
                "''": '"',
                "-LRB-": '(',
                "-RRB-": ')',
                "-LSB-": '[',
                "-RSB-": ']',
                "-LCB-": '{',
                "-RCB-": '}'
            }
            words = sent.strip().split()
            words = [
                wd if not wd in penn2orig else penn2orig[wd] for wd in words
            ]
            return ' '.join(words)

        num_token, num_unk_token = 0, 0
        num_seq = 0
        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        truncated_number = 511 if lang == 'article' else 100
        BERT_CLS_ID = tokenizer.convert_tokens_to_ids([BERT_CLS])[0]
        BERT_SEP_ID = tokenizer.convert_tokens_to_ids([BERT_SEP])[0]
        for line in open(input_file, encoding='utf8'):
            sents = line.strip().split('<S_SEP>')
            sents = [
                tokenizer.tokenize(penn_token2orig_token(sent))
                for sent in sents
            ]
            article_wids = []
            for i, sent in enumerate(sents):
                if i != 0:
                    article_wids.append(dict.sep_index)
                wids = tokenizer.convert_tokens_to_ids(sent)
                # wids_vocab = [dict.index(word) for word in sent]
                # assert wids == wids_vocab, 'word indices should be the same!'
                article_wids.extend(wids)
                for wid in wids:
                    if wid == dict.unk_index:
                        num_unk_token += 1
                    num_token += 1
            num_seq += 1
            article_wids = article_wids[:truncated_number] if len(
                article_wids) > truncated_number else article_wids

            if len(article_wids) > truncated_number:
                print('lang: %s, token len: %d, truncated len: %d' %
                      (lang, len(article_wids), truncated_number))
            if lang == 'article':
                if article_wids[-1] != BERT_SEP_ID:
                    if article_wids[-2] != BERT_SEP_ID:
                        article_wids[-1] = BERT_SEP_ID
                article_wids = [BERT_CLS_ID] + article_wids
            # print(article_wids)
            tensor = torch.IntTensor(article_wids)
            # print( dict.string_complete(tensor) )
            ds.add_item(tensor)

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, num_seq, num_token,
            100 * num_unk_token / num_token,
            dict.unk_word if hasattr(dict, 'unk_word') else '<no_unk_word>'))

        #
        #     n_seq_tok = [0, 0]
        #     replaced = Counter()
        #
        #     def merge_result(worker_result):
        #         replaced.update(worker_result["replaced"])
        #         n_seq_tok[0] += worker_result["nseq"]
        #         n_seq_tok[1] += worker_result["ntok"]
        #
        #     input_file = "{}{}".format(
        #         input_prefix, ("." + lang) if lang is not None else ""
        #     )
        #     offsets = Binarizer.find_offsets(input_file, num_workers)
        #     pool = None
        #     if num_workers > 1:
        #         pool = Pool(processes=num_workers - 1)
        #         for worker_id in range(1, num_workers):
        #             prefix = "{}{}".format(output_prefix, worker_id)
        #             pool.apply_async(
        #                 binarize,
        #                 (
        #                     args,
        #                     input_file,
        #                     vocab,
        #                     prefix,
        #                     lang,
        #                     offsets[worker_id],
        #                     offsets[worker_id + 1]
        #                 ),
        #                 callback=merge_result
        #             )
        #         pool.close()
        #
        #     ds = indexed_dataset.IndexedDatasetBuilder(
        #         dataset_dest_file(args, output_prefix, lang, "bin")
        #     )
        #     merge_result(
        #         Binarizer.binarize(
        #             input_file, vocab, lambda t: ds.add_item(t),
        #             offset=0, end=offsets[1]
        #         )
        #     )
        #     if num_workers > 1:
        #         pool.join()
        #         for worker_id in range(1, num_workers):
        #             prefix = "{}{}".format(output_prefix, worker_id)
        #             temp_file_path = dataset_dest_prefix(args, prefix, lang)
        #             ds.merge_file_(temp_file_path)
        #             os.remove(indexed_dataset.data_file_path(temp_file_path))
        #             os.remove(indexed_dataset.index_file_path(temp_file_path))
        #
        #     ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
        #
        #     print(
        #         "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
        #             lang,
        #             input_file,
        #             n_seq_tok[0],
        #             n_seq_tok[1],
        #             100 * sum(replaced.values()) / n_seq_tok[1],
        #             vocab.unk_word,
        #         )
        #     )

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                                num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang, vocab):
        if args.trainpref:
            print(args.trainpref, lang)

            make_dataset(vocab,
                         args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args.workers)

    make_all(args.source_lang, src_dict)
    if target:
        make_all(args.target_lang, tgt_dict)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        freq_map = {}
        with open(args.alignfile, "r", encoding='utf-8') as align_file:
            with open(src_file_name, "r", encoding='utf-8') as src_file:
                with open(tgt_file_name, "r", encoding='utf-8') as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = src_dict.encode_line(s, add_if_not_exist=False)
                        ti = tgt_dict.encode_line(t, add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(os.path.join(
                args.destdir,
                "alignment.{}-{}.txt".format(args.source_lang,
                                             args.target_lang),
        ),
                  "w",
                  encoding='utf-8') as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args):
    import_user_module(args)

    print(args)

    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    # group.add_argument("--convert_raw", action="store_true", help="convert_raw")
    # group.add_argument("--convert_with_bpe", action="store_true", help="convert_with_bpe")
    # group.add_argument('--bpe_code', metavar='FILE', help='bpe_code')

    # new_prefix, src_tree_file, tgt_tree_file

    if args.convert_raw:
        print(f'start --- args.convert_raw')
        raise NotImplementedError

    if args.convert_raw_only:
        print(f'Finish!.')
        return

    remove_root = not args.no_remove_root
    take_pos_tag = not args.no_take_pos_tag
    take_nodes = not args.no_take_nodes
    reverse_node = not args.no_reverse_node
    no_collapse = args.no_collapse
    # remove_root =, take_pos_tag =, take_nodes =
    print(f'remove_root: {remove_root}')
    print(f'take_pos_tag: {take_pos_tag}')
    print(f'take_nodes: {take_nodes}')
    print(f'reverse_node: {reverse_node}')
    print(f'no_collapse: {no_collapse}')

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def share_dict_path():
        return args.share_dict_txt

    def build_shared_nstack2seq_dictionary(_src_file, _tgt_file):
        d = dictionary.Dictionary()
        print(f'Build dict on src_file: {_src_file}')
        NstackTreeTokenizer.acquire_vocab_multithread(
            _src_file, d, tokenize_line, num_workers=args.workers,
            remove_root=remove_root, take_pos_tag=take_pos_tag, take_nodes=take_nodes,
            no_collapse=no_collapse,
        )
        print(f'Build dict on tgt_file: {_tgt_file}')
        dictionary.Dictionary.add_file_to_dictionary(_tgt_file, d, tokenize_line, num_workers=args.workers)
        d.finalize(
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor
        )
        print(f'Finish building vocabulary: size {len(d)}')
        return d

    def build_nstack_source_dictionary(_src_file):
        d = dictionary.Dictionary()
        print(f'Build dict on src_file: {_src_file}')
        NstackTreeTokenizer.acquire_vocab_multithread(
            _src_file, d, tokenize_line, num_workers=args.workers,
            remove_root=remove_root, take_pos_tag=take_pos_tag, take_nodes=take_nodes,
            no_collapse=no_collapse,
        )
        d.finalize(
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor
        )
        print(f'Finish building src vocabulary: size {len(d)}')
        return d

    def build_target_dictionary(_tgt_file):
        # assert src ^ tgt
        print(f'Build dict on tgt: {_tgt_file}')
        d = task.build_dictionary(
            [_tgt_file],
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
        )
        print(f'Finish building tgt vocabulary: size {len(d)}')
        return d

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))
    # if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
    #     raise FileExistsError(dict_path(args.target_lang))

    if args.joined_dictionary:
        assert not args.srcdict or not args.tgtdict, \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"

        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        elif args.tgtdict:
            src_dict = task.load_dictionary(args.tgtdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_file = f'{args.trainpref}.{args.source_lang}'
            tgt_file = f'{args.trainpref}.{args.target_lang}'
            src_dict = build_shared_nstack2seq_dictionary(src_file, tgt_file)
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = task.load_dictionary(args.srcdict)
        else:
            assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
            src_dict = build_nstack_source_dictionary(train_path(args.source_lang))

        if target:
            if args.tgtdict:
                tgt_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_target_dictionary(train_path(args.target_lang))
        else:
            tgt_dict = None
        # raise NotImplementedError(f'only allow args.joined_dictionary for now')

    src_dict.save(dict_path(args.source_lang))
    if target and tgt_dict is not None:
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
        pool = None

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin")
        )

        def consumer(tensor):
            ds.add_item(tensor)

        stat = BinarizerDataset.export_binarized_dataset(
            input_file, vocab, consumer, add_if_not_exist=False, num_workers=num_workers,
        )

        ntok = stat['ntok']
        nseq = stat['nseq']
        nunk = stat['nunk']

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                nseq,
                ntok,
                100 * nunk / ntok,
                vocab.unk_word,
            )
        )

    def make_binary_nstack_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )

        dss = {
            modality: NstackSeparateIndexedDatasetBuilder(
                dataset_dest_file_dptree(args, output_prefix, lang, 'bin', modality))
            for modality in NSTACK_KEYS
        }

        def consumer(example):
            for modality, tensor in example.items():
                dss[modality].add_item(tensor)

        stat = NstackTreeMergeBinarizerDataset.export_binarized_separate_dataset(
            input_file, vocab, consumer, add_if_not_exist=False, num_workers=num_workers,
            remove_root=remove_root, take_pos_tag=take_pos_tag, take_nodes=take_nodes, reverse_node=reverse_node,
            no_collapse=no_collapse,
        )
        ntok = stat['ntok']
        nseq = stat['nseq']
        nunk = stat['nunk']

        for modality, ds in dss.items():
            ds.finalize(dataset_dest_file_dptree(args, output_prefix, lang, "idx", modality))

        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                nseq,
                ntok,
                100 * nunk / ntok,
                vocab.unk_word,
            )
        )
        for modality, ds in dss.items():
            print(f'\t{modality}')

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_dptree_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format != "binary":
            raise NotImplementedError(f'output format {args.output_format} not impl')

        make_binary_nstack_dataset(vocab, input_prefix, output_prefix, lang, num_workers)

    def make_all(lang, vocab):
        if args.trainpref:
            print(f'!!!! Warning..... Not during en-fr target because already done!.....')
            # make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)

        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab, validpref, outprefix, lang, num_workers=args.eval_workers)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab, testpref, outprefix, lang, num_workers=args.eval_workers)

    def make_all_src(lang, vocab):
        if args.trainpref:
            # print(f'!!!! Warning..... Not during en-fr source because already done!.....')
            make_dptree_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)

        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dptree_dataset(vocab, validpref, outprefix, lang, num_workers=args.eval_workers)

        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dptree_dataset(vocab, testpref, outprefix, lang, num_workers=args.eval_workers)

    def make_all_tgt(lang, vocab):
        make_all(lang, vocab)

    make_all_src(args.source_lang, src_dict)
    print(f'|||| WARNIONG no processing for source.')
    if target:
        make_all_tgt(args.target_lang, tgt_dict)
        # print(f'No makign target')

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        raise NotImplementedError('alignfile Not impl at the moment')
Пример #18
0
def main(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    start_time = time.time()
    total_translate_time = 0

    utils.import_user_module(cfg.common)

    if cfg.interactive.buffer_size < 1:
        cfg.interactive.buffer_size = 1
    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.batch_size = 1

    assert (not cfg.generation.sampling
            or cfg.generation.nbest == cfg.generation.beam
            ), "--sampling requires --nbest to be equal to --beam"
    assert (not cfg.dataset.batch_size
            or cfg.dataset.batch_size <= cfg.interactive.buffer_size
            ), "--batch-size cannot be larger than --buffer-size"

    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    if cfg.dataset.aistrigh:
        predictor = PredictText(cfg.dataset.aistrigh[1],
                                cfg.dataset.aistrigh[0],
                                cfg.dataset.aistrigh[2])
        mutate = MutateText(cfg.dataset.aistrigh[1])

    # Load ensemble
    overrides = ast.literal_eval(cfg.common_eval.model_overrides)
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Initialize generator
    generator = task.build_generator(models, cfg.generation)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if cfg.generation.constraints:
        logger.warning(
            "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
        )

    if cfg.interactive.buffer_size > 1:
        logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info("Type the input sentence and press return:")
    start_id = 0
    for inputs in buffered_read(cfg.interactive.input,
                                cfg.interactive.buffer_size):
        results = []
        for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }
            translate_start_time = time.time()
            translations = task.inference_step(generator,
                                               models,
                                               sample,
                                               constraints=constraints)
            translate_time = time.time() - translate_start_time
            total_translate_time += translate_time
            list_constraints = [[] for _ in range(bsz)]
            if cfg.generation.constraints:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((
                    start_id + id,
                    src_tokens_i,
                    hypos,
                    {
                        "constraints": constraints,
                        "time": translate_time / len(translations),
                    },
                ))

        # sort output to match input order
        for id_, src_tokens, hypos, info in sorted(results,
                                                   key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          cfg.common_eval.post_process)
                print("S-{}\t{}".format(id_, src_str))
                print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
                for constraint in info["constraints"]:
                    print("C-{}\t{}".format(
                        id_,
                        tgt_dict.string(constraint,
                                        cfg.common_eval.post_process)))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                if cfg.dataset.aistrigh:
                    debpe_hypo_str = bpe.decode(hypo_str)
                    aist_predictions = predictor.predict(debpe_hypo_str)
                    debpe_hypo_str = mutate.mutate_text(aist_predictions)
                    detok_hypo_str = tokenizer.decode(debpe_hypo_str)
                else:
                    detok_hypo_str = decode_fn(hypo_str)
                score = hypo["score"] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print("H-{}\t{}\t{}".format(id_, score, hypo_str))
                # detokenized hypothesis
                print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
                print("P-{}\t{}".format(
                    id_,
                    " ".join(
                        map(
                            lambda x: "{:.4f}".format(x),
                            # convert from base e to base 2
                            hypo["positional_scores"].div_(math.log(2)
                                                           ).tolist(),
                        )),
                ))
                if cfg.generation.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print("A-{}\t{}".format(id_, alignment_str))

        # update running id_ counter
        start_id += len(inputs)

    logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(
        time.time() - start_time, total_translate_time))
Пример #19
0
def get_parser(desc, default_task='translation'):
    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument('--user-dir', default=None)
    usr_args, _ = usr_parser.parse_known_args()
    utils.import_user_module(usr_args)

    parser = configargparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument('--configs', required=False, is_config_file=True)
    # fmt: off
    parser.add_argument('--no-progress-bar',
                        action='store_true',
                        help='disable progress bar')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=1000,
        metavar='N',
        help='log progress every N batches (when progress bar is disabled)')
    parser.add_argument('--log-format',
                        default=None,
                        help='log format to use',
                        choices=['json', 'none', 'simple', 'tqdm'])
    parser.add_argument(
        '--tensorboard-logdir',
        metavar='DIR',
        default='',
        help='path to save logs for tensorboard, should match --logdir '
        'of running tensorboard (default: no tensorboard logging)')
    parser.add_argument('--seed',
                        default=1,
                        type=int,
                        metavar='N',
                        help='pseudo random number generator seed')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='use CPU instead of CUDA')
    parser.add_argument('--fp16', action='store_true', help='use FP16')
    parser.add_argument(
        '--memory-efficient-fp16',
        action='store_true',
        help='use a memory-efficient version of FP16 training; implies --fp16')
    parser.add_argument('--fp16-init-scale',
                        default=2**7,
                        type=int,
                        help='default FP16 loss scale')
    parser.add_argument('--fp16-scale-window',
                        type=int,
                        help='number of updates before increasing loss scale')
    parser.add_argument(
        '--fp16-scale-tolerance',
        default=0.0,
        type=float,
        help='pct of updates that can overflow before decreasing the loss scale'
    )
    parser.add_argument(
        '--min-loss-scale',
        default=1e-4,
        type=float,
        metavar='D',
        help='minimum FP16 loss scale, after which training is stopped')
    parser.add_argument('--threshold-loss-scale',
                        type=float,
                        help='threshold FP16 loss scale from below')
    parser.add_argument(
        '--user-dir',
        default=None,
        help=
        'path to a python module containing custom extensions (tasks and/or architectures)'
    )
    parser.add_argument(
        '--empty-cache-freq',
        default=0,
        type=int,
        help='how often to clear the PyTorch CUDA cache (0 to disable)')
    parser.add_argument(
        '--all-gather-list-size',
        default=16384,
        type=int,
        help='number of bytes reserved for gathering stats from workers')

    from fairseq.registry import REGISTRIES
    for registry_name, REGISTRY in REGISTRIES.items():
        parser.add_argument(
            '--' + registry_name.replace('_', '-'),
            default=REGISTRY['default'],
            choices=REGISTRY['registry'].keys(),
        )

    # Task definitions can be found under fairseq/tasks/
    from fairseq.tasks import TASK_REGISTRY
    parser.add_argument('--task',
                        metavar='TASK',
                        default=default_task,
                        choices=TASK_REGISTRY.keys(),
                        help='task')
    # fmt: on
    return parser
Пример #20
0
def main(args, init_distributed=False):
    import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 6000
    print(args)

    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    # Initialize distributed training (after data loading)
    if init_distributed:
        import socket
        args.distributed_rank = distributed_utils.distributed_init(args)
        print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Make a dummy batch to (i) warm the caching allocator and (ii) as a
    # placeholder DistributedDataParallel when there's an uneven number of
    # batches per worker.
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
    oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)

    # Build trainer
    trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    # Load the latest checkpoint if one is available
    if not load_checkpoint(args, trainer, epoch_itr):
        trainer.dummy_train_step([dummy_batch])

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Пример #21
0
def main(args, init_distributed=False):
    # what is that
    # print('hhhhhhhhhhhhhhhhhhhhhhhhhhhhhh')
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    # print('i can fdsaft undresta')
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Пример #22
0
def main(cfg: DictConfig, override_args=None):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    #assert (
    #    cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
    #), "Must specify batch size either with --max-tokens or --batch-size"
    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    task = tasks.setup_task(cfg.task)
    task.load_dataset(cfg.dataset.gen_subset)

    #if override_args is not None:
    #    overrides = vars(override_args)
    #    overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    #else:
    #    overrides = None
    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    #models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
    #    [cfg.common_eval.path],
    #    arg_overrides=overrides,
    #    suffix=cfg.checkpoint.checkpoint_suffix,
    #)
    models, model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count)
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args.criterion)
    criterion.eval()

    for subset in cfg.dataset.valid_subset.split(","):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception("Cannot find dataset: " + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=cfg.dataset.max_tokens,
            max_sentences=cfg.dataset.batch_size,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=cfg.dataset.
            skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=cfg.dataset.
            required_batch_size_multiple,
            seed=cfg.common.seed,
            num_shards=cfg.distributed_training.distributed_world_size,
            shard_id=cfg.distributed_training.distributed_rank,
            num_workers=cfg.dataset.num_workers,
            data_buffer_size=cfg.dataset.data_buffer_size,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if cfg.distributed_training.distributed_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=cfg.common.all_gather_list_size,
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
Пример #23
0
def main(args, override_args=None):
    utils.import_user_module(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    eval_task = args.eval_task
    model, task, criterion, generator, model_args = set_up_model(
        args, args.vqvae_path, override_args)
    if eval_task == 'sampling':
        assert args.prior_path is not None
        prior_model, prior_task, prior_criterion, prior_generator, prior_args = set_up_model(
            args, args.prior_path, None)

    dictionary = task.dictionary
    if eval_task == 'code_extract':
        fopt = io.open(os.path.join(args.results_path,
                                    args.gen_subset + ".codes"),
                       "w",
                       encoding='utf-8')
    elif eval_task == 'reconstruct':
        if args.sampling:
            prefix = ".sample"
        else:
            prefix = ".bs"
        if args.code_extract_strategy is not None:
            prefix = prefix + '.' + args.code_extract_strategy
        else:
            prefix += '.orig.sampling'
        if args.prefix_num > 0:
            prefix = prefix + ".prefix_{}".format(args.prefix_num)
        fopt = io.open(os.path.join(
            args.results_path, args.gen_subset + prefix + ".reconstruction"),
                       "w",
                       encoding='utf-8')
        # Generate and compute BLEU score
        if args.sacrebleu:
            scorer = bleu.SacrebleuScorer()
        else:
            scorer = bleu.Scorer(dictionary.pad(), dictionary.eos(),
                                 dictionary.unk())
    elif eval_task == 'sampling':
        fopt = io.open(os.path.join(args.results_path,
                                    args.gen_subset + ".samples"),
                       "w",
                       encoding='utf-8')
    else:
        raise ValueError

    # Initialize generator
    gen_timer = StopwatchMeter()
    num_sentences = 0
    generate_id = 0
    if eval_task != 'sampling':
        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for subset in args.gen_subset.split(','):
            try:
                task.load_dataset(subset, combine=False, epoch=args.shard_id)
                dataset = task.dataset(subset)
            except KeyError:
                raise Exception('Cannot find dataset: ' + subset)

            # Initialize data iterator
            itr = task.get_batch_iterator(
                dataset=dataset,
                max_tokens=args.max_tokens,
                max_sentences=args.max_sentences,
                max_positions=utils.resolve_max_positions(
                    task.max_positions(),
                    model.max_positions(),
                ),
                ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=args.required_batch_size_multiple,
                seed=args.seed,
                num_workers=args.num_workers,
            ).next_epoch_itr(shuffle=False)
            progress = progress_bar.build_progress_bar(
                args,
                itr,
                prefix='valid on \'{}\' subset'.format(subset),
                no_progress_bar='simple')

            log_outputs = []
            wps_meter = TimeMeter()
            all_codes = set()
            all_stats = []
            for jj, sample in enumerate(progress):
                sample = utils.move_to_cuda(sample) if use_cuda else sample
                log_output = {'sample_size': sample['target'].size(0)}
                log_outputs.append(log_output)
                num_sentences += sample['nsentences']

                if eval_task == 'code_extract':
                    codes = task.extract_codes(sample, model)
                    if args.gen_subset == 'valid':
                        all_codes.update(torch.unique(codes).tolist())
                    codes = codes.cpu().numpy()
                elif eval_task == 'reconstruct':
                    prefix_tokens = None if args.prefix_num == 0 else sample[
                        'target'][:, :args.prefix_num]
                    gen_timer.start()
                    hypos, codes, stats = task.reconstruct(
                        sample,
                        model,
                        generator,
                        prefix_tokens,
                        extract_mode=args.code_extract_strategy)
                    if isinstance(codes, list):
                        codes, topk = codes
                    else:
                        topk = None
                    if 'avg_topp' in stats:
                        all_stats.append(stats)
                    all_codes.update(torch.unique(codes).tolist())
                    num_generated_tokens = sum(
                        len(h[0]['tokens']) for h in hypos)
                    gen_timer.stop(num_generated_tokens)
                    wps_meter.update(num_generated_tokens)
                    progress.log({'wps': round(wps_meter.avg)})
                else:
                    raise NotImplementedError

                progress.log(log_output, step=jj)

                for i, sample_id in enumerate(sample['id'].tolist()):
                    tokens = utils.strip_pad(sample['target'][i, :],
                                             dictionary.pad())
                    origin_string = dictionary.string(
                        tokens, bpe_symbol=args.remove_bpe, escape_unk=True)
                    if len(tokens) <= 1:
                        continue
                    bpe_string = dictionary.string(tokens,
                                                   bpe_symbol=None,
                                                   escape_unk=True)
                    fopt.write('T-ori-{}\t{}\n'.format(sample_id,
                                                       origin_string))
                    fopt.write('T-bpe-{}\t{}\n'.format(sample_id, bpe_string))

                    if eval_task == 'code_extract':
                        code = codes[i]
                        fopt.write('C-{}\t{}\n'.format(
                            sample_id, ' '.join([
                                "c-{}".format(x) for x in code.tolist()
                                if x != -1
                            ])))
                    elif eval_task == 'reconstruct':
                        for j, hypo in enumerate(hypos[i][:args.nbest]):
                            code = codes[i]
                            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                                hypo_tokens=hypo['tokens'].int().cpu()
                                [:len(tokens) - 1],
                                src_str="",
                                alignment=hypo['alignment'],
                                align_dict=None,
                                tgt_dict=dictionary,
                                remove_bpe=args.remove_bpe,
                            )
                            _, hypo_str_bpe, _ = utils.post_process_prediction(
                                hypo_tokens=hypo['tokens'].int().cpu()
                                [:len(tokens) - 1],
                                src_str="",
                                alignment=hypo['alignment'],
                                align_dict=None,
                                tgt_dict=dictionary,
                                remove_bpe=None,
                            )

                            fopt.write('H-bpe-{}\t{}\t{}\n'.format(
                                sample_id, hypo['score'], hypo_str_bpe))
                            fopt.write('H-{}\t{}\t{}\n'.format(
                                sample_id, hypo['score'], hypo_str))
                            code_str = ""
                            if topk is not None:
                                prob_str = ""
                                sent_topk = topk[i] / topk[i].sum(1).unsqueeze(
                                    1)
                            for ii, token_code in enumerate(code):
                                code_str += " ".join([
                                    "c{}-{}".format(ii, kk)
                                    for kk in token_code if kk != -1
                                ]) + ' '
                                if topk is not None:
                                    prob_str += " ".join([
                                        "p{}-{:.2f}".format(ii, kk.item())
                                        for kk, mm in zip(
                                            sent_topk[ii], token_code)
                                        if mm != -1
                                    ]) + ' '
                            fopt.write('C-{}\t{}\n'.format(
                                sample_id, code_str))
                            if topk is not None:
                                fopt.write('K-{}\n'.format(prob_str))
                            if hypo['attention'] is not None:
                                hypo_attn = hypo['attention'].cpu().numpy(
                                )  # src_len x tgt_len
                                entropy, max_idx = compute_attn(hypo_attn)
                                baseline_entropy = hypo_attn.shape[
                                    0] * math.log(1. / hypo_attn.shape[0]
                                                  ) * 1.0 / hypo_attn.shape[0]
                                fopt.write(
                                    'A-entropy-baseline-{:.2f}\n'.format(
                                        baseline_entropy))
                                fopt.write('A-entropy-{}\n'.format(" ".join(
                                    ["%.2f" % e for e in entropy])))
                                fopt.write('A-max-attn-pos-{}\n'.format(
                                    " ".join([str(kk) for kk in max_idx])))
                            fopt.write('\n')
                            if args.remove_bpe is not None:
                                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                                tokens = dictionary.encode_line(
                                    origin_string, add_if_not_exist=True)
                            if hasattr(scorer, 'add_string'):
                                scorer.add_string(origin_string, hypo_str)
                            else:
                                scorer.add(tokens, hypo_tokens)
                    else:
                        raise NotImplementedError
                generate_id += len(sample['id'])
                if generate_id % 1000 == 0:
                    print("Processed {} lines!".format(i))
                progress.print(log_outputs[0], tag=subset, step=i)
            print("Total unique active codes = {}".format(len(all_codes)))
            if len(all_stats) > 0:
                avg_topp = sum([stat['avg_topp']
                                for stat in all_stats]) * 1. / len(all_stats)
                max_topp = sum([stat['max_topp']
                                for stat in all_stats]) * 1. / len(all_stats)
                min_topp = sum([stat['min_topp']
                                for stat in all_stats]) * 1. / len(all_stats)
                print(
                    'Max #codes in top 0.9 = {}, min #codes in top 0.9 = {}, avg #codes in top 0.9 = {}'
                    .format(max_topp, min_topp, avg_topp))
        if eval_task == 'reconstruct':
            print(
                '| Reconstructed {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
                .format(num_sentences, gen_timer.n, gen_timer.sum,
                        num_sentences / gen_timer.sum, 1. / gen_timer.avg))
            print('| Reconstruct {} with beam={}: {}'.format(
                args.gen_subset, args.beam, scorer.result_string()))
    else:
        batch_size = 3072 // args.max_len_b
        gen_epochs = args.num_samples // batch_size
        latent_dictionary = prior_task.dictionary
        latent_dictionary_size = len(latent_dictionary)
        for ii in range(gen_epochs):
            dummy_tokens = torch.ones(batch_size, args.max_len_b).long().cuda()
            dummy_lengths = (torch.ones(args.max_len_b) *
                             args.max_len_b).long().cuda()
            dummy_samples = {
                'net_input': {
                    'src_tokens': dummy_tokens,
                    'src_lengths': dummy_lengths,
                },
                'target': dummy_tokens
            }
            prefix_tokens = None
            code_hypos = prior_task.inference_step(prior_generator,
                                                   [prior_model],
                                                   dummy_samples,
                                                   prefix_tokens)
            list_predictions = []
            for jj in range(batch_size):
                code_hypo = code_hypos[jj][0]  # best output
                latent_hypo_tokens, latent_hypo_str, _ = utils.post_process_prediction(
                    hypo_tokens=code_hypo['tokens'].int().cpu(),
                    src_str="",
                    alignment=code_hypo['alignment'],
                    align_dict=None,
                    tgt_dict=latent_dictionary,
                    remove_bpe=None,
                )
                # should have no pad and eos
                list_predictions.append(
                    torch.LongTensor([
                        int(ss) for ss in latent_hypo_str.strip().split()
                    ]).cuda())
            merged_codes = data_utils.collate_tokens(
                list_predictions,
                latent_dictionary_size,
                left_pad=False,
            )
            code_masks = merged_codes.eq(latent_dictionary_size)
            merged_codes = merged_codes.masked_fill_(code_masks, 0)
            hypos, _, _ = task.sampling(dummy_samples, merged_codes,
                                        code_masks, model, generator)
            for tt in range(len(hypos)):
                hypo = hypos[tt][0]
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str="",
                    alignment=hypo['alignment'],
                    align_dict=None,
                    tgt_dict=dictionary,
                    remove_bpe=args.remove_bpe,
                )
                fopt.write('C-{}\t{}\n'.format(
                    generate_id, " ".join([
                        "c-%d" % kk for kk in list_predictions[tt] if kk != -1
                    ])))
                fopt.write('H-{}\t{}\t{}\n'.format(generate_id, hypo['score'],
                                                   hypo_str))
                generate_id += 1

            if generate_id % 1000 == 0:
                print("Sampled {} sentences!".format(generate_id))
    fopt.close()
Пример #24
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    args.beam = args.nbest = 1
    args.max_tokens = int(1e4)

    utils.import_user_module(args)

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    src_dict = getattr(task, 'source_dictionary', None)
    tgt_dict = task.target_dictionary

    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(beamable_mm_beam_size=args.beam,
                                    need_attn=False)
        model.cuda()

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    generator = task.build_generator(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            sample = utils.move_to_cuda(sample)
            if 'net_input' not in sample:
                continue

            prefix_tokens = None

            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)

            for i, sample_id in enumerate(sample['id'].tolist()):
                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())

                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""

                # Process top predictions
                hypo = hypos[i][0]
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                result = dict(src=src_str,
                              pred=hypo_str,
                              src_len=len(src_str.split()),
                              pred_len=len(hypo_str.split()))
                result_line = json.dumps(result)
                print(result_line)
Пример #25
0
def main(args, override_args=None):
    utils.import_user_module(args)

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    else:
        overrides = None

    # Load ensemble
    print("| loading model(s) from {}".format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path], arg_overrides=overrides,
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    print(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    criterion.eval()

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for subset in args.valid_subset.split(","):
        try:
            task.load_dataset(subset, combine=False, epoch=0)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception("Cannot find dataset: " + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(), *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            prefix="valid on '{}' subset".format(subset),
            no_progress_bar="simple",
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        log_output = task.aggregate_logging_outputs(log_outputs, criterion)

        progress.print(log_output, tag=subset, step=i)
Пример #26
0
def get_parser(desc, default_task="translation"):
    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument("--user-dir", default=None)
    usr_args, _ = usr_parser.parse_known_args()
    utils.import_user_module(usr_args)

    parser = argparse.ArgumentParser(allow_abbrev=False)
    # fmt: off
    parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='log progress every N batches (when progress bar is disabled)')
    parser.add_argument('--log-format', default=None, help='log format to use',
                        choices=['json', 'none', 'simple', 'tqdm'])
    parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
                        help='path to save logs for tensorboard, should match --logdir '
                             'of running tensorboard (default: no tensorboard logging)')
    parser.add_argument('--seed', default=None, type=int, metavar='N',
                        help='pseudo random number generator seed')
    parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
    parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA')
    parser.add_argument('--bf16', action='store_true', help='use bfloat16; implies --tpu')
    parser.add_argument('--fp16', action='store_true', help='use FP16')
    parser.add_argument('--memory-efficient-bf16', action='store_true',
                        help='use a memory-efficient version of BF16 training; implies --bf16')
    parser.add_argument('--memory-efficient-fp16', action='store_true',
                        help='use a memory-efficient version of FP16 training; implies --fp16')
    parser.add_argument('--fp16-no-flatten-grads', action='store_true',
                        help='don\'t flatten FP16 grads tensor')
    parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int,
                        help='default FP16 loss scale')
    parser.add_argument('--fp16-scale-window', type=int,
                        help='number of updates before increasing loss scale')
    parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
                        help='pct of updates that can overflow before decreasing the loss scale')
    parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
                        help='minimum FP16 loss scale, after which training is stopped')
    parser.add_argument('--threshold-loss-scale', type=float,
                        help='threshold FP16 loss scale from below')
    parser.add_argument('--user-dir', default=None,
                        help='path to a python module containing custom extensions (tasks and/or architectures)')
    parser.add_argument('--empty-cache-freq', default=0, type=int,
                        help='how often to clear the PyTorch CUDA cache (0 to disable)')
    parser.add_argument('--all-gather-list-size', default=16384, type=int,
                        help='number of bytes reserved for gathering stats from workers')
    parser.add_argument('--model-parallel-size', type=int, metavar='N',
                        default=1,
                        help='total number of GPUs to parallelize model over')
    parser.add_argument('--checkpoint-suffix', default='',
                        help='suffix to add to the checkpoint file name')
    parser.add_argument('--quantization-config-path', default=None,
                        help='path to quantization config file')
    parser.add_argument('--profile', action='store_true', help='enable autograd profiler emit_nvtx')
    parser.add_argument('--use-bert-model', action='store_true',
                        help='pretrained BERT model to calculate BERT loss')
    parser.add_argument("--bos", default="<s>", type=str,
                        help="Specify bos token from the dictionary.")
    parser.add_argument("--pad", default="<pad>", type=str,
                        help="Specify bos token from the dictionary.")
    parser.add_argument("--eos", default="</s>", type=str,
                        help="Specify bos token from the dictionary.")
    parser.add_argument("--unk", default="<unk>", type=str,
                        help="Specify bos token from the dictionary.")
    parser.add_argument("--tgtdict_add_sentence_limit_words_after", action="store_true",
                        help="Add sentence limit words (i.e. bos, eos, pad, unk) after loading tgtdict.")
    parser.add_argument("--rewe", action="store_true",
                        help="Regressing Word Embeddings Regularization")

    from fairseq.registry import REGISTRIES
    for registry_name, REGISTRY in REGISTRIES.items():
        parser.add_argument(
            '--' + registry_name.replace('_', '-'),
            default=REGISTRY['default'],
            choices=REGISTRY['registry'].keys(),
        )

    # Task definitions can be found under fairseq/tasks/
    from fairseq.tasks import TASK_REGISTRY
    parser.add_argument('--task', metavar='TASK', default=default_task,
                        choices=TASK_REGISTRY.keys(),
                        help='task')
    # fmt: on
    return parser
Пример #27
0
def get_parser(desc, default_task='translation'):
    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument('--user-dir', default=None)
    usr_args, _ = usr_parser.parse_known_args()
    import_user_module(usr_args)

    parser = argparse.ArgumentParser(allow_abbrev=False)
    # fmt: off
    parser.add_argument('--no-progress-bar',
                        action='store_true',
                        help='disable progress bar')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=1000,
        metavar='N',
        help='log progress every N batches (when progress bar is disabled)')
    parser.add_argument('--log-format',
                        default=None,
                        help='log format to use',
                        choices=['json', 'none', 'simple', 'tqdm'])
    parser.add_argument(
        '--tensorboard-logdir',
        metavar='DIR',
        default='',
        help='path to save logs for tensorboard, should match --logdir '
        'of running tensorboard (default: no tensorboard logging)')
    parser.add_argument('--seed',
                        default=1,
                        type=int,
                        metavar='N',
                        help='pseudo random number generator seed')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='use CPU instead of CUDA')
    parser.add_argument('--fp16', action='store_true', help='use FP16')
    parser.add_argument(
        '--memory-efficient-fp16',
        action='store_true',
        help='use a memory-efficient version of FP16 training; implies --fp16')
    parser.add_argument('--fp16-init-scale',
                        default=2**7,
                        type=int,
                        help='default FP16 loss scale')
    parser.add_argument('--fp16-scale-window',
                        type=int,
                        help='number of updates before increasing loss scale')
    parser.add_argument(
        '--fp16-scale-tolerance',
        default=0.0,
        type=float,
        help='pct of updates that can overflow before decreasing the loss scale'
    )
    parser.add_argument(
        '--min-loss-scale',
        default=1e-4,
        type=float,
        metavar='D',
        help='minimum FP16 loss scale, after which training is stopped')
    parser.add_argument('--threshold-loss-scale',
                        type=float,
                        help='threshold FP16 loss scale from below')
    parser.add_argument(
        '--user-dir',
        default=None,
        help=
        'path to a python module containing custom extensions (tasks and/or architectures)'
    )

    # Task definitions can be found under fairseq/tasks/
    parser.add_argument('--task',
                        metavar='TASK',
                        default=default_task,
                        choices=TASK_REGISTRY.keys(),
                        help='task')
    # fmt: on
    return parser
Пример #28
0
def parse_args_and_arch(
    parser: argparse.ArgumentParser,
    input_args: List[str] = None,
    parse_known: bool = False,
    suppress_defaults: bool = False,
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None,
):
    """
    Args:
        parser (ArgumentParser): the parser
        input_args (List[str]): strings to parse, defaults to sys.argv
        parse_known (bool): only parse known arguments, similar to
            `ArgumentParser.parse_known_args`
        suppress_defaults (bool): parse while ignoring all default values
        modify_parser (Optional[Callable[[ArgumentParser], None]]):
            function to modify the parser, e.g., to set default values
    """
    if suppress_defaults:
        # Parse args without any default values. This requires us to parse
        # twice, once to identify all the necessary task/model args, and a second
        # time with all defaults set to None.
        args = parse_args_and_arch(
            parser,
            input_args=input_args,
            parse_known=parse_known,
            suppress_defaults=False,
        )
        suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser])
        suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()})
        args = suppressed_parser.parse_args(input_args)
        return argparse.Namespace(
            **{k: v for k, v in vars(args).items() if v is not None}
        )

    from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY

    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument("--user-dir", default=None)
    usr_args, _ = usr_parser.parse_known_args(input_args)
    utils.import_user_module(usr_args)

    if modify_parser is not None:
        modify_parser(parser)

    # The parser doesn't know about model/criterion/optimizer-specific args, so
    # we parse twice. First we parse the model/criterion/optimizer, then we
    # parse a second time after adding the *-specific arguments.
    # If input_args is given, we will parse those args instead of sys.argv.
    args, _ = parser.parse_known_args(input_args)

    # Add model-specific args to parser.
    if hasattr(args, "arch"):
        model_specific_group = parser.add_argument_group(
            "Model-specific configuration",
            # Only include attributes which are explicitly given as command-line
            # arguments or which have default values.
            argument_default=argparse.SUPPRESS,
        )
        ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)

    # Add *-specific args to parser.
    from fairseq.registry import REGISTRIES

    for registry_name, REGISTRY in REGISTRIES.items():
        choice = getattr(args, registry_name, None)
        if choice is not None:
            cls = REGISTRY["registry"][choice]
            if hasattr(cls, "add_args"):
                cls.add_args(parser)
    if hasattr(args, "task"):
        from fairseq.tasks import TASK_REGISTRY

        TASK_REGISTRY[args.task].add_args(parser)
    if getattr(args, "use_bmuf", False):
        # hack to support extra args for block distributed data parallelism
        from fairseq.optim.bmuf import FairseqBMUF

        FairseqBMUF.add_args(parser)

    # Modify the parser a second time, since defaults may have been reset
    if modify_parser is not None:
        modify_parser(parser)

    # Parse a second time.
    if parse_known:
        args, extra = parser.parse_known_args(input_args)
    else:
        args = parser.parse_args(input_args)
        extra = None

    # Post-process args.
    if hasattr(args, "max_sentences_valid") and args.max_sentences_valid is None:
        args.max_sentences_valid = args.max_sentences
    if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None:
        args.max_tokens_valid = args.max_tokens
    if getattr(args, "memory_efficient_fp16", False):
        args.fp16 = True
    if getattr(args, "memory_efficient_bf16", False):
        args.bf16 = True
    args.tpu = getattr(args, "tpu", False)
    args.bf16 = getattr(args, "bf16", False)
    if args.bf16:
        args.tpu = True
    if args.tpu and args.fp16:
        raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs")

    if getattr(args, "seed", None) is None:
        args.seed = 1  # default seed for training
        args.no_seed_provided = True
    else:
        args.no_seed_provided = False

    # Apply architecture configuration.
    if hasattr(args, "arch"):
        ARCH_CONFIG_REGISTRY[args.arch](args)

    if parse_known:
        return args, extra
    else:
        return args
Пример #29
0
def main(args):
    import_user_module(args)

    print(args)
    os.makedirs(args.destdir, exist_ok=True)
    target = not args.only_source

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    if args.joined_dictionary:
        assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary"
        assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary"
        src_dict = build_dictionary(
            {
                train_path(lang)
                for lang in [args.source_lang, args.target_lang]
            },
            args.workers,
        )
        tgt_dict = src_dict
    else:
        if args.srcdict:
            src_dict = dictionary.Dictionary.load(args.srcdict)
        else:
            assert (args.trainpref
                    ), "--trainpref must be set if --srcdict is not specified"
            src_dict = build_dictionary([train_path(args.source_lang)],
                                        args.workers)
        if target:
            if args.tgtdict:
                tgt_dict = dictionary.Dictionary.load(args.tgtdict)
            else:
                assert (
                    args.trainpref
                ), "--trainpref must be set if --tgtdict is not specified"
                tgt_dict = build_dictionary([train_path(args.target_lang)],
                                            args.workers)

    src_dict.finalize(
        threshold=args.thresholdsrc,
        nwords=args.nwordssrc,
        padding_factor=args.padding_factor,
    )
    src_dict.save(dict_path(args.source_lang))
    if target:
        if not args.joined_dictionary:
            tgt_dict.finalize(
                threshold=args.thresholdtgt,
                nwords=args.nwordstgt,
                padding_factor=args.padding_factor,
            )
        tgt_dict.save(dict_path(args.target_lang))

    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
        dict = dictionary.Dictionary.load(dict_path(lang))
        print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        dict,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, "bin"))
        merge_result(
            Tokenizer.binarize(input_file,
                               dict,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            dict.unk_word,
        ))

    def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
        if args.output_format == "binary":
            make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
        elif args.output_format == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix +
                ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)

    def make_all(lang):
        if args.trainpref:
            make_dataset(args.trainpref,
                         "train",
                         lang,
                         num_workers=args.workers)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(validpref, outprefix, lang)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(testpref, outprefix, lang)

    make_all(args.source_lang)
    if target:
        make_all(args.target_lang)

    print("| Wrote preprocessed data to {}".format(args.destdir))

    if args.alignfile:
        assert args.trainpref, "--trainpref must be set if --alignfile is specified"
        src_file_name = train_path(args.source_lang)
        tgt_file_name = train_path(args.target_lang)
        src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
        tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
        freq_map = {}
        with open(args.alignfile, "r") as align_file:
            with open(src_file_name, "r") as src_file:
                with open(tgt_file_name, "r") as tgt_file:
                    for a, s, t in zip_longest(align_file, src_file, tgt_file):
                        si = Tokenizer.tokenize(s,
                                                src_dict,
                                                add_if_not_exist=False)
                        ti = Tokenizer.tokenize(t,
                                                tgt_dict,
                                                add_if_not_exist=False)
                        ai = list(map(lambda x: tuple(x.split("-")),
                                      a.split()))
                        for sai, tai in ai:
                            srcidx = si[int(sai)]
                            tgtidx = ti[int(tai)]
                            if srcidx != src_dict.unk(
                            ) and tgtidx != tgt_dict.unk():
                                assert srcidx != src_dict.pad()
                                assert srcidx != src_dict.eos()
                                assert tgtidx != tgt_dict.pad()
                                assert tgtidx != tgt_dict.eos()

                                if srcidx not in freq_map:
                                    freq_map[srcidx] = {}
                                if tgtidx not in freq_map[srcidx]:
                                    freq_map[srcidx][tgtidx] = 1
                                else:
                                    freq_map[srcidx][tgtidx] += 1

        align_dict = {}
        for srcidx in freq_map.keys():
            align_dict[srcidx] = max(freq_map[srcidx],
                                     key=freq_map[srcidx].get)

        with open(
                os.path.join(
                    args.destdir,
                    "alignment.{}-{}.txt".format(args.source_lang,
                                                 args.target_lang),
                ),
                "w",
        ) as f:
            for k, v in align_dict.items():
                print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
Пример #30
0
def main(args, override_args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)
    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    print(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    criterion.eval()

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build trainer
    trainer = validator(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator

    valid_subsets = args.valid_subset.split(',')

    valid_losses = validate(args, trainer, task, valid_subsets)