Пример #1
0
def get_parser(desc, default_task="translation"):
    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument("--user-dir", default=None)
    usr_args, _ = usr_parser.parse_known_args()
    import_user_module(usr_args)

    parser = argparse.ArgumentParser(allow_abbrev=False)
    gen_parser_from_dataclass(parser, CommonParams())

    from fairseq.registry import REGISTRIES

    for registry_name, REGISTRY in REGISTRIES.items():
        parser.add_argument(
            "--" + registry_name.replace("_", "-"),
            default=REGISTRY["default"],
            choices=REGISTRY["registry"].keys(),
        )

    # Task definitions can be found under fairseq/tasks/
    from fairseq.tasks import TASK_REGISTRY

    parser.add_argument(
        "--task",
        metavar="TASK",
        default=default_task,
        choices=TASK_REGISTRY.keys(),
        help="task",
    )
    # fmt: on
    return parser
Пример #2
0
def main(args):

    # Load dataset
    import_user_module(args)
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)

    # Get iterator over batches
    batch_index_iterator = get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=None,
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
        large_sent_first=False)

    # collate batch of sentences into single tensor for all data
    for batch_ids in tqdm(batch_index_iterator):
        samples = [dataset[i] for i in batch_ids]
        dataset.collater(samples)
Пример #3
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # ========== for bartsv task, rebuild dictionary after model args are loaded ==========
    # assert not hasattr(args, 'node_freq_min'), 'node_freq_min should be read from model args'
    # args.node_freq_min = 5    # temporarily set before model loading, as this is needed in tasks.setup_task(args)
    # =====================================================================================

    # Load dataset splits
    task = tasks.setup_task(args)
    # Note: states are not needed since they will be provided by the state
    # machine
    task.load_dataset(args.gen_subset, state_machine=False)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    try:
        models, _model_args = checkpoint_utils.load_model_ensemble(
            args.path.split(':'),
            arg_overrides=eval(args.model_overrides),
            task=task,
        )
    except:
        # NOTE this is for "bartsv" models when default "args.node_freq_min" (5) is not equal to the model
        #      when loading model with the above task there will be an error when building the model with the task's
        #      target vocabulary, which would be of different size
        # TODO better handle these cases (without sacrificing compatibility with other model archs)
        models, _model_args = checkpoint_utils.load_model_ensemble(
            args.path.split(':'),
            arg_overrides=eval(args.model_overrides),
            task=None,
        )

    # ========== for bartsv task, rebuild the dictionary based on model args ==========
    if 'bartsv' in _model_args.arch and args.node_freq_min != _model_args.node_freq_min:
        args.node_freq_min = _model_args.node_freq_min
        # Load dataset splits
        task = tasks.setup_task(args)
        # Note: states are not needed since they will be provided by the state machine
        task.load_dataset(args.gen_subset, state_machine=False)

        # Set dictionaries
        try:
            src_dict = getattr(task, 'source_dictionary', None)
        except NotImplementedError:
            src_dict = None
        tgt_dict = task.target_dictionary
    # ==================================================================================

    # import pdb; pdb.set_trace()
    # print(_model_args)

    # ========== for previous model trained when new arguments were not there ==========
    if not hasattr(_model_args, 'shift_pointer_value'):
        _model_args.shift_pointer_value = 1
    # ==================================================================================

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align
    # dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=None,
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
        # large_sent_first=False        # not in fairseq
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args, _model_args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True

    examples = Examples(args.path, args.results_path, args.gen_subset,
                        args.nbest)

    error_stats = {'num_sub_start': 0}

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                raise Exception("Did not expect empty sample")
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            # breakpoint()

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, args,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            # breakpoint()

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                # debug: '<<unk>>' is added to the dictionary
                # if 'unk' in target_str:
                #     breakpoint()
                # ==========> NOTE we do not really have the ground truth target (with the same alignments)
                #                  target_str might have <unk> as the target dictionary is only built on training data
                #                  but it doesn't matter. It should not affect the target dictionary!

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    # hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    #     hypo_tokens=hypo['tokens'].int().cpu(),
                    #     src_str=src_str,
                    #     alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    #     align_dict=align_dict,
                    #     tgt_dict=tgt_dict,
                    #     remove_bpe=args.remove_bpe,
                    #     # FIXME: AMR specific
                    #     split_token="\t",
                    #     line_tokenizer=task.tokenize,
                    # )

                    if 'bartsv' in _model_args.arch:
                        if not tgt_dict[hypo['tokens'][0]].startswith(
                                tgt_dict.bpe.INIT):
                            error_stats['num_sub_start'] += 1

                        try:
                            actions_nopos, actions_pos, actions = post_process_action_pointer_prediction_bartsv(
                                hypo, tgt_dict)
                        except:
                            breakpoint()
                    else:
                        actions_nopos, actions_pos, actions = post_process_action_pointer_prediction(
                            hypo, tgt_dict)

                    # breakpoint()

                    if args.clean_arcs:
                        actions_nopos, actions_pos, actions, invalid_idx = clean_pointer_arcs(
                            actions_nopos, actions_pos, actions)

                    # TODO these are just dummy for the reference below to run
                    hypo_tokens = hypo['tokens'].int().cpu()
                    hypo_str = '/t'.join(actions)
                    alignment = None

                    # update the list of examples
                    examples.append({
                        'actions_nopos': actions_nopos,
                        'actions_pos': actions_pos,
                        'actions': actions,
                        'reference': target_str,
                        'src_str': src_str,
                        'sample_id': sample_id
                    })

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo_str,
                                                    hypo['score']))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=False)
                            # NOTE do not modify the tgt dictionary with 'add_if_not_exist=True'!
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    # Save examples to files
    examples.save()

    print('| Error case (handled by manual fix) statistics:')
    print(error_stats)

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
    return scorer
Пример #4
0
def parse_args_and_arch(
    parser: argparse.ArgumentParser,
    input_args: List[str] = None,
    parse_known: bool = False,
    suppress_defaults: bool = False,
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None,
):
    """
    Args:
        parser (ArgumentParser): the parser
        input_args (List[str]): strings to parse, defaults to sys.argv
        parse_known (bool): only parse known arguments, similar to
            `ArgumentParser.parse_known_args`
        suppress_defaults (bool): parse while ignoring all default values
        modify_parser (Optional[Callable[[ArgumentParser], None]]):
            function to modify the parser, e.g., to set default values
    """
    if suppress_defaults:
        # Parse args without any default values. This requires us to parse
        # twice, once to identify all the necessary task/model args, and a second
        # time with all defaults set to None.
        args = parse_args_and_arch(
            parser,
            input_args=input_args,
            parse_known=parse_known,
            suppress_defaults=False,
        )
        suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser])
        suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()})
        args = suppressed_parser.parse_args(input_args)
        return argparse.Namespace(
            **{k: v for k, v in vars(args).items() if v is not None}
        )

    from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY

    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument("--user-dir", default=None)
    usr_args, _ = usr_parser.parse_known_args(input_args)
    import_user_module(usr_args)

    if modify_parser is not None:
        modify_parser(parser)

    # The parser doesn't know about model/criterion/optimizer-specific args, so
    # we parse twice. First we parse the model/criterion/optimizer, then we
    # parse a second time after adding the *-specific arguments.
    # If input_args is given, we will parse those args instead of sys.argv.
    args, _ = parser.parse_known_args(input_args)

    # Add model-specific args to parser.
    if hasattr(args, "arch"):
        model_specific_group = parser.add_argument_group(
            "Model-specific configuration",
            # Only include attributes which are explicitly given as command-line
            # arguments or which have default values.
            argument_default=argparse.SUPPRESS,
        )
        if args.arch in ARCH_MODEL_REGISTRY:
            ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
        elif args.arch in MODEL_REGISTRY:
            MODEL_REGISTRY[args.arch].add_args(model_specific_group)
        else:
            raise RuntimeError()

    # Add *-specific args to parser.
    from fairseq.registry import REGISTRIES

    for registry_name, REGISTRY in REGISTRIES.items():
        choice = getattr(args, registry_name, None)
        if choice is not None:
            cls = REGISTRY["registry"][choice]
            if hasattr(cls, "add_args"):
                cls.add_args(parser)
    if hasattr(args, "task"):
        from fairseq.tasks import TASK_REGISTRY

        TASK_REGISTRY[args.task].add_args(parser)
    if getattr(args, "use_bmuf", False):
        # hack to support extra args for block distributed data parallelism
        from fairseq.optim.bmuf import FairseqBMUF

        FairseqBMUF.add_args(parser)

    # Modify the parser a second time, since defaults may have been reset
    if modify_parser is not None:
        modify_parser(parser)

    # Parse a second time.
    if parse_known:
        args, extra = parser.parse_known_args(input_args)
    else:
        args = parser.parse_args(input_args)
        extra = None
    # Post-process args.
    if (
        hasattr(args, "batch_size_valid") and args.batch_size_valid is None
    ) or not hasattr(args, "batch_size_valid"):
        args.batch_size_valid = args.batch_size
    if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None:
        args.max_tokens_valid = args.max_tokens
    if getattr(args, "memory_efficient_fp16", False):
        args.fp16 = True
    if getattr(args, "memory_efficient_bf16", False):
        args.bf16 = True
    args.tpu = getattr(args, "tpu", False)
    args.bf16 = getattr(args, "bf16", False)
    if args.bf16:
        args.tpu = True
    if args.tpu and args.fp16:
        raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs")

    if getattr(args, "seed", None) is None:
        args.seed = 1  # default seed for training
        args.no_seed_provided = True
    else:
        args.no_seed_provided = False

    # Apply architecture configuration.
    if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY:
        ARCH_CONFIG_REGISTRY[args.arch](args)

    if parse_known:
        return args, extra
    else:
        return args
Пример #5
0
def main(args):
    import_user_module(args)

    assert (
        args.max_tokens is not None or args.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info("criterion: {} ({})".format(args.criterion,
                                            criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # breakpoint()

    # ========== initialize the model with pretrained BART parameters ==========
    # for shared embeddings and subtoken split for amr nodes
    if 'bartsv' in args.arch:

        if args.initialize_with_bart:
            logger.info(
                '-' * 10 +
                ' initializing model parameters with pretrained BART model ' +
                '-' * 10)

            new_state_dict = copy.deepcopy(task.bart.model.state_dict())
            # treat the embedding initialization separately later, as the size different
            logger.info(
                '-' * 10 +
                ' delay encoder embeddings, decoder input and output embeddings initialization '
                + '-' * 10)
            ignore_keys = set([
                'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight',
                'decoder.output_projection.weight'
            ])
            for k in ignore_keys:
                del new_state_dict[k]

            if not args.initialize_with_bart_enc:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART encoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('encoder'):
                        del new_state_dict[k]

            if not args.initialize_with_bart_dec:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART decoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('decoder'):
                        del new_state_dict[k]

            model.load_state_dict(new_state_dict, strict=False, args=args)

            # initialize the Bart part embeddings
            bart_vocab_size = task.target_dictionary.bart_vocab_size
            # NOTE we need to prune the pretrained BART embeddings, especially for bart.base
            bart_embed_weight = task.bart.model.encoder.embed_tokens.weight.data[:
                                                                                 bart_vocab_size]
            assert len(bart_embed_weight) == bart_vocab_size

            with torch.no_grad():
                model.encoder.embed_tokens.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)
                model.decoder.embed_tokens.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)
                model.decoder.output_projection.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)

        if args.bart_emb_init_composition:
            logger.info(
                '-' * 10 +
                ' initialize extended target embeddings with compositional embeddings '
                'from BART vocabulary ' + '-' * 10)

            # breakpoint()
            symbols = [
                task.target_dictionary[idx]
                for idx in range(bart_vocab_size, len(task.target_dictionary))
            ]
            mapper = MapAvgEmbeddingBART(task.bart,
                                         task.bart.model.decoder.embed_tokens)
            comp_embed_weight, map_all = mapper.map_avg_embeddings(
                symbols, transform=transform_action_symbol, add_noise=False)
            assert len(comp_embed_weight) == len(symbols)

            with torch.no_grad():
                model.encoder.embed_tokens.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)
                model.decoder.embed_tokens.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)
                model.decoder.output_projection.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)

    elif 'bart' in args.arch:

        if args.initialize_with_bart:
            logger.info(
                '-' * 10 +
                ' initializing model parameters with pretrained BART model ' +
                '-' * 10)

            new_state_dict = copy.deepcopy(task.bart.model.state_dict())
            if not args.bart_emb_decoder:
                logger.info('-' * 10 +
                            ' build a separate decoder dictionary embedding ' +
                            '-' * 10)
                if not args.bart_emb_decoder_input:
                    ignore_keys = set([
                        'decoder.embed_tokens.weight',
                        'decoder.output_projection.weight'
                    ])
                else:
                    logger.info(
                        '-' * 10 +
                        ' use BART dictionary embedding for target input ' +
                        '-' * 10)
                    ignore_keys = set(['decoder.output_projection.weight'])
                for k in ignore_keys:
                    del new_state_dict[k]

            if not args.initialize_with_bart_enc:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART encoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('encoder'):
                        del new_state_dict[k]

            if not args.initialize_with_bart_dec:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART decoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('decoder'):
                        del new_state_dict[k]

            model.load_state_dict(new_state_dict, strict=False, args=args)

        # initialize the target embeddings with average of subtoken embeddings in BART vocabulary
        if args.bart_emb_init_composition:
            assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of BART vocabulary here'
            logger.info(
                '-' * 10 +
                ' initialize target embeddings with compositional embeddings from BART vocabulary '
                + '-' * 10)
            composite_embed = CompositeEmbeddingBART(
                task.bart, task.bart.model.decoder.embed_tokens,
                task.target_dictionary)
            if args.bart_emb_decoder_input:
                # only initialize the decoder output embeddings
                with torch.no_grad():
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)
            else:
                # initialize both the decoder input and output embeddings
                with torch.no_grad():
                    model.decoder.embed_tokens.weight.copy_(
                        composite_embed.embedding_weight)
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)

    elif 'roberta' in args.arch:
        # initialize the target embeddings with average of subtoken embeddings in BART vocabulary
        if args.bart_emb_init_composition:
            assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of RoBERTa vocabulary here'
            logger.info(
                '-' * 10 +
                ' initialize target embeddings with compositional embeddings from RoBERTa vocabulary '
                + '-' * 10)
            composite_embed = CompositeEmbeddingBART(
                task.bart,  # NOTE here "bart" means roberta
                task.bart.model.encoder.sentence_encoder.embed_tokens,
                task.target_dictionary)

            if args.bart_emb_decoder_input:
                # only initialize the decoder output embeddings
                with torch.no_grad():
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)
            else:
                # initialize both the decoder input and output embeddings
                with torch.no_grad():
                    model.decoder.embed_tokens.weight.copy_(
                        composite_embed.embedding_weight)
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)

    else:
        raise ValueError
    # ==========================================================================

    # breakpoint()

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.batch_size))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        args,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(args, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def main(args):
    import_user_module(args)

    print(args)

    # to control what preprocessing needs to be run (as they take both time and storage so we avoid running repeatedly)
    run_basic = True
    # this includes:
    # src: build src dictionary, copy the raw data to dir; build src binary data (need to refactor later if unneeded)
    # tgt: split target non-pointer actions and pointer values into separate files; build tgt dictionary
    run_act_states = True
    # this includes:
    # run the state machine reformer to get
    # a) training data: input and output, pointer values;
    # b) states information to facilitate modeling;
    # takes about 1 hour and 13G space on CCC
    run_roberta_emb = True
    # this includes:
    # for src sentences, use pre-trained RoBERTa model to extract contextual embeddings for each word;
    # takes about 10min for RoBERTa base and 30 mins for RoBERTa large and 2-3G space;
    # this needs GPU and only needs to run once for the English sentences, which does not change for different oracles;
    # thus the embeddings are stored separately from the oracles.

    if os.path.exists(args.destdir):
        print(f'binarized actions and states directory {args.destdir} already exists; not rerunning.')
        run_basic = False
        run_act_states = False
    if os.path.exists(args.embdir):
        print(f'pre-trained embedding directory {args.embdir} already exists; not rerunning.')
        run_roberta_emb = False

    os.makedirs(args.destdir, exist_ok=True)
    os.makedirs(args.embdir, exist_ok=True)
    target = not args.only_source

    task = tasks.get_task(args.task)

    # preprocess target actions files, to split '.actions' to '.actions_nopos' and '.actions_pos'
    # when building dictionary on the target actions sequences
    # split the action file into two files, one without arc pointer and one with only arc pointer values
    # and the dictionary is only built on the no pointer actions
    if run_basic:
        assert args.target_lang == 'actions', 'target extension must be "actions"'
        actions_files = [f'{pref}.{args.target_lang}' for pref in (args.trainpref, args.validpref, args.testpref)]
        task.split_actions_pointer_files(actions_files)
        args.target_lang_nopos = 'actions_nopos'    # only build dictionary without pointer values
        args.target_lang_pos = 'actions_pos'

    # set tokenizer
    tokenize = task.tokenize if hasattr(task, 'tokenize') else tokenize_line

    def train_path(lang):
        return "{}{}".format(args.trainpref, ("." + lang) if lang else "")

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args.destdir, file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    def build_dictionary(filenames, src=False, tgt=False):
        assert src ^ tgt

        return task.build_dictionary(
            filenames,
            workers=args.workers,
            threshold=args.thresholdsrc if src else args.thresholdtgt,
            nwords=args.nwordssrc if src else args.nwordstgt,
            padding_factor=args.padding_factor,
            # tokenize separator is taken care inside task
        )

    # build dictionary and save

    if run_basic:
        if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
            raise FileExistsError(dict_path(args.source_lang))
        if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
            raise FileExistsError(dict_path(args.target_lang))

        if args.joined_dictionary:
            assert not args.srcdict or not args.tgtdict, \
                "cannot use both --srcdict and --tgtdict with --joined-dictionary"

            if args.srcdict:
                src_dict = task.load_dictionary(args.srcdict)
            elif args.tgtdict:
                src_dict = task.load_dictionary(args.tgtdict)
            else:
                assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
                src_dict = build_dictionary(
                    {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
                )
            tgt_dict = src_dict
        else:
            if args.srcdict:
                src_dict = task.load_dictionary(args.srcdict)
            else:
                assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
                src_dict = build_dictionary([train_path(args.source_lang)], src=True)

            if target:
                if args.tgtdict:
                    tgt_dict = task.load_dictionary(args.tgtdict)
                else:
                    assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
                    tgt_dict = build_dictionary([train_path(args.target_lang_nopos)], tgt=True)
            else:
                tgt_dict = None

        src_dict.save(dict_path(args.source_lang))
        if target and tgt_dict is not None:
            tgt_dict.save(dict_path(args.target_lang_nopos))

    # save binarized preprocessed files

    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                        False,    # note here we shut off append eos
                        tokenize
                    ),
                    callback=merge_result
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"),
                                          impl=args.dataset_impl, vocab_size=len(vocab), dtype=np.int64)
        merge_result(
            Binarizer.binarize(
                input_file, vocab, lambda t: ds.add_item(t),
                offset=0, end=offsets[1],
                append_eos=False,
                tokenize=tokenize
            )
        )
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print(
            "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            )
        )

    def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1, dataset_impl=args.dataset_impl):
        if dataset_impl == "raw":
            # Copy original text file to destination folder
            output_text_file = dest_path(
                output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
                lang,
            )
            shutil.copyfile(file_name(input_prefix, lang), output_text_file)
        else:
            make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)

    def make_all(lang, vocab, dataset_impl=args.dataset_impl):
        if args.trainpref:
            make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers, dataset_impl=dataset_impl)
        if args.validpref:
            for k, validpref in enumerate(args.validpref.split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers, dataset_impl=dataset_impl)
        if args.testpref:
            for k, testpref in enumerate(args.testpref.split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers, dataset_impl=dataset_impl)

    # NOTE we do not encode the source sentences with dictionary, as the source embeddings are directly provided
    # from RoBERTa, thus the source dictionary here is of no use
    if run_basic:
        make_all(args.source_lang, src_dict, dataset_impl='raw')
        make_all(args.source_lang, src_dict, dataset_impl='mmap')
        # above: just leave for the sake of model to run without too much change
        # NOTE there are <unk> in valid and test set for target actions
        # if target:
        #     make_all(args.target_lang_nopos, tgt_dict)

        # NOTE targets (input, output, pointer values) are now all included in the state generation process

        # binarize pointer values and save to file

        # TODO make naming convention clearer
        # assume one training file, one validation file, and one test file
        # for pos_file, split in [(f'{pref}.actions_pos', split) for pref, split in
        #                         [(args.trainpref, 'train'), (args.validpref, 'valid'), (args.testpref, 'test')]]:
        #     out_pref = os.path.join(args.destdir, split)
        #     task.binarize_actions_pointer_file(pos_file, out_pref)

    # save action states information to assist training with auxiliary info
    # assume one training file, one validation file, and one test file
    if run_act_states:
        task_obj = task(args, tgt_dict=tgt_dict)
        for prefix, split in zip([args.trainpref, args.validpref, args.testpref], ['train', 'valid', 'test']):
            en_file = prefix + '.en'
            actions_file = prefix + '.actions'
            out_file_pref = os.path.join(args.destdir, split)
            task_obj.build_actions_states_info(en_file, actions_file, out_file_pref, num_workers=args.workers)

    # save RoBERTa embeddings
    # TODO refactor this code
    if run_roberta_emb:
        make_roberta_embeddings(args, tokenize=tokenize)

    print("| Wrote preprocessed oracle data to {}".format(args.destdir))
    print("| Wrote preprocessed embedding data to {}".format(args.embdir))
Пример #7
0
def get_parser(desc, default_task='translation'):
    # Before creating the true parser, we need to import optional user module
    # in order to eagerly import custom tasks, optimizers, architectures, etc.
    usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
    usr_parser.add_argument('--user-dir', default=None)
    usr_args, _ = usr_parser.parse_known_args()
    import_user_module(usr_args)

    parser = argparse.ArgumentParser(allow_abbrev=False)
    # fmt: off
    parser.add_argument('--no-progress-bar',
                        action='store_true',
                        help='disable progress bar')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=1000,
        metavar='N',
        help='log progress every N batches (when progress bar is disabled)')
    parser.add_argument('--log-format',
                        default=None,
                        help='log format to use',
                        choices=['json', 'none', 'simple', 'tqdm'])
    parser.add_argument(
        '--tensorboard-logdir',
        metavar='DIR',
        default='',
        help='path to save logs for tensorboard, should match --logdir '
        'of running tensorboard (default: no tensorboard logging)')
    parser.add_argument("--tbmf-wrapper",
                        action="store_true",
                        help="[FB only] ")
    parser.add_argument('--seed',
                        default=1,
                        type=int,
                        metavar='N',
                        help='pseudo random number generator seed')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='use CPU instead of CUDA')
    parser.add_argument('--fp16', action='store_true', help='use FP16')
    parser.add_argument(
        '--memory-efficient-fp16',
        action='store_true',
        help='use a memory-efficient version of FP16 training; implies --fp16')
    parser.add_argument('--fp16-init-scale',
                        default=2**7,
                        type=int,
                        help='default FP16 loss scale')
    parser.add_argument('--fp16-scale-window',
                        type=int,
                        help='number of updates before increasing loss scale')
    parser.add_argument(
        '--fp16-scale-tolerance',
        default=0.0,
        type=float,
        help='pct of updates that can overflow before decreasing the loss scale'
    )
    parser.add_argument(
        '--min-loss-scale',
        default=1e-4,
        type=float,
        metavar='D',
        help='minimum FP16 loss scale, after which training is stopped')
    parser.add_argument('--threshold-loss-scale',
                        type=float,
                        help='threshold FP16 loss scale from below')
    parser.add_argument(
        '--user-dir',
        default=None,
        help=
        'path to a python module containing custom extensions (tasks and/or architectures)'
    )
    parser.add_argument('--profile',
                        default=False,
                        action='store_true',
                        help='enable autograd profiler emit_nvtx')
    parser.add_argument('--model-parallel-size',
                        default=1,
                        type=int,
                        help='total number of GPUs to parallelize model over')
    parser.add_argument('--quantization-config-path',
                        default=None,
                        help='path to quantization config file')
    parser.add_argument('--bf16',
                        default=False,
                        action='store_true',
                        help='use bfloat16; implies --tpu')

    from fairseq.registry import REGISTRIES
    for registry_name, REGISTRY in REGISTRIES.items():
        parser.add_argument(
            '--' + registry_name.replace('_', '-'),
            default=REGISTRY['default'],
            choices=REGISTRY['registry'].keys(),
        )

    # Task definitions can be found under fairseq/tasks/
    from fairseq.tasks import TASK_REGISTRY
    parser.add_argument('--task',
                        metavar='TASK',
                        default=default_task,
                        choices=TASK_REGISTRY.keys(),
                        help='task')
    # fmt: on
    return parser