def main(args: argparse.Namespace):
    task = FairseqTask(None)
    dictionary = task.build_dictionary(filenames=args.filenames,
                                       workers=args.workers,
                                       threshold=args.threshold,
                                       nwords=args.nwords,
                                       padding_factor=args.padding_factor)
    dictionary.save(args.dict_out)
예제 #2
0
파일: train.py 프로젝트: ajesujoba/fairseq
def validate(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    subsets: List[str],
) -> List[Optional[float]]:
    """Evaluate the model on the validation set(s) and return the losses."""

    if cfg.dataset.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(cfg.dataset.fixed_validation_seed)

    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    for subset_idx, subset in enumerate(subsets):
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(
            shuffle=False,
            set_dataset_epoch=False  # use a fixed valid set
        )
        if cfg.common.tpu:
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(cfg.common.tensorboard_logdir
                                if distributed_utils.is_master(
                                    cfg.distributed_training) else None),
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
            wandb_project=(cfg.common.wandb_project
                           if distributed_utils.is_master(
                               cfg.distributed_training) else None),
            wandb_run_name=os.environ.get(
                "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for i, sample in enumerate(progress):
                if (cfg.dataset.max_valid_steps is not None
                        and i > cfg.dataset.max_valid_steps):
                    break
                trainer.valid_step(sample)

        # log validation stats
        # only tracking the best metric on the 1st validation subset
        tracking_best = subset_idx == 0
        stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(),
                                tracking_best)

        if hasattr(task, "post_validate"):
            task.post_validate(trainer.get_model(), stats, agg)

        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
    return valid_losses
예제 #3
0
 def reduce_metrics(self, logging_outputs, criterion):
     return FairseqTask.reduce_metrics(self, logging_outputs, criterion)
예제 #4
0
 def valid_step(self, sample, model, criterion):
     return FairseqTask.valid_step(self, sample, model, criterion)
예제 #5
0
def train(
    cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (
        cfg.optimization.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(cfg.optimization.update_freq)
        else cfg.optimization.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            cfg.common.tensorboard_logdir
            if distributed_utils.is_master(cfg.distributed_training)
            else None
        ),
        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(
            cfg.common.wandb_project
            if distributed_utils.is_master(cfg.distributed_training)
            else None
        ),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
        ),
        azureml_logging=(
            cfg.common.azureml_logging
            if distributed_utils.is_master(cfg.distributed_training)
            else False
        ),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    if hasattr(trainer.criterion, "set_epoch"):
        trainer.criterion.set_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    for i, samples in enumerate(progress):
        with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
            "train_step-%d" % i
        ):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        # update the state prior stored in the model for cross-entropy training of hybrid systems
        if hasattr(task, "update_state_prior"):
            task.update_state_prior(trainer.get_model())

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(
            cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
        )

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
예제 #6
0
def sari_validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
                  epoch_itr, subsets: List[str]) -> List[Optional[float]]:
    from pathlib import Path
    from access.resources.paths import get_data_filepath
    from access.utils.helpers import read_lines
    from access.preprocessors import load_preprocessors, ComposedPreprocessor
    from easse.report import get_all_scores
    from fairseq.data import encoders
    from fairseq_cli.interactive import buffered_read, make_batches
    from fairseq_cli.generate import get_symbols_to_strip_from_output
    from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
    import tempfile

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # TODO: Choose parameters for the preprocessors ?
    # 从pickle文件读取preprocessor
    # preprocessors = load_preprocessors(Path(cfg.task.data).parent)
    # composed_preprocessor = ComposedPreprocessor(preprocessors)
    # 获得turkcorpus.valid.complex的路径
    complex_filepath = get_data_filepath('turkcorpus', 'valid', 'complex')
    # make temp dir
    # encoded_complex_filepath = tempfile.mkstemp()[1]
    # encoded_pred_filepath = tempfile.mkstemp()[1]
    pred_filepath = tempfile.mkstemp()[1]
    # use preprocessors to encode complex file
    # composed_preprocessor.encode_file(complex_filepath, encoded_complex_filepath)
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        trainer.get_model().max_positions(),
    )
    parser = options.get_generation_parser(interactive=True)
    # TODO: Take args from fairseq_generate
    gen_args = options.parse_args_and_arch(
        parser, input_args=['/dummy_data', '--beam', '2'])
    # Initialize generator
    generator = task.build_generator([trainer.model], gen_args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    with open(pred_filepath, 'w') as f:
        start_id = 0
        for inputs in buffered_read(complex_filepath, buffer_size=9999):
            results = []
            for batch in make_batches(inputs, cfg, task, max_positions,
                                      encode_fn):
                bsz = batch.src_tokens.size(0)
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                constraints = batch.constraints
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()
                    if constraints is not None:
                        constraints = constraints.cuda()
                sample = {
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": src_lengths,
                    },
                }
                translations = task.inference_step(generator, [trainer.model],
                                                   sample,
                                                   constraints=constraints)
                list_constraints = [[] for _ in range(bsz)]
                if cfg.generation.constraints:
                    list_constraints = [
                        unpack_constraints(c) for c in constraints
                    ]
                for i, (id, hypos) in enumerate(
                        zip(batch.ids.tolist(), translations)):
                    src_tokens_i = utils.strip_pad(src_tokens[i],
                                                   tgt_dict.pad())
                    constraints = list_constraints[i]
                    results.append((
                        start_id + id,
                        src_tokens_i,
                        hypos,
                        {
                            "constraints": constraints,
                        },
                    ))

            # sort output to match input order
            for id_, src_tokens, hypos, info in sorted(results,
                                                       key=lambda x: x[0]):
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)
                    for constraint in info["constraints"]:
                        pass

                # Process top predictions
                for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=cfg.common_eval.post_process,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )
                    detok_hypo_str = decode_fn(hypo_str)
                    # detokenized hypothesis
                    f.write(f'{detok_hypo_str}\n')
                    if cfg.generation.print_alignment:
                        alignment_str = " ".join([
                            "{}-{}".format(src, tgt) for src, tgt in alignment
                        ])

            # update running id_ counter
            start_id += len(inputs)

        # composed_preprocessor.decode_file(encoded_pred_filepath, pred_filepath)
        ref_filepaths = [
            get_data_filepath('turkcorpus', 'valid', 'simple.turk', i)
            for i in range(8)
        ]
        scores = get_all_scores(
            read_lines(complex_filepath), read_lines(pred_filepath),
            [read_lines(ref_filepath) for ref_filepath in ref_filepaths])
        print(f'num_updates={trainer.get_num_updates()}')
        print(f'ts_scores={scores}')
        sari = scores['SARI']
        if not hasattr(trainer, 'best_sari'):
            trainer.best_sari = 0
        if not hasattr(trainer, 'n_validations_since_best'):
            trainer.n_validations_since_best = 0
        if sari > trainer.best_sari:
            trainer.best_sari = sari
            trainer.n_validations_since_best = 0
        else:
            trainer.n_validations_since_best += 1
            print(
                f'SARI did not improve for {trainer.n_validations_since_best} validations'
            )
            # Does not work because scheduler will set it to previous value everytime
            # trainer.optimizer.set_lr(0.75 * trainer.optimizer.get_lr())
            if trainer.n_validations_since_best >= cfg.validations_before_sari_early_stopping:
                print(
                    f'Early stopping because SARI did not improve for {trainer.n_validations_since_best} validations'
                )
                trainer.early_stopping = True

            def is_abort(epoch_itr, best_sari):
                if (epoch_itr.epoch >= 2 and best_sari < 19):
                    return True
                if (epoch_itr.epoch >= 5 and best_sari < 22):
                    return True
                if (epoch_itr.epoch >= 10 and best_sari < 25):
                    return True
                return False

            # if is_abort(epoch_itr, best_sari):
            #     print(f'Early stopping because best SARI is too low ({best_sari:.2f}) after {epoch_itr.epoch} epochs.')
            #     # Remove the checkpoint directory as we got nothing interesting
            #     shutil.rmtree(args.save_dir)
            #     # TODO: Abort
    return [-sari]