def main(args: argparse.Namespace): task = FairseqTask(None) dictionary = task.build_dictionary(filenames=args.filenames, workers=args.workers, threshold=args.threshold, nwords=args.nwords, padding_factor=args.padding_factor) dictionary.save(args.dict_out)
def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset_idx, subset in enumerate(subsets): logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for i, sample in enumerate(progress): if (cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps): break trainer.valid_step(sample) # log validation stats # only tracking the best metric on the 1st validation subset tracking_best = subset_idx == 0 stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(), tracking_best) if hasattr(task, "post_validate"): task.post_validate(trainer.get_model(), stats, agg) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses
def reduce_metrics(self, logging_outputs, criterion): return FairseqTask.reduce_metrics(self, logging_outputs, criterion)
def valid_step(self, sample, model, criterion): return FairseqTask.valid_step(self, sample, model, criterion)
def train( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr ) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), azureml_logging=( cfg.common.azureml_logging if distributed_utils.is_master(cfg.distributed_training) else False ), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) if hasattr(trainer.criterion, "set_epoch"): trainer.criterion.set_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") # update the state prior stored in the model for cross-entropy training of hybrid systems if hasattr(task, "update_state_prior"): task.update_state_prior(trainer.get_model()) end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def sari_validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: from pathlib import Path from access.resources.paths import get_data_filepath from access.utils.helpers import read_lines from access.preprocessors import load_preprocessors, ComposedPreprocessor from easse.report import get_all_scores from fairseq.data import encoders from fairseq_cli.interactive import buffered_read, make_batches from fairseq_cli.generate import get_symbols_to_strip_from_output from fairseq.token_generation_constraints import pack_constraints, unpack_constraints import tempfile use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation task = tasks.setup_task(cfg.task) # TODO: Choose parameters for the preprocessors ? # 从pickle文件读取preprocessor # preprocessors = load_preprocessors(Path(cfg.task.data).parent) # composed_preprocessor = ComposedPreprocessor(preprocessors) # 获得turkcorpus.valid.complex的路径 complex_filepath = get_data_filepath('turkcorpus', 'valid', 'complex') # make temp dir # encoded_complex_filepath = tempfile.mkstemp()[1] # encoded_pred_filepath = tempfile.mkstemp()[1] pred_filepath = tempfile.mkstemp()[1] # use preprocessors to encode complex file # composed_preprocessor.encode_file(complex_filepath, encoded_complex_filepath) max_positions = utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ) parser = options.get_generation_parser(interactive=True) # TODO: Take args from fairseq_generate gen_args = options.parse_args_and_arch( parser, input_args=['/dummy_data', '--beam', '2']) # Initialize generator generator = task.build_generator([trainer.model], gen_args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(cfg.tokenizer) bpe = encoders.build_bpe(cfg.bpe) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x align_dict = utils.load_align_dict(cfg.generation.replace_unk) with open(pred_filepath, 'w') as f: start_id = 0 for inputs in buffered_read(complex_filepath, buffer_size=9999): results = [] for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translations = task.inference_step(generator, [trainer.model], sample, constraints=constraints) list_constraints = [[] for _ in range(bsz)] if cfg.generation.constraints: list_constraints = [ unpack_constraints(c) for c in constraints ] for i, (id, hypos) in enumerate( zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append(( start_id + id, src_tokens_i, hypos, { "constraints": constraints, }, )) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) for constraint in info["constraints"]: pass # Process top predictions for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore= get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) # detokenized hypothesis f.write(f'{detok_hypo_str}\n') if cfg.generation.print_alignment: alignment_str = " ".join([ "{}-{}".format(src, tgt) for src, tgt in alignment ]) # update running id_ counter start_id += len(inputs) # composed_preprocessor.decode_file(encoded_pred_filepath, pred_filepath) ref_filepaths = [ get_data_filepath('turkcorpus', 'valid', 'simple.turk', i) for i in range(8) ] scores = get_all_scores( read_lines(complex_filepath), read_lines(pred_filepath), [read_lines(ref_filepath) for ref_filepath in ref_filepaths]) print(f'num_updates={trainer.get_num_updates()}') print(f'ts_scores={scores}') sari = scores['SARI'] if not hasattr(trainer, 'best_sari'): trainer.best_sari = 0 if not hasattr(trainer, 'n_validations_since_best'): trainer.n_validations_since_best = 0 if sari > trainer.best_sari: trainer.best_sari = sari trainer.n_validations_since_best = 0 else: trainer.n_validations_since_best += 1 print( f'SARI did not improve for {trainer.n_validations_since_best} validations' ) # Does not work because scheduler will set it to previous value everytime # trainer.optimizer.set_lr(0.75 * trainer.optimizer.get_lr()) if trainer.n_validations_since_best >= cfg.validations_before_sari_early_stopping: print( f'Early stopping because SARI did not improve for {trainer.n_validations_since_best} validations' ) trainer.early_stopping = True def is_abort(epoch_itr, best_sari): if (epoch_itr.epoch >= 2 and best_sari < 19): return True if (epoch_itr.epoch >= 5 and best_sari < 22): return True if (epoch_itr.epoch >= 10 and best_sari < 25): return True return False # if is_abort(epoch_itr, best_sari): # print(f'Early stopping because best SARI is too low ({best_sari:.2f}) after {epoch_itr.epoch} epochs.') # # Remove the checkpoint directory as we got nothing interesting # shutil.rmtree(args.save_dir) # # TODO: Abort return [-sari]