def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for i, sample in enumerate(progress): if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: break trainer.valid_step(sample) # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses
def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, cur_step, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) # 需要通过metrics.log_scalar("key", val)添加到metrics里面,才能在agg中显示出来log with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample, cur_step=cur_step) # import pdb # pdb.set_trace() # log validation stats # stats里面已经有了agg.get_smoothed_values()这个orderedDict作为基础,通过get_valid_stats函数获得一些其他的state值 stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) # print("In fairseq/fairseq_cli/train.py line 400:\n{} not in stats".format(cfg.checkpoint.best_checkpoint_metric)) return valid_losses