def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for i, sample in enumerate(progress): if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: break trainer.valid_step(sample) # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) # reset dummy batch only for validation trainer._dummy_batch = "DUMMY" # reset dummy batch valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for step, sample in enumerate(progress): trainer.valid_step(sample) stats = get_training_stats(agg.get_smoothed_values()) plog = progress.log if hasattr(progress, "wrapped_bar"): plog = progress.wrapped_bar.log plog(stats, tag='valid', step=step) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) # reset dummy batch again for continuing training trainer._dummy_batch = "DUMMY" return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) model = trainer.model val_conf = { "encoder": [{ "self_attn": [] } for i in range(args.encoder_layers)], "decoder": [{ "self_attn": [], "enc_attn": [] } for i in range(args.decoder_layers)] } valid_losses = [] for subset in subsets: # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # Get confidence for each head if args.head_confidence_method is not None: val_conf = get_batch_confs(model, val_conf, args) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) if args.head_confidence_method is not None: val_conf = convert_confs(val_conf, args) val_conf = calc_conf_per_epoch(val_conf, args) return valid_losses, val_conf
def train(args, trainer, task, epoch_itr, max_update=math.inf): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, 'tpu', False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) progress.log_args(args, tag='train') trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(',') for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') end_of_epoch = not itr.has_next() valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) all_preds, all_labels = [], [] with metrics.aggregate(new_root=True) as agg: for sample in progress: logging_outputs, preds, labels = trainer.valid_step(sample) if preds is not None: all_preds.extend(preds) all_labels.extend(labels) else: all_preds, all_labels = None, None if all_preds is not None: all_preds = torch.cat(all_preds).cpu().numpy() all_labels = torch.cat(all_labels).cpu().numpy() # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values(), all_preds, all_labels) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" assert args.max_sentences_valid == 1, 'Val only supports batch size 1!' if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] shuffle = False if args.validation_max_size > 0: logging.info(f'Validation set truncated to {args.validation_max_size}.') shuffle = True assert args.seed == 1234, args.seed for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=1,#args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=shuffle) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: i = 0 for sample in progress: i += 1 if args.validation_max_size > 0 and i > args.validation_max_size / args.distributed_world_size: continue trainer.valid_step(sample, validation_topk=args.validation_topk, validation_D=args.validation_D, validation_rounds=args.validation_rounds) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, cur_step, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) # 需要通过metrics.log_scalar("key", val)添加到metrics里面,才能在agg中显示出来log with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample, cur_step=cur_step) # import pdb # pdb.set_trace() # log validation stats # stats里面已经有了agg.get_smoothed_values()这个orderedDict作为基础,通过get_valid_stats函数获得一些其他的state值 stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) # print("In fairseq/fairseq_cli/train.py line 400:\n{} not in stats".format(cfg.checkpoint.best_checkpoint_metric)) return valid_losses
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): if (((self.pruning_every_steps is not None) and \ (update_num % self.pruning_every_steps == 0) and \ (update_num > 0)) or \ ((self.steps_to_prune_voxels is not None) and \ update_num in self.steps_to_prune_voxels) \ ) and \ (update_num > self._num_updates['pv']) and \ hasattr(model, 'prune_voxels'): model.eval() if getattr(self.args, "pruning_rerun_train_set", False): with torch.no_grad(): model.clean_caches(reset=True) progress = progress_bar.progress_bar( self._unique_trainitr.next_epoch_itr(shuffle=False), prefix=f"pruning based statiscs over training set", tensorboard_logdir=None, default_log_format=self.args.log_format if self.args.log_format is not None else "tqdm") for step, inner_sample in enumerate(progress): outs = model(**self._trainer._prepare_sample(self.filter_dummy(inner_sample))) progress.log(stats=outs['other_logs'], tag='track', step=step) model.prune_voxels(self.pruning_th, train_stats=getattr(self.args, "pruning_with_train_stats", False)) self.update_step(update_num, 'pv') if self.steps_to_half_voxels is not None and \ (update_num in self.steps_to_half_voxels) and \ (update_num > self._num_updates['sv']): model.split_voxels() self.update_step(update_num, 'sv') raise ResetTrainerException if self.rendering_every_steps is not None and \ (update_num % self.rendering_every_steps == 0) and \ (update_num > 0) and \ self.renderer is not None and \ (update_num > self._num_updates['re']): sample_clone = {key: sample[key].clone() if sample[key] is not None else None for key in sample } outputs = self.inference_step(self.renderer, [model], [sample_clone, 0])[1] if getattr(self.args, "distributed_rank", -1) == 0: # save only for master self.renderer.save_images(outputs, update_num) self.steps_to_half_voxels = [a for a in self.steps_to_half_voxels if a != update_num] if self.steps_to_reduce_step is not None and \ update_num in self.steps_to_reduce_step and \ (update_num > self._num_updates['rs']): model.reduce_stepsize() self.update_step(update_num, 'rs') self.update_step(update_num, 'step') return super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad)
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) if "fg_gloss0" in stats: criterion = trainer.get_criterion() ngroups = criterion.n_groups baselines = torch.zeros(ngroups, device='cuda') for ii in range(ngroups): key = "fg_gloss{}".format(ii) baselines[ii] = stats[key] stats.pop(key, None) criterion.set_valid_baselines(baselines) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) count = 0 with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) count += 1 if count % 50 == 0: logger.info("Processed {} batches!".format(count)) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def main(cfg: DictConfig, override_args=None): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if cfg.distributed_training.distributed_world_size > 1: data_parallel_world_size = distributed_utils.get_data_parallel_world_size( ) data_parallel_rank = distributed_utils.get_data_parallel_rank() else: data_parallel_world_size = 1 data_parallel_rank = 0 if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(saved_cfg) # Build criterion criterion = task.build_criterion(saved_cfg.criterion) criterion.eval() for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=cfg.dataset. skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset. required_batch_size_multiple, seed=cfg.common.seed, num_shards=data_parallel_world_size, shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, group=distributed_utils.get_data_parallel_group(), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if use_cuda: torch.cuda.set_device(args.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=getattr(args, 'all_gather_list_size', 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def main(cfg: DictConfig, override_args=None, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None logger.info(cfg) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Load dataset splits gen_subset = cfg.dataset.gen_subset task.load_dataset(gen_subset) dataset = task.dataset(gen_subset) if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=cfg.task.tokens_per_sample, context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) assert len(models) > 0 logger.info( "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) ) itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens or 36000, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), ignore_invalid_inputs=True, num_shards=max( cfg.dataset.num_shards, cfg.distributed_training.distributed_world_size, ), shard_id=max( cfg.dataset.shard_id, cfg.distributed_training.distributed_rank, ), num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 if cfg.common_eval.remove_bpe is not None: if cfg.common_eval.remove_bpe == "sentencepiece": raise NotImplementedError else: bpe_cont = cfg.common_eval.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample["id"][i] tokens = hypo["tokens"] tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() if cfg.task.add_bos_token: assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) if inf_scores.any(): logger.info( "skipping tokens with inf scores:", task.target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob ) is_bpe = False w = "" if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " + ( "\t".join( "{} [{:2f}]".format(x[0], x[1]) for x in word_prob ) ) ) wps_meter.update(sample["ntokens"]) progress.log({"wps": round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info( "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg ) ) logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( avg_nll_loss, 2 ** avg_nll_loss ) ) if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('espresso.speech_recognize') if output_file is not sys.stdout: # also print to stdout logger.addHandler(logging.StreamHandler(sys.stdout)) print_options_meaning_changes(args, logger) utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset split task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionary dictionary = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, ) for i, m in enumerate(models): if hasattr(m, 'is_wordlm') and m.is_wordlm: # assume subword LM comes before word LM if isinstance(models[i - 1], FairseqLanguageModel): models[i - 1] = MultiLevelLanguageModel( m, models[i - 1], subwordlm_weight=args.subwordlm_weight, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) del models[i] logger.info('LM fusion with Multi-level LM') else: models[i] = TensorizedLookaheadLanguageModel( m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) logger.info('LM fusion with Look-ahead Word LM') # assume subword LM comes after E2E models elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): logger.info('LM fusion with Subword LM') if args.lm_weight != 0.0: logger.info('using LM fusion with lm-weight={:.2f}'.format( args.lm_weight)) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[ model.max_positions() if hasattr(model, 'encoder') else (None, model.max_positions()) for model in models ]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) # Initialize generator if args.match_source_len: logger.warning( 'The option match_source_len is not applicable to speech recognition. Ignoring it.' ) gen_timer = StopwatchMeter() generator = task.build_generator(args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Generate and compute WER scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens, lm_weight=args.lm_weight, ) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions if args.print_alignment: net_input = sample['net_input'] src_tokens = net_input['src_tokens'] output_lengths = models[0].encoder.output_lengths( net_input['src_lengths']) nonpad_idxs = sequence_mask( output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) for i in range(len(sample['id'])): has_target = sample['target'] is not None utt_id = sample['utt_id'][i] # Retrieve the original sentences if has_target: target_str = sample['target_raw_text'][i] if not args.quiet: detok_target_str = decode_fn(target_str) print('T-{}\t{}'.format(utt_id, detok_target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_str = dictionary.string( hypo['tokens'].int().cpu(), bpe_symbol=None, extra_symbols_to_ignore={dictionary.pad()}, ) # not removing bpe at this point detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 print('H-{}\t{}\t{}'.format(utt_id, detok_hypo_str, score), file=output_file) # Score and obtain attention only the top hypothesis if j == 0: # src_len x tgt_len attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \ if args.print_alignment and hypo['attention'] is not None else None if args.print_alignment and attention is not None: save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) plot_attention(attention, detok_hypo_str, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) if has_target: scorer.add_evaluation(utt_id, target_str, hypo_str) wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info( 'Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: logger.info('Saved attention plots in ' + save_dir) if has_target: scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids) fn = 'decoded_char_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_char_results()) logger.info('Decoded char results saved as ' + f.name) fn = 'decoded_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_results()) logger.info('Decoded results saved as ' + f.name) if has_target: header = 'Recognize {} with beam={}: '.format(args.gen_subset, args.beam) fn = 'wer' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.wer())) logger.info(header + res) f.write(res + '\n') logger.info('WER saved in ' + f.name) fn = 'cer' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.cer())) logger.info(' ' * len(header) + res) f.write(res + '\n') logger.info('CER saved in ' + f.name) fn = 'aligned_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_aligned_results()) logger.info('Aligned results saved as ' + f.name) return scorer
def _main(args, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(args) if args.max_tokens is None and args.batch_size is None: args.max_tokens = 12000 logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) overrides = ast.literal_eval(args.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), strict=(args.checkpoint_shard_count == 1), num_shards=args.checkpoint_shard_count, ) if args.lm_path is not None: overrides["data"] = args.data try: lms, _ = checkpoint_utils.load_model_ensemble( [args.lm_path], arg_overrides=overrides, task=None, ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({args.data})" ) raise assert len(lms) == 1 else: lms = [None] # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if args.fp16: model.half() if use_cuda and not args.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(args) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=("tqdm" if not args.no_progress_bar else "none"), ) output = [] for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue loss, _ = task.inference_step( models, sample ) output.append(loss) with open(args.score_target_file, "w") as f: for loss in output: for i in range(loss.shape[0]): f.write(str(float(loss[i])/math.log(2))+'\n') f.close()
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('fairseq_cli.generate') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True, extra_symbols_to_ignore={ generator.eos, }) src_str = decode_fn(src_str) if has_target: target_str = decode_fn(target_str) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) if has_target: print('T-{}\t{}'.format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore={ generator.eos, }) detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) # detokenized hypothesis print('D-{}\t{}\t{}'.format(sample_id, score, detok_hypo_str), file=output_file) print( 'P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2) ).tolist(), ))), file=output_file) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join([ '{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ])), file=output_file) if args.print_step: print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) if 'enc_selection' in hypo: print('Menc-{}\t{}'.format(sample_id, hypo['enc_selection']), file=output_file) if 'dec_selection' in hypo: print('Mdec-{}\t{}'.format(sample_id, hypo['dec_selection']), file=output_file) if args.print_attn_confidence: print('C-{}\t{}'.format(sample_id, hypo['enc_self_attn_conf']), file=output_file) if getattr(args, 'retain_iter_history', False): for step, h in enumerate(hypo['history']): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line( detok_hypo_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info( 'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) else: logger.warning( "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) logger.info('Generate {} with beam={}: {}'.format( args.gen_subset, args.beam, scorer.result_string())) return scorer
def main(parsed_args, **unused_kwargs): assert parsed_args.path is not None, '--path required for evaluation!' if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq( float('-inf')) if inf_scores.any(): logger.info( 'skipping tokens with inf scores:', task.target_dictionary.string( tokens[inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample['ntokens']) progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss)) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def _main(args, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=output_file, ) logger = logging.getLogger("espresso.dump_posteriors") print_options_meaning_changes(args, logger) utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset split task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Load ensemble logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Load state prior for cross-entropy trained systems decoding if args.state_prior_file is not None: prior = torch.from_numpy(kaldi_io.read_vec_flt(args.state_prior_file)) else: prior = [] # Optimize ensemble for generation for model in models: model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() if isinstance(prior, list) and getattr(model, "state_prior", None) is not None: prior.append(model.state_prior.unsqueeze(0)) if isinstance(prior, list) and len(prior) > 0: prior = torch.cat(prior, 0).mean(0) # average priors across models prior = prior / prior.sum() # re-normalize elif isinstance(prior, list): prior = None if prior is not None: if args.fp16: prior = prior.half() if use_cuda: prior = prior.cuda() log_prior = prior.log() else: log_prior = None # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[ model.max_positions() if hasattr(model, "encoder") else (None, model.max_positions()) for model in models ]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=("tqdm" if not args.no_progress_bar else "none"), ) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(models, args) # Generate and dump num_sentences = 0 chunk_width = getattr(task, "chunk_width", None) lprobs_wspecifier = "ark:| copy-matrix ark:- ark:-" with kaldi_io.open_or_fd(lprobs_wspecifier, "wb") as f: if chunk_width is None: # normal dumping (i.e., no chunking) for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue gen_timer.start() lprobs, padding_mask = task.inference_step( generator, models, sample) if log_prior is not None: assert lprobs.size(-1) == log_prior.size(0) lprobs = lprobs - log_prior out_lengths = (~padding_mask).long().sum( dim=1).cpu() if padding_mask is not None else None num_processed_frames = sample["ntokens"] gen_timer.stop(num_processed_frames) num_sentences += sample["nsentences"] if out_lengths is not None: for i in range(sample["nsentences"]): length = out_lengths[i] kaldi_io.write_mat(f, lprobs[i, :length, :].cpu().numpy(), key=sample["utt_id"][i]) else: for i in range(sample["nsentences"]): kaldi_io.write_mat(f, lprobs[i, :, :].cpu().numpy(), key=sample["utt_id"][i]) else: # dumping chunks within the same utterance from left to right for sample in progress: # sample is actually a list of batches sample = utils.move_to_cuda(sample) if use_cuda else sample utt_id = sample[0]["utt_id"] id = sample[0]["id"] whole_lprobs = None for i, chunk_sample in enumerate(sample): if "net_input" not in chunk_sample: continue assert chunk_sample["utt_id"] == utt_id and ( chunk_sample["id"] == id).all() gen_timer.start() lprobs, _ = task.inference_step(generator, models, chunk_sample) if log_prior is not None: assert lprobs.size(-1) == log_prior.size(0) lprobs = lprobs - log_prior if whole_lprobs is None: whole_lprobs = lprobs.cpu() else: whole_lprobs = torch.cat((whole_lprobs, lprobs.cpu()), 1) num_processed_frames = chunk_sample["ntokens"] gen_timer.stop(num_processed_frames) if i == len(sample) - 1: num_sentences += len(utt_id) for j in range(len(utt_id)): truncated_length = models[0].output_lengths( task.dataset(args.gen_subset).src_sizes[id[j]] ) # length is after possible subsampling by the model mat = whole_lprobs[j, :truncated_length, :] kaldi_io.write_mat(f, mat.numpy(), key=utt_id[j]) logger.info( "Dumped {} utterances ({} frames) in {:.1f}s ({:.2f} sentences/s, {:.2f} frames/s)" .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) return
def train(args, trainer, task, epoch_itr, model, experiment_path, total_samples=None, last_epoch_num=0, restore=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) num_heads = args.decoder_attention_heads head_dim = args.decoder_embed_dim // num_heads if experiment_path is not None: with open(experiment_path, 'r') as f: swaps = json.load(f) mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False conf = { "encoder": [{ "self_attn": [] } for i in range(args.encoder_layers)], "decoder": [{ "self_attn": [], "enc_attn": [] } for i in range(args.decoder_layers)] } attentions = { "decoder": [{ "self_attn": [] } for i in range(args.decoder_layers)] } batch_regression = 1.0 - (total_samples / (160239 * 50)) for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples, batch_num=batch_regression) if log_output is None: # OOM, overflow, ... continue total_samples += model.decoder.layers[0].self_attn.bsz batch_regression = 1.0 - ( total_samples / (160239 * 40) ) # need to find more generic way to find total samples and epoch num. # Get Confidence for each Head. if args.head_confidence_method is not None: conf = get_batch_confs(model, conf, args) # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop, val_conf = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break if args.head_confidence_method is not None: conf = convert_confs(conf, args) path = args.save_dir.replace("checkpoints", "confs") + "-method={0}".format( args.head_confidence_method) try: os.mkdir(path, 0o775) except: pass with open( args.save_dir.replace("checkpoints", "confs") + "-method={0}".format(args.head_confidence_method) + "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd: pickle.dump(conf, fd, protocol=3) if args.dynamic_type is not None and args.head_confidence_method is not None: conf = val_conf restore['enc_self_attn'], last_epoch_num[ 'enc_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[0]), "encoder", "self_attn", restore['enc_self_attn'], int(args.dynamic_swap_frequency[0]), last_epoch_num['enc_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[0]), conf[0], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[0], rest=int(args.dynamic_rest[0]), end_epoch=int( args.dynamic_end_epoch[0])) restore['dec_self_attn'], last_epoch_num[ 'dec_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[1]), "decoder", "self_attn", restore['dec_self_attn'], int(args.dynamic_swap_frequency[1]), last_epoch_num['dec_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[1]), conf[1], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[1], rest=int(args.dynamic_rest[1]), end_epoch=int( args.dynamic_end_epoch[1])) restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr( model, int(args.start_dynamic_mhr[2]), "decoder", "encoder_attn", restore['dec_enc_attn'], int(args.dynamic_swap_frequency[2]), last_epoch_num['dec_enc_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[2]), conf[2], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[2], rest=int(args.dynamic_rest[2]), end_epoch=int(args.dynamic_end_epoch[2])) # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop, total_samples, restore, last_epoch_num
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('fairseq_cli.predict') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.pred_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation TODO for model in models: if args.fp16: model.half() if use_cuda: model.cuda() # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.pred_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x model = models[0] ## Use first model only, TODO ensembles y_true = [] y_pred = [] for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue for i, sample_id in enumerate(sample['id'].tolist()): # Remove padding src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() ### Predict _, _, y_true_, y_pred_ = task._predict_with_seqeval(sample, model) # get labels and predictions y_true_ = y_true_[0] y_pred_ = y_pred_[0] y_true.append(y_true_) # append to all labels and predictions y_pred.append(y_pred_) if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" target_str = tgt_dict.string( target_tokens, args.remove_bpe, escape_unk=True, ) src_str = decode_fn(src_str) target_str = decode_fn(target_str) pred_str = ' '.join(y_pred_) if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) print('P-{}\t{}'.format(sample_id, pred_str), file=output_file) clf_report = classification_report(y_true, y_pred) logger.info("Eval results on {} subset: \n\n{} ".format(args.pred_subset, clf_report))
def main(cfg: DictConfig, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) logger.info(cfg) if cfg.eval_lm.context_window > 0: # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window # Initialize the task using the current *cfg* task = tasks.setup_task(cfg.task) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=eval(cfg.common_eval.model_overrides), suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, task=task, ) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) # Optimize ensemble for generation and set the source and dest dicts on the model # (required by scorer) for model in models: if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) assert len(models) > 0 logger.info( "num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters())) ) # Load dataset splits task.load_dataset(cfg.dataset.gen_subset) dataset = task.dataset(cfg.dataset.gen_subset) logger.info( "{} {} {:,} examples".format( cfg.task.data, cfg.dataset.gen_subset, len(dataset) ) ) itr = task.eval_lm_dataloader( dataset=dataset, max_tokens=cfg.dataset.max_tokens or 36000, batch_size=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), num_shards=max( cfg.dataset.num_shards, cfg.distributed_training.distributed_world_size, ), shard_id=max( cfg.dataset.shard_id, cfg.distributed_training.distributed_rank, ), num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, context_window=cfg.eval_lm.context_window, ) itr = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) results = eval_lm( models=models, source_dictionary=task.source_dictionary, batch_iterator=itr, post_process=cfg.common_eval.post_process, output_word_probs=cfg.eval_lm.output_word_probs, output_word_stats=cfg.eval_lm.output_word_stats, target_dictionary=task.target_dictionary, softmax_batch=cfg.eval_lm.softmax_batch, remove_bos_token=getattr(cfg.task, "add_bos_token", False), ) logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( results["loss"], results["perplexity"] ) ) return results
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: with metrics.aggregate('train_inner'): try: log_output = trainer.train_step(samples) except ResetTrainerException: trainer._wrapped_criterion = None trainer._wrapped_model = None trainer._optimizer = None logger.info("reset the trainer at {}".format( trainer.get_num_updates())) log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return should_end_training
def validate(args, trainer, task, epoch_itr, subsets, prune=-1): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # added by Junxian if prune > 0: index_map = trainer.get_model().set_prune_index(prune) task.set_index_map(index_map) # not write templates for time profiling write_template_flag = False if args.eval_mode == 'time' else True # only one worker deals with the template file in DDP if args.distributed_rank == 0 and write_template_flag: print('write template files') if args.eval_mode == 'none': fout = open( os.path.join( args.save_dir, 'templates_{}_{}.txt'.format( epoch_itr.epoch, trainer.get_num_updates())), 'w') else: fout = open( os.path.join(args.save_dir, 'templates_eval_{}.txt'.format(subset)), 'w') if prune <= 0: task.write_lambda(fout, trainer.get_model()) else: fout = None # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample, split=subset) # added by Junxian if args.distributed_rank == 0: task.write_template(sample, trainer.get_model(), fout) if fout is not None: fout.close() # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(cfg.common) if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.max_tokens = 12000 logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits task = tasks.setup_task(cfg.task) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( [cfg.generation.lm_path], arg_overrides=overrides, task=None) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({cfg.task.data})" ) raise assert len(lms) == 1 else: lms = [None] # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models]), ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, num_shards=cfg.distributed_training.distributed_world_size, shard_id=cfg.distributed_training.distributed_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() extra_gen_cls_kwargs = { "lm_model": lms[0], "lm_weight": cfg.generation.lm_weight } generator = task.build_generator(models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if cfg.generation.prefix_size > 0: prefix_tokens = sample["target"][:, :cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints, ) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): has_target = sample["target"] is not None # Remove padding if "src_tokens" in sample["net_input"]: src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()) else: src_tokens = None target_tokens = None if has_target: target_tokens = (utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( cfg.dataset.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( cfg.dataset.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, cfg.common_eval.post_process, escape_unk=True, extra_symbols_to_ignore= get_symbols_to_strip_from_output(generator), ) src_str = decode_fn(src_str) if has_target: target_str = decode_fn(target_str) if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:cfg.generation.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator), ) detok_hypo_str = decode_fn(hypo_str) if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( "H-{}\t{}\t{}".format(sample_id, score, hypo_str), file=output_file, ) # detokenized hypothesis print( "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), file=output_file, ) print( "P-{}\t{}".format( sample_id, " ".join( map( lambda x: "{:.4f}".format(x), # convert from base e to base 2 hypo["positional_scores"].div_(math.log(2) ).tolist(), )), ), file=output_file, ) if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, " ".join([ "{}-{}".format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ]), ), file=output_file, ) if cfg.generation.print_alignment == "soft": print( "A-{}\t{}".format( sample_id, " ".join([ ",".join(src_probs) for src_probs in alignment ]), ), file=output_file, ) if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print( "E-{}_{}\t{}".format(sample_id, step, h_str), file=output_file, ) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or cfg.common_eval.post_process is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line( detok_hypo_str, add_if_not_exist=True) if hasattr(scorer, "add_string"): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({"wps": round(wps_meter.avg)}) num_sentences += (sample["nsentences"] if "nsentences" in sample else sample["id"].numel()) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)" .format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, )) if has_target: if cfg.bpe and not cfg.generation.sacrebleu: if cfg.common_eval.post_process: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) else: logger.warning( "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format(cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()), file=output_file, ) return scorer
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') for i, samples in enumerate(progress): with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if(i==0): print('epoch: ', epoch_itr.epoch) endeattn_norm=[] selfattn_norm=[] for m in model.modules(): if(hasattr(m, 'selfattn_norm')): if(m.selfattn_norm != None): selfattn_norm.append(m.selfattn_norm) if(hasattr(m, 'endeattn_norm')): if(m.endeattn_norm != None): endeattn_norm.append(m.endeattn_norm) print('self attention norms: ', selfattn_norm) print('en/decoder attn norms:', endeattn_norm) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def validate_iw(args, trainer, task, epoch_itr, subsets, prune=-1, mode='iw'): """Evaluate the model on the validation set(s) and return the losses.""" if mode == 'none' or mode == 'time' or args.criterior == 'lm_baseline': return [0] # top k instead of sampling to approximate sum of prototypes for evaluation for subset in subsets: task.dataset(subset).set_sampling(False) if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=1, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) if prune > 0: index_map = trainer.get_model().set_prune_index(prune) task.set_index_map(index_map) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_iw_step(sample, mode=mode) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag='valid_iw', step=trainer.get_num_updates()) # valid_losses.append(stats[args.best_checkpoint_metric]) if prune > 0: trainer.get_model().reset_prune_index() task.reset_index_map() return valid_losses
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), azureml_logging=(cfg.common.azureml_logging if distributed_utils.is_master( cfg.distributed_training) else False), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def _main(args, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("espresso.speech_recognize") if output_file is not sys.stdout: # also print to stdout logger.addHandler(logging.StreamHandler(sys.stdout)) print_options_meaning_changes(args, logger) utils.import_user_module(args) if args.max_tokens is None and args.batch_size is None: args.max_tokens = 12000 logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset split task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionary dictionary = task.target_dictionary overrides = ast.literal_eval(args.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), strict=(args.checkpoint_shard_count == 1), num_shards=args.checkpoint_shard_count, ) if args.lm_path is not None: overrides["data"] = args.data try: lms, _ = checkpoint_utils.load_model_ensemble( utils.split_paths(args.lm_path), arg_overrides=overrides, task=None, ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({args.data})") raise assert len(lms) == 1 or len(lms) == 2 # Multi-level LM expects two LMs else: lms = [None] for i, m in enumerate(lms): if m is None: continue if hasattr(m, "is_wordlm") and m.is_wordlm: # assume subword LM comes before word LM if i > 0 and isinstance(lms[i - 1], FairseqLanguageModel): lms[i - 1] = MultiLevelLanguageModel( m, lms[i - 1], subwordlm_weight=args.subwordlm_weight, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) del lms[i] logger.info("LM fusion with Multi-level LM") else: lms[i] = TensorizedLookaheadLanguageModel( m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) logger.info("LM fusion with Look-ahead Word LM") else: assert isinstance(m, FairseqLanguageModel) logger.info("LM fusion with Subword LM") if args.lm_weight != 0.0: logger.info("using LM fusion with lm-weight={:.2f}".format( args.lm_weight)) # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if args.fp16: model.half() if use_cuda and not args.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(args) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[ model.max_positions() if hasattr(model, "encoder") else (None, model.max_positions()) for model in models ]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=("tqdm" if not args.no_progress_bar else "none"), ) # Initialize generator if args.match_source_len: logger.warning( "The option match_source_len is not applicable to speech recognition. Ignoring it." ) gen_timer = StopwatchMeter() extra_gen_cls_kwargs = { "lm_model": lms[0], "lm_weight": args.lm_weight, "eos_factor": args.eos_factor, } args.score_reference = False # not applicable for ASR temp_val = args.print_alignment args.print_alignment = False # not applicable for ASR generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) args.print_alignment = temp_val # Handle tokenization and BPE tokenizer = task.build_tokenizer(args) bpe = task.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Generate and compute WER scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, :args.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions if args.print_alignment: net_input = sample["net_input"] src_tokens = net_input["src_tokens"] output_lengths = models[0].encoder.output_lengths( net_input["src_lengths"]) nonpad_idxs = sequence_mask( output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) for i in range(len(sample["id"])): has_target = sample["target"] is not None utt_id = sample["utt_id"][i] # Retrieve the original sentences if has_target: target_str = sample["target_raw_text"][i] if not args.quiet: detok_target_str = decode_fn(target_str) print("T-{}\t{}".format(utt_id, detok_target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_str = dictionary.string( hypo["tokens"].int().cpu(), bpe_symbol=None, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator), ) # not removing bpe at this point detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo["score"] / math.log(2) # convert to base 2 print("H-{}\t{}\t{}".format(utt_id, detok_hypo_str, score), file=output_file) # Score and obtain attention only the top hypothesis if j == 0: # src_len x tgt_len attention = hypo["attention"][nonpad_idxs[i]].float().cpu() \ if args.print_alignment and hypo["attention"] is not None else None if args.print_alignment and attention is not None: save_dir = os.path.join(args.results_path, "attn_plots") os.makedirs(save_dir, exist_ok=True) plot_attention(attention, detok_hypo_str, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) if has_target: scorer.add_evaluation(utt_id, target_str, hypo_str) wps_meter.update(num_generated_tokens) progress.log({"wps": round(wps_meter.avg)}) num_sentences += sample[ "nsentences"] if "nsentences" in sample else sample["id"].numel() logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( "Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)" .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: logger.info("Saved attention plots in " + save_dir) if has_target: scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids) fn = "decoded_char_results.txt" with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_char_results()) logger.info("Decoded char results saved as " + f.name) fn = "decoded_results.txt" with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_results()) logger.info("Decoded results saved as " + f.name) if has_target: header = "Recognize {} with beam={}: ".format(args.gen_subset, args.beam) fn = "wer" with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.wer())) logger.info(header + res) f.write(res + "\n") logger.info("WER saved in " + f.name) fn = "cer" with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.cer())) logger.info(" " * len(header) + res) f.write(res + "\n") logger.info("CER saved in " + f.name) fn = "aligned_results.txt" with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_aligned_results()) logger.info("Aligned results saved as " + f.name) return scorer