def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple' ) # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) valid_losses.append(stats['valid_loss']) return valid_losses
def train(args, trainer, task, epoch_itr, summary_writer=None): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates % args.log_interval == 0: summary_writer.log_stats('train', stats, num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def main(args, task=None, model_state=None): check_args(args) if args.max_tokens is None and args.batch_size is None: args.max_tokens = 4000000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu if task is None: # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) logger.info( "| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)) ) ) # Set dictionary tgt_dict = task.target_dictionary logger.info("| decoding with criterion {}".format(args.criterion)) # Load ensemble if args.load_emissions: models, criterions = [], [] else: logger.info("| loading model(s) from {}".format(args.path)) models, criterions, _ = load_models_and_criterions( args.path, data_path=args.data, arg_overrides=eval(args.model_overrides), # noqa task=task, model_state=model_state, ) optimize_models(args, use_cuda, models) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": trans = criterions[0].asg.trans.data args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(models, args) if args.load_emissions: generator = ExistingEmissionsDecoder( generator, np.load(args.load_emissions, allow_pickle=True) ) logger.info("loaded emissions from " + args.load_emissions) num_sentences = 0 if args.results_path is not None and not os.path.exists(args.results_path): os.makedirs(args.results_path) max_source_pos = ( utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ) if max_source_pos is not None: max_source_pos = max_source_pos[0] if max_source_pos is not None: max_source_pos = max_source_pos[0] - 1 if args.dump_emissions: emissions = {} if args.dump_features: features = {} models[0].bert.proj = None else: res_files = prepare_result_files(args) errs_t = 0 lengths_t = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, : args.prefix_size] gen_timer.start() if args.dump_emissions: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) emm = models[0].get_normalized_probs(encoder_out, log_probs=True) emm = emm.transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): emissions[id.item()] = emm[i] continue elif args.dump_features: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None features[id.item()] = (feat[i], padding) continue hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): speaker = None # id = task.dataset(args.gen_subset).ids[int(sample_id)] id = sample_id toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :] target_tokens = ( utils.strip_pad(toks, tgt_dict.pad()).int().cpu() ) # Process top predictions errs, length = process_predictions( args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id ) errs_t += errs lengths_t += length wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() wer = None if args.dump_emissions: emm_arr = [] for i in range(len(emissions)): emm_arr.append(emissions[i]) np.save(args.dump_emissions, emm_arr) logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") elif args.dump_features: feat_arr = [] for i in range(len(features)): feat_arr.append(features[i]) np.save(args.dump_features, feat_arr) logger.info(f"saved {len(features)} emissions to {args.dump_features}") else: if lengths_t > 0: wer = errs_t * 100.0 / lengths_t logger.info(f"WER: {wer}") logger.info( "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" "sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, ) ) logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) return task, wer
def main(args): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), task=task, bert_ratio=args.bert_ratio if args.change_ratio else None, encoder_ratio=args.encoder_ratio if args.change_ratio else None, geargs=args, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens) # batchsize, beamsize, dict num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for i, hypo in enumerate( hypos[i][:min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join( map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis if has_target and i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) return scorer
def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' utils.import_user_module(parsed_args) print(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(':'), arg_overrides=eval(parsed_args.model_overrides), task=task, ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() for name, _module in model.named_modules(): if 'layer_norm' in name: _module.float() if use_cuda: model.cuda() assert len(models) > 0 print('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = set(i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont)) bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if args.add_bos_token: assert hypo['tokens'][0].item( ) == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq( float('-inf')) if inf_scores.any(): print( '| Skipping tokens with inf scores:', task.target_dictionary.string( tokens[inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: print( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): print(ws)
def eval_tune_loss(args, trainer, task, subset, extra_state): """Evaluate the model on the validation set and return the average loss.""" # Initialize dataloader itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions() ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args=args, iterator=itr, epoch=extra_state["epoch"], prefix=f"valid on '{subset}' subset", no_progress_bar="simple", ) # reset validation loss meters for k in ["valid_loss", "valid_nll_loss"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ["loss", "nll_loss", "ntokens", "nsentences", "sample_size"]: continue if "loss" in k: extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) extra_state["tune_eval"]["loss"] = stats["valid_loss"] extra_state["tune_eval"]["perplexity"] = stats["valid_ppl"] if ( extra_state["tune_eval"]["lowest_loss"] is None or extra_state["tune_eval"]["loss"] < extra_state["tune_eval"]["lowest_loss"] ): extra_state["tune_eval"]["lowest_loss"] = extra_state["tune_eval"]["loss"] extra_state["tune_eval"]["num_since_best"] = 0 else: extra_state["tune_eval"]["num_since_best"] += 1 stop_due_to_tune_loss = False if ( args.stop_no_best_validate_loss >= 0 and extra_state["tune_eval"]["num_since_best"] > args.stop_no_best_validate_loss ): stop_due_to_tune_loss = True print( f"Stopping training due to eval tune loss stagnation - last best " f"eval tune loss of {extra_state['tune_eval']['lowest_loss']} " f"(current loss: {extra_state['tune_eval']['loss']}) " f"was {extra_state['tune_eval']['num_since_best']} validations ago." ) return extra_state, stop_due_to_tune_loss
def main(args): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) models, _model_args = utils.load_ensemble_for_inference( args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() if args.score_reference: translator = SequenceScorer(models, task.target_dictionary) else: translator = SequenceGenerator( models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature, diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength, match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size, ) if use_cuda: translator.cuda() # Generate and compute BLEU score scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: if args.score_reference: translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) else: translations = translator.generate_batched_itr( t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, ) wps_meter = TimeMeter() for sample_id, src_tokens, target_tokens, hypos in translations: # Process input and ground truth has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: src_str = src_dict.string(src_tokens, args.remove_bpe) if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join( map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis if has_target and i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tokenizer.Tokenizer.tokenize( target_str, tgt_dict, add_if_not_exist=True) scorer.add(target_tokens, hypo_tokens) wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) num_sentences += 1 print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def validate(args, trainer, dataset, subset, extra_state): """Evaluate the model on the validation set and return the average loss.""" epoch = extra_state["epoch"] # Initialize dataloader max_positions_valid = ( trainer.get_model().max_encoder_positions(), trainer.get_model().max_decoder_positions(), ) itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, prefix=f"valid on '{subset}' subset", no_progress_bar="simple") # reset validation loss meters for k in ["valid_loss", "valid_nll_loss"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ["loss", "nll_loss"]: continue if "loss" in k: extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) val_loss = stats["valid_loss"] val_ppl = stats["valid_ppl"] if ("validate" not in extra_state or val_loss < extra_state["validate"]["lowest_loss"]): extra_state["validate"] = { "lowest_loss": val_loss, "num_since_best": 0 } else: extra_state["validate"]["num_since_best"] += 1 stop_due_to_val_loss = False if (args.stop_no_best_validate_loss >= 0 and extra_state["validate"]["num_since_best"] > args.stop_no_best_validate_loss): stop_due_to_val_loss = True print( f"Stopping training due to validation score stagnation - last best " f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss})" f"was {extra_state['validate']['num_since_best']} validations ago." ) return val_loss, val_ppl, stop_due_to_val_loss
def validate(args, trainer, dataset, subset, epoch): """Evaluate the model on the validation set and return the average loss.""" # Initialize dataloader max_positions_valid = ( trainer.get_model().max_encoder_positions(), trainer.get_model().max_decoder_positions(), ) itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, prefix=f'valid on \'{subset}\' subset', no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss']: continue if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) val_loss = stats['valid_loss'] val_ppl = stats['valid_ppl'] if not hasattr(validate, 'lowest_loss') or val_loss < validate.lowest_loss: validate.lowest_loss = val_loss validate.num_since_best = 0 elif not hasattr(validate, 'num_since_best'): validate.num_since_best = 1 else: validate.num_since_best += 1 stop_due_to_val_loss = False if (args.stop_no_best_validate_loss >= 0 and validate.num_since_best > args.stop_no_best_validate_loss): stop_due_to_val_loss = True print( f'Stopping training due to validation score stagnation - last best ' f'validation loss of {validate.lowest_loss} (current loss: {val_loss})' f'was {validate.num_since_best} validations ago.') return val_loss, val_ppl, stop_due_to_val_loss
def setup_epoch( args, epoch, batch_offset, trainer, dataset, ): """Sets up data and progress meters for one epoch.""" # Set seed based on args.seed and the epoch number so that we get # reproducible results when resuming from checkpoints seed = args.seed + epoch torch.manual_seed(seed) # The max number of positions can be different for train and valid # e.g., RNNs may support more positions at test time than seen in training max_positions_train = (min( args.max_source_positions, trainer.get_model().max_encoder_positions(), ), min( args.max_target_positions, trainer.get_model().max_decoder_positions(), )) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( args.train_subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, no_progress_bar='simple', ) itr = itertools.islice(progress, batch_offset, None) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) return itr, progress, extra_meters
def _generate_score(models, args, dataset, dataset_split): use_cuda = torch.cuda.is_available() and not args.cpu # Load ensemble if not args.quiet: print("| loading model(s) from {}".format(", ".join(args.path))) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam ) # Initialize generator model_weights = None if args.model_weights: model_weights = [float(w.strip()) for w in args.model_weights.split(",")] use_char_source = isinstance(models[0], char_source_model.CharSourceModel) translator = beam_decode.SequenceGenerator( models, beam_size=args.beam, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, unk_penalty=args.unkpen, word_reward=args.word_reward, model_weights=model_weights, use_char_source=use_char_source, ) if use_cuda: translator.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Keep track of translations # Initialize with empty translations translated_sentences = [""] * len(dataset.splits[dataset_split]) # Generate and compute BLEU score scorer = bleu.Scorer( dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk() ) max_positions = min(model.max_encoder_positions() for model in models) itr = dataset.eval_dataloader( dataset_split, max_sentences=args.max_sentences, max_positions=max_positions, skip_invalid_size_inputs_valid_test=(args.skip_invalid_size_inputs_valid_test), ) if args.num_shards > 1: if args.shard_id < 0 or args.shard_id >= args.num_shards: raise ValueError("--shard-id must be between 0 and num_shards") itr = data.sharded_iterator(itr, args.num_shards, args.shard_id) num_sentences = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() # Keep more detailed timing when invoked from benchmark if "keep_detailed_timing" in args: gen_timer = pytorch_translate_utils.BucketStopwatchMeter( args.increment, args.max_length, args.samples_per_length, ) else: gen_timer = StopwatchMeter() translations = translator.generate_batched_itr( t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda, timer=gen_timer, ) for sample_id, src_tokens, target_tokens, hypos in translations: # Process input and ground truth target_tokens = target_tokens.int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = dataset.splits[dataset_split].src.get_original_text(sample_id) target_str = dataset.splits[dataset_split].dst.get_original_text( sample_id ) else: src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) target_str = dataset.dst_dict.string( target_tokens, args.remove_bpe, escape_unk=True ) if not args.quiet: print(f"S-{sample_id}\t{src_str}") print(f"T-{sample_id}\t{target_str}") # Process top predictions for i, hypo in enumerate(hypos[: min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"].int().cpu(), align_dict=align_dict, dst_dict=dataset.dst_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print(f"H-{sample_id}\t{hypo['score']}\t{hypo_str}") print( "A-{}\t{}".format( sample_id, " ".join(map(lambda x: str(utils.item(x)), alignment)), ) ) # Score only the top hypothesis if i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement # and/or without BPE target_tokens = tokenizer.Tokenizer.tokenize( target_str, dataset.dst_dict, add_if_not_exist=True ) scorer.add(target_tokens, hypo_tokens) translated_sentences[sample_id] = hypo_str wps_meter.update(src_tokens.size(0)) t.log({"wps": round(wps_meter.avg)}) num_sentences += 1 # If applicable, save the translations to the output file # For eg. external evaluation if getattr(args, "translation_output_file", False): with open(args.translation_output_file, 'w') as out_file: for hypo_str in translated_sentences: print(hypo_str, file=out_file) return scorer, num_sentences, gen_timer
def validate_translation(args, trainer, task, epoch_itr, generator): src_dict = task.source_dictionary tgt_dict = task.target_dictionary models = [trainer.get_model()] if hasattr(task, 'eval_lang_pairs'): bleu_dict = {key: None for key in task.eval_lang_pairs} # Generate and compute BLEU score if args.sacrebleu: scorer_dict = { key: bleu.SacrebleuScorer() for key in task.eval_lang_pairs } else: scorer_dict = { key: bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) for key in task.eval_lang_pairs } itr = task.get_batch_iterator( dataset=task.dataset('valid'), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, noskip=True, )[0].next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, prefix='translate subset', no_progress_bar='simple') num_sentences = 0 has_target = True #with progress_bar.build_progress_bar(args, itr) as t: for samples in progress: if torch.cuda.is_available() and not args.cpu: samples = utils.move_to_cuda(samples) #if 'net_input' not in samples: # continue prefix_tokens = None for key, sample in samples.items(): hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Remove padding if args.sde: src_tokens = target_tokens else: src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) # Either retrieve the original sentences or regenerate them from tokens. #if src_dict is not None: # src_str = src_dict.string(src_tokens, args.remove_bpe) #else: # src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) #if not args.quiet: # if src_dict is not None: # print('S-{}\t{}'.format(sample_id, src_str)) # if has_target: # print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str="", alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) #if not args.quiet: # print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) # print('P-{}\t{}'.format( # sample_id, # ' '.join(map( # lambda x: '{:.4f}'.format(x), # hypo['positional_scores'].tolist(), # )) # )) # if args.print_alignment: # print('A-{}\t{}'.format( # sample_id, # ' '.join(map(lambda x: str(utils.item(x)), alignment)) # )) #print(has_target, j, hypo_str) # Score only the top hypothesis if has_target and j == 0: if args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) if hasattr(scorer_dict[key], 'add_string'): scorer_dict[key].add_string( target_str, hypo_str) else: scorer_dict[key].add(target_tokens, hypo_tokens) num_sentences += sample['nsentences'] print("|valid tranlsated {} sentences".format(num_sentences)) for key, scorer in scorer_dict.items(): bleu_dict[key] = scorer.score() else: bleu_dict = {0: None} # Generate and compute BLEU score if args.sacrebleu: scorer_dict = {0: bleu.SacrebleuScorer()} else: scorer_dict = { 0: bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) } itr = task.get_batch_iterator( dataset=task.dataset('valid'), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, noskip=True, )[0].next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, prefix='translate subset', no_progress_bar='simple') num_sentences = 0 has_target = True #with progress_bar.build_progress_bar(args, itr) as t: for samples in progress: if torch.cuda.is_available() and not args.cpu: samples = utils.move_to_cuda(samples) #if 'net_input' not in samples: # continue prefix_tokens = None sample = samples hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Remove padding if args.sde: src_tokens = target_tokens else: src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) # Either retrieve the original sentences or regenerate them from tokens. #if src_dict is not None: # src_str = src_dict.string(src_tokens, args.remove_bpe) #else: # src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) #if not args.quiet: # if src_dict is not None: # print('S-{}\t{}'.format(sample_id, src_str)) # if has_target: # print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str="", alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) #if not args.quiet: # print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) # print('P-{}\t{}'.format( # sample_id, # ' '.join(map( # lambda x: '{:.4f}'.format(x), # hypo['positional_scores'].tolist(), # )) # )) # if args.print_alignment: # print('A-{}\t{}'.format( # sample_id, # ' '.join(map(lambda x: str(utils.item(x)), alignment)) # )) #print(has_target, j, hypo_str) # Score only the top hypothesis if has_target and j == 0: if args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) if hasattr(scorer_dict[0], 'add_string'): scorer_dict[0].add_string(target_str, hypo_str) else: scorer_dict[0].add(target_tokens, hypo_tokens) num_sentences += sample['nsentences'] print("|valid tranlsated {} sentences".format(num_sentences)) for key, scorer in scorer_dict.items(): bleu_dict[key] = scorer.score() return bleu_dict
def validate(args, trainer, task, epoch_itr, subsets, generator=None): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] if args.eval_bleu: bleus = validate_translation(args, trainer, task, epoch_itr, generator) for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, noskip=True, )[0].next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) if args.eval_bleu: for k, v in bleus.items(): extra_meters[str(k) + ":bleu"].update(v) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer, args, extra_meters) if epoch_itr.epoch > args.switch_obj_epoch: for k, v in extra_meters.items(): #print(k, v.avg) if k.endswith(":loss"): k = k.split(":")[0] trainer.valid_losses[k] = v.avg for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric].avg if args. best_checkpoint_metric == 'loss' else stats[args.best_checkpoint_metric]) if args.eval_bleu: return [sum(bleus.values())] else: return valid_losses
def train(args, trainer, task, epoch_itr, generator=None, filtered_maxpos_indices=None): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf # data selection: reset epoch iter to filter out unselected data if epoch_itr.epoch == args.select_by_dds_epoch and args.select_by_dds_epoch > 0: epoch_itr, _ = trainer.get_filtered_train_iterator( epoch_itr.epoch, filtered_maxpos_indices=filtered_maxpos_indices) if args.update_language_sampling > 0 and args.select_by_dds_epoch < 0 and ( not args.data_actor_step_update): num_reset = len(epoch_itr.frozen_batches) // ( args.update_language_sampling * args.update_freq[0] + 1) datasize = args.update_language_sampling * args.update_freq[0] + 1 if num_reset * datasize < len(epoch_itr.frozen_batches): num_reset += 1 else: num_reset = 1 datasize = -1 for reset_idx in range(num_reset): print("resetting at step", reset_idx) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), offset=reset_idx * (args.update_language_sampling * args.update_freq[0] + 1), datasize=datasize, ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if args.extra_data_actor == 'ave_emb': update_actor = (i % args.extra_update_language_sampling == 0) elif args.data_actor_step_update: update_actor = (i % args.update_language_sampling == 0) elif args.data_actor == 'lan' and args.data_actor_step_update: update_actor = (i % args.update_language_sampling == 0) else: update_actor = False if (epoch_itr.epoch > args.select_by_dds_epoch and args.select_by_dds_epoch > 0): update_actor = False log_output = trainer.train_step(samples, update_actor=update_actor) if log_output is None: continue # update sampling distribution if args.update_language_sampling > 0 and i % args.update_language_sampling == 0 and args.data_actor != 'ave_emb' and not args.data_actor_step_update: if args.data_actor_multilin: trainer.update_language_sampler_multilin( args, epoch=epoch_itr.epoch) else: trainer.update_language_sampler(args) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, generator) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() return epoch_itr
def score(args, trainer, task, epoch_itr, subset): mlperf_print(key=mlperf_compliance.constants.EVAL_START, metadata={'epoch_num': epoch_itr.epoch}, sync=True) begin = time.time() if not subset in task.datasets.keys(): task.load_dataset(subset) src_dict = deepcopy(task.source_dictionary ) # This is necessary, generation of translations tgt_dict = deepcopy( task.target_dictionary ) # alters target dictionary messing up with the rest of training model = trainer.get_model() # Initialize data iterator itr = data.EpochBatchIterator( dataset=task.dataset(subset), max_tokens=min(2560, args.max_tokens), max_sentences=max( 8, min(math.ceil(1024 / args.distributed_world_size), 128)), max_positions=(256, 256), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, seq_len_multiple=args.seq_len_multiple, # Use a large growth factor to get fewer buckets. # Fewer buckets yield faster eval since batches are filled from single bucket # and eval dataset is small. bucket_growth_factor=1.2, batching_scheme=args.batching_scheme, batch_multiple_strategy=args.batch_multiple_strategy, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() translator = SequenceGenerator( [model], tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len, ) # Generate and compute BLEU ref_toks = [] sys_toks = [] num_sentences = 0 has_target = True if args.log_translations: log = open( os.path.join( args.save_dir, 'translations_epoch{}_{}'.format(epoch_itr.epoch, args.distributed_rank)), 'w+') with progress_bar.build_progress_bar(args, itr) as progress: translations = translator.generate_batched_itr( progress, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=True, timer=gen_timer, prefix_size=args.prefix_size, ) wps_meter = TimeMeter() for sample_id, src_tokens, target_tokens, hypos in translations: # Process input and grount truth has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None src_str = src_dict.string(src_tokens, args.remove_bpe) if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe) if args.log_translations: log.write('S-{}\t{}\n'.format(sample_id, src_str)) if has_target: log.write('T-{}\t{}\n'.format(sample_id, target_str)) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe) if args.log_translations: log.write('H-{}\t{}\t{}\n'.format(sample_id, hypo['score'], hypo_str)) # log.write(str(hypo_tokens)) log.write('P-{}\t{}\n'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) # Score only the top hypothesis if has_target and i == 0: src_str = detokenize_subtokenized_sentence(src_str) target_str = detokenize_subtokenized_sentence(target_str) hypo_str = detokenize_subtokenized_sentence(hypo_str) sys_tok = bleu_tokenize( (hypo_str.lower() if args.ignore_case else hypo_str)) ref_tok = bleu_tokenize((target_str.lower() if args.ignore_case else target_str)) sys_toks.append(sys_tok) ref_toks.append(ref_tok) wps_meter.update(src_tokens.size(0)) progress.log({'wps': round(wps_meter.avg)}) num_sentences += 1 bleu_score_reference = compute_bleu(ref_toks, sys_toks, args) bleu_score_reference_str = '{:.4f}'.format(bleu_score_reference) if args.log_translations: log.close() if gen_timer.sum != 0: print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: bleu_score={}'.format( subset, args.beam, bleu_score_reference_str)) print('| Eval completed in: {:.2f}s'.format(time.time() - begin)) mlperf_print(key=mlperf_compliance.constants.EVAL_STOP, metadata={'epoch_num': epoch_itr.epoch}, sync=True) return bleu_score_reference
def main(args): utils.import_user_module(args) if args.buffer_size < 1: args.buffer_size = 1 if args.max_tokens is None and args.max_sentences is None: args.max_sentences = 1 assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ '--max-sentences/--batch-size cannot be larger than --buffer-size' print(args) use_cuda = torch.cuda.is_available() and not args.cpu use_ctc_loss = True if args.criterion == 'ctc_loss' else False # Setup task, e.g., image captioning task = tasks.setup_task(args) # Load dataset split task.load_dataset(args.gen_subset, combine=True, epoch=0) # Load ensemble print('| loading model(s) from {}'.format(args.path)) model_paths = args.path.split(':') models, _model_args = checkpoint_utils.load_model_ensemble( model_paths, arg_overrides=eval(args.model_overrides), task=task, ) # Set dictionaries tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) stats = collections.OrderedDict() num_sentences = 0 num_correct = 0 has_target = True with progress_bar.build_progress_bar( args, itr, prefix='inference on \'{}\' subset'.format(args.gen_subset), no_progress_bar='simple', ) as progress: wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = task.inference_step(generator, models, sample) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None target_tokens = None if has_target: if use_ctc_loss: target_tokens = sample['target'][i] target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) else: # Remove padding target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Regenerate original sentences from tokens. target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: if has_target: print('\nT-{}\t{}'.format(sample_id, target_str)) # Process top predictions hypo = hypos[i][0] hypo_tokens = hypo['tokens'] if use_ctc_loss else hypo[ 'tokens'].int().cpu() hypo_str = tgt_dict.string(hypo_tokens, args.remove_bpe, escape_unk=True) alignment = hypo['alignment'].int().cpu( ) if hypo['alignment'] is not None else None if hypo_str == target_str: num_correct += 1 if not args.quiet: print('H-{}\t{}\t{}'.format(sample_id, hypo_str, hypo['score'])) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )) if not use_ctc_loss else None)) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join( map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis if has_target: if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) num_sentences += sample['nsentences'] stats['wps'] = round(wps_meter.avg) stats['acc'] = num_correct / num_sentences progress.log(stats, tag='accuracy') progress.print(stats, tag='accuracy') print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) return scorer
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') # update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] if args.enable_parallel_backward_allred_opt and update_freq > 1: raise RuntimeError( '--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1' ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) #begin = time.time() #inside = 0 for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): #newbegin = time.time() #print("iter time", newbegin - begin, inside, (newbegin - begin - inside)*1000) #begin = newbegin if i < num_batches - 1 and (i + 1) % update_freq > 0: # buffer updates according to --update-freq trainer.train_step(sample, update_params=False, last_step=(i == len(itr) - 1)) continue else: log_output = trainer.train_step(sample, update_params=True, last_step=(i == len(itr) - 1)) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() if args.profile is not None and i == args.profile: import sys sys.exit() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break #end = time.time() #inside = end - begin # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def main(args): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() model.decoder.alignment_layer = args.alignment_layer itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) if args.print_vanilla_alignment: import string punc = string.punctuation src_punc_tokens = [ w for w in range(len(src_dict)) if src_dict[w] in punc ] tgt_punc_tokens = [ w for w in range(len(tgt_dict)) if tgt_dict[w] in punc ] else: src_punc_tokens = None import time print('start time is :', time.strftime("%Y-%m-%d %X")) with progress_bar.build_progress_bar(args, itr) as t: if args.decoding_path is not None: align_sents = [[] for _ in range(4000000)] for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue if args.print_vanilla_alignment: if args.set_shift: alignments = utils.extract_soft_alignment( sample, models[0], src_punc_tokens, tgt_punc_tokens, alignment_task=args.alignment_task) else: alignments = utils.extract_soft_alignment_noshift( sample, models[0], src_punc_tokens, tgt_punc_tokens, alignment_task=args.alignment_task) else: alignments = None for sample_id in sample['id'].tolist(): if args.print_vanilla_alignment and args.decoding_path is not None: align_sents[int(sample_id)].append( alignments[int(sample_id)]) print('end time is :', time.strftime("%Y-%m-%d %X")) if args.decoding_path is not None and args.print_vanilla_alignment: with open( os.path.join( args.decoding_path, f'{args.gen_subset}.{args.source_lang}2{args.target_lang}.align' ), 'w') as f: for sents in align_sents: if len(sents) == 0: continue for sent in sents: f.write(str(sent) + '\n') print("finished ...")
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') # update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): if i < num_batches - 1 and (i + 1) % update_freq > 0: # buffer updates according to --update-freq trainer.train_step(sample, update_params=False) continue else: log_output = trainer.train_step(sample, update_params=True) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def main(args): print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset if args.replace_unk is None: dataset = data.load_dataset( args.data, [args.gen_subset], args.source_lang, args.target_lang, ) else: dataset = data.load_raw_text_dataset( args.data, [args.gen_subset], args.source_lang, args.target_lang, args.doctopics, args.encoder_embed_dim, ) if args.source_lang is None or args.target_lang is None: # record inferred languages in args args.source_lang, args.target_lang = dataset.src, dataset.dst # Load ensemble print('| loading model(s) from {}'.format(', '.join(args.path))) models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset]))) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, ) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) max_positions = min(model.max_encoder_positions() for model in models) itr = dataset.eval_dataloader( args.gen_subset, max_sentences=args.max_sentences, max_positions=max_positions, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, ) if args.num_shards > 1: if args.shard_id < 0 or args.shard_id >= args.num_shards: raise ValueError('--shard-id must be between 0 and num_shards') itr = data.sharded_iterator(itr, args.num_shards, args.shard_id) # print("SHASHI: I AM HERE") # Initialize generator gen_timer = StopwatchMeter() if args.score_reference: translator = SequenceScorer(models) else: translator = SequenceGenerator( models, beam_size=args.beam, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, unk_penalty=args.unkpen) if use_cuda: translator.cuda() # Generate and compute BLEU score scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: if args.score_reference: translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) else: translations = translator.generate_batched_itr( t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size) wps_meter = TimeMeter() for sample_id, src_tokens, target_tokens, hypos in translations: # Process input and ground truth has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = dataset.splits[ args.gen_subset].src.get_original_text(sample_id) target_str = dataset.splits[ args.gen_subset].dst.get_original_text(sample_id) else: src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) target_str = dataset.dst_dict.string( target_tokens, args.remove_bpe, escape_unk=True) if has_target else '' if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu(), align_dict=align_dict, dst_dict=dataset.dst_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) print('A-{}\t{}'.format( sample_id, ' '.join(map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis if has_target and i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tokenizer.Tokenizer.tokenize( target_str, dataset.dst_dict, add_if_not_exist=True) scorer.add(target_tokens, hypo_tokens) wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) num_sentences += 1 print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'. format(num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def main(args): check_args(args) import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 30000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) logger.info("| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Set dictionary tgt_dict = task.target_dictionary logger.info("| decoding with criterion {}".format(args.criterion)) # Load ensemble logger.info("| loading model(s) from {}".format(args.path)) models, criterions, _model_args = load_models_and_criterions( args.path.split(":"), arg_overrides=eval(args.model_overrides), # noqa task=task, ) optimize_models(args, use_cuda, models) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": trans = criterions[0].asg.trans.data args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) num_sentences = 0 if not os.path.exists(args.results_path): os.makedirs(args.results_path) sp = spm.SentencePieceProcessor() sp.Load(os.path.join(args.data, "spm.model")) res_files = prepare_result_files(args) with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): speaker = task.dataset( args.gen_subset).speakers[int(sample_id)] id = task.dataset(args.gen_subset).ids[int(sample_id)] target_tokens = (utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()) # Process top predictions process_predictions(args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id) wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) num_sentences += sample["nsentences"] logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" "sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, )) logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
def main(args): assert args.path is not None, '--path required for recognition!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset split task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionary dictionary = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), task=task, ) for i, m in enumerate(models): if hasattr(m, 'is_wordlm') and m.is_wordlm: # assume subword LM comes before word LM if isinstance(models[i - 1], FairseqLanguageModel): models[i-1] = MultiLevelLanguageModel( m, models[i-1], subwordlm_weight=args.subwordlm_weight, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) del models[i] print('| LM fusion with Multi-level LM') else: models[i] = TensorizedLookaheadLanguageModel( m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) print('| LM fusion with Look-ahead Word LM') # assume subword LM comes after E2E models elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): print('| LM fusion with Subword LM') if args.lm_weight != 0.0: print('| using LM fusion with lm-weight={:.2f}'.format(args.lm_weight)) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() if hasattr(model, 'encoder') else (None, model.max_positions()) for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator if args.match_source_len: print('| The option match_source_len is not applicable to ' 'speech recognition. Ignoring it.') gen_timer = StopwatchMeter() generator = task.build_generator(args) # Generate and compute WER scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens, lm_weight=args.lm_weight, ) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions if args.print_alignment: net_input = sample['net_input'] src_tokens = net_input['src_tokens'] output_lengths = models[0].encoder.output_lengths(net_input['src_lengths']) nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) for i in range(len(sample['id'])): has_target = sample['target'] is not None utt_id = sample['utt_id'][i] # Retrieve the original sentences if has_target: target_str = sample['target_raw_text'][i] if not args.quiet: target_sent = dictionary.tokens_to_sentence( target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, ) print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_str = dictionary.string(hypo['tokens'].int().cpu()) # not removing bpe at this point if not args.quiet or i == 0: hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) if not args.quiet: print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) # Score and obtain attention only the top hypothesis if j == 0: # src_len x tgt_len attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \ if args.print_alignment and hypo['attention'] is not None else None if args.print_alignment and attention is not None: save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) plot_attention(attention, hypo_sent, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe) if has_target: scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] print('| Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: print('| Saved attention plots in ' + save_dir) if has_target: assert args.test_text_files is not None scorer.add_ordered_utt_list(*args.test_text_files) os.makedirs(args.results_path, exist_ok=True) fn = 'decoded_char_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_char_results()) print('| Decoded char results saved as ' + f.name) fn = 'decoded_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_results()) print('| Decoded results saved as ' + f.name) if has_target: header = ' Recognize {} with beam={}: '.format(args.gen_subset, args.beam) fn = 'wer' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.wer())) print('|' + header + res) f.write(res + '\n') print('| WER saved in ' + f.name) fn = 'cer' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.cer())) print('|' + ' ' * len(header) + res) f.write(res + '\n') print('| CER saved in ' + f.name) fn = 'aligned_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_aligned_results()) print('| Aligned results saved as ' + f.name) return scorer
def _generate_score(models, args, task, dataset): use_cuda = torch.cuda.is_available() and not args.cpu # Load ensemble if not args.quiet: print("| loading model(s) from {}".format(", ".join( args.path.split(":")))) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=True, ) translator = build_sequence_generator(args, task, models) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) if args.max_examples_to_evaluate > 0: pytorch_translate_data.subsample_pair_dataset( dataset, args.max_examples_to_evaluate) # Keep track of translations # Initialize with empty translations # and zero probs scores translated_sentences = [""] * len(dataset) translated_scores = [0.0] * len(dataset) hypos_list = [] collect_output_hypos = getattr(args, "output_hypos_binary_path", False) if collect_output_hypos: output_hypos_token_arrays = [None] * len(dataset) # Generate and compute BLEU score dst_dict = task.target_dictionary if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk()) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) oracle_scorer = None if args.report_oracle_bleu: oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk()) rescorer = None num_sentences = 0 translation_samples = [] translation_info_list = [] with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() gen_timer = StopwatchMeter() translations = translator.generate_batched_itr( t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda, timer=gen_timer, prefix_size=1 if pytorch_translate_data.is_multilingual(args) else 0, ) for trans_info in _iter_translations(args, task, dataset, translations, align_dict, rescorer): if hasattr(scorer, "add_string"): scorer.add_string(trans_info.target_str, trans_info.hypo_str) else: scorer.add(trans_info.target_tokens, trans_info.hypo_tokens) if oracle_scorer is not None: oracle_scorer.add(trans_info.target_tokens, trans_info.best_hypo_tokens) if getattr(args, "translation_output_file", False): translated_sentences[ trans_info.sample_id] = trans_info.hypo_str if getattr(args, "hypotheses_export_path", False): hypos_list.append(trans_info.hypos) if collect_output_hypos: output_hypos_token_arrays[ trans_info.sample_id] = trans_info.best_hypo_tokens if args.translation_info_export_path is not None: # Strip expensive data from hypotheses before saving hypos = [{ k: v for k, v in hypo.items() if k in ["tokens", "score"] } for hypo in trans_info.hypos] # Make sure everything is on cpu before exporting hypos = [{ "score": hypo["score"], "tokens": hypo["tokens"].cpu() } for hypo in hypos] translation_info_list.append({ "src_tokens": trans_info.src_tokens.cpu(), "target_tokens": trans_info.target_tokens, "hypos": hypos, }) translation_samples.append( collections.OrderedDict({ "sample_id": trans_info.sample_id.item(), "src_str": trans_info.src_str, "target_str": trans_info.target_str, "hypo_str": trans_info.hypo_str, })) wps_meter.update(trans_info.src_tokens.size(0)) t.log({"wps": round(wps_meter.avg)}) num_sentences += 1 # If applicable, save collected hypothesis tokens to binary output file if collect_output_hypos: output_dataset = pytorch_translate_data.InMemoryNumpyDataset() output_dataset.load_from_sequences(output_hypos_token_arrays) output_dataset.save(args.output_hypos_binary_path) if args.translation_info_export_path is not None: f = open(args.translation_info_export_path, "wb") pickle.dump(translation_info_list, f) f.close() # If applicable, save the translations and hypos to the output file # For eg. external evaluation if getattr(args, "translation_output_file", False): with open(args.translation_output_file, "w") as out_file: for hypo_str in translated_sentences: print(hypo_str, file=out_file) if getattr(args, "hypotheses_export_path", False): with open(args.hypotheses_export_path, "w") as out_file: for hypos in hypos_list: for hypo in hypos: print( task.tgt_dict.string(hypo["tokens"], bpe_symbol=args.remove_bpe), file=out_file, ) if getattr(args, "translation_probs_file", False): with open(args.translation_probs_file, "w") as out_file: for hypo_score in translated_scores: print(np.exp(hypo_score), file=out_file) if oracle_scorer is not None: print( f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}" ) return scorer, num_sentences, gen_timer, translation_samples
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) if callable(getattr(trainer.criterion, 'set_valid_tgt_dataset', None)): trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size', 'word_count', 'char_count' ]: continue if k == 'word_error': extra_meters['wer'].update( float(v) / log_output['word_count'] * 100, log_output['word_count']) elif k == 'char_error': extra_meters['cer'].update( float(v) / log_output['char_count'] * 100, log_output['char_count']) else: extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric].avg if args. best_checkpoint_metric == 'loss' else stats[args.best_checkpoint_metric]) return valid_losses
def main(args, task=None, model_state=None): check_args(args) if args.max_tokens is None and args.batch_size is None: args.max_tokens = 4000000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu logger.info("| decoding with criterion {}".format(args.criterion)) task = tasks.setup_task(args) # Load ensemble if args.load_emissions: models, criterions = [], [] task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=ast.literal_eval(args.model_overrides), task=task, suffix=args.checkpoint_suffix, strict=(args.checkpoint_shard_count == 1), num_shards=args.checkpoint_shard_count, state=model_state, ) optimize_models(args, use_cuda, models) task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) # Set dictionary tgt_dict = task.target_dictionary logger.info( "| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)) ) ) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": raise NotImplementedError("asg_loss is currently not supported") # trans = criterions[0].asg.trans.data # args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) # Initialize generator gen_timer = StopwatchMeter() def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(args, task.target_dictionary) else: print( "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" ) # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task generator = build_generator(args) if args.load_emissions: generator = ExistingEmissionsDecoder( generator, np.load(args.load_emissions, allow_pickle=True) ) logger.info("loaded emissions from " + args.load_emissions) num_sentences = 0 if args.results_path is not None and not os.path.exists(args.results_path): os.makedirs(args.results_path) max_source_pos = ( utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ) if max_source_pos is not None: max_source_pos = max_source_pos[0] if max_source_pos is not None: max_source_pos = max_source_pos[0] - 1 if args.dump_emissions: emissions = {} if args.dump_features: features = {} models[0].bert.proj = None else: res_files = prepare_result_files(args) errs_t = 0 lengths_t = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, : args.prefix_size] gen_timer.start() if args.dump_emissions: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) emm = models[0].get_normalized_probs(encoder_out, log_probs=True) emm = emm.transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): emissions[id.item()] = emm[i] continue elif args.dump_features: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): padding = ( encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None ) features[id.item()] = (feat[i], padding) continue hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): speaker = None # id = task.dataset(args.gen_subset).ids[int(sample_id)] id = sample_id toks = ( sample["target"][i, :] if "target_label" not in sample else sample["target_label"][i, :] ) target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() # Process top predictions errs, length = process_predictions( args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id, ) errs_t += errs lengths_t += length wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) num_sentences += ( sample["nsentences"] if "nsentences" in sample else sample["id"].numel() ) wer = None if args.dump_emissions: emm_arr = [] for i in range(len(emissions)): emm_arr.append(emissions[i]) np.save(args.dump_emissions, emm_arr) logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") elif args.dump_features: feat_arr = [] for i in range(len(features)): feat_arr.append(features[i]) np.save(args.dump_features, feat_arr) logger.info(f"saved {len(features)} emissions to {args.dump_features}") else: if lengths_t > 0: wer = errs_t * 100.0 / lengths_t logger.info(f"WER: {wer}") logger.info( "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" "sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, ) ) logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) return task, wer
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf start_time = time.time() step = 0 for samples in progress: log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() step += 1 """ if step % 10 == 0: print(step) if step >= 200: pr.disable() #pr.dump_stats( "torch_profile") sys.exit() step += 1 """ if log_output is None: continue # log mid-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.log(stats, tag='train', step=num_updates) if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): print("validate and save_checkpoint") valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break train_epoch_cost = time.time() - start_time # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) print("epoch_cost: %.5f s, avg_speed: %.5f steps/s" % (train_epoch_cost, float(step) / train_epoch_cost)) # reset epoch-level meters metrics.reset_meters('train')
def main(args): assert args.path is not None, '--path required for generation!' args.beam = args.nbest = 1 args.max_tokens = int(1e4) utils.import_user_module(args) # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) src_dict = getattr(task, 'source_dictionary', None) tgt_dict = task.target_dictionary models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_(beamable_mm_beam_size=args.beam, need_attn=False) model.cuda() itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) generator = task.build_generator(args) with progress_bar.build_progress_bar(args, itr) as t: for sample in t: sample = utils.move_to_cuda(sample) if 'net_input' not in sample: continue prefix_tokens = None hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) for i, sample_id in enumerate(sample['id'].tolist()): # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" # Process top predictions hypo = hypos[i][0] hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=None, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) result = dict(src=src_str, pred=hypo_str, src_len=len(src_str.split()), pred_len=len(hypo_str.split())) result_line = json.dumps(result) print(result_line)
def compute_top_k( task, models, dataset, k, use_cuda, max_tokens=None, max_sentences=None, progress_bar_args=None, ): """ This function runs forward computation on an ensemble of trained models using binarized parallel training data and returns the top-k probabilities and their corresponding token indices for each output step. Returns: (top_k_scores, top_k_indices) Each a NumPy array of size (total_target_tokens, k) """ top_k_scores_list = [None for _ in range(len(dataset))] top_k_indices_list = [None for _ in range(len(dataset))] itr = task.get_batch_iterator( dataset=dataset, max_tokens=max_tokens, max_sentences=max_sentences).next_epoch_itr(shuffle=False) if progress_bar_args is not None: itr = progress_bar.build_progress_bar( args=progress_bar_args, iterator=itr, prefix=f"top-k probs eval", no_progress_bar="simple", ) for sample in itr: sentence_ids = sample["id"] target_lengths = ((sample["net_input"]["prev_output_tokens"] != dataset.tgt_dict.pad()).sum(axis=1).numpy()) if use_cuda: sample = utils.move_to_cuda(sample) avg_probs = None for model in models: with torch.no_grad(): net_output = model(**sample["net_input"]) probs = model.get_normalized_probs(net_output, log_probs=False) if avg_probs is None: avg_probs = probs else: avg_probs.add_(probs) avg_probs.div_(len(models)) top_k_avg_probs, indices = torch.topk(avg_probs, k=k) top_k_probs_normalized = F.normalize(top_k_avg_probs, p=1, dim=2).cpu() indices = indices.cpu() for i, sentence_id in enumerate(sentence_ids): length = target_lengths[i] top_k_scores_list[sentence_id] = top_k_probs_normalized[ i][:length].numpy() top_k_indices_list[sentence_id] = indices[i][:length].numpy() assert all(top_k_scores is not None for top_k_scores in top_k_scores_list), "scores not calculated for all examples!" assert all(top_k_indices is not None for top_k_indices in top_k_indices_list), "indices not calculated for all examples!" top_k_scores = np.concatenate(top_k_scores_list, axis=0) top_k_indices = np.concatenate(top_k_indices_list, axis=0) return top_k_scores, top_k_indices
def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda): task = tasks.setup_task(cfg.fairseq.task) saved_cfg.task.labels = cfg.fairseq.task.labels task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task) # Set dictionary tgt_dict = task.target_dictionary logger.info("| {} {} {} examples".format( cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset, len(task.dataset(cfg.fairseq.dataset.gen_subset)), )) # Load dataset (possibly sharded) itr = get_dataset_itr(cfg, task) # Initialize generator gen_timer = StopwatchMeter() def build_generator(cfg: UnsupGenerateConfig): w2l_decoder = cfg.w2l_decoder if w2l_decoder == DecoderType.VITERBI: from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(cfg, task.target_dictionary) elif w2l_decoder == DecoderType.KENLM: from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(cfg, task.target_dictionary) elif w2l_decoder == DecoderType.FAIRSEQ: from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(cfg, task.target_dictionary) elif w2l_decoder == DecoderType.KALDI: from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder assert cfg.kaldi_decoder_config is not None return KaldiDecoder( cfg.kaldi_decoder_config, cfg.beam, ) else: raise NotImplementedError( "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found " + str(w2l_decoder)) generator = build_generator(cfg) kenlm = None fairseq_lm = None if cfg.lm_model is not None: import kenlm kenlm = kenlm.Model(cfg.lm_model) num_sentences = 0 if cfg.results_path is not None and not os.path.exists(cfg.results_path): os.makedirs(cfg.results_path) res_files = prepare_result_files(cfg) errs_t = 0 lengths_hyp_t = 0 lengths_hyp_unit_t = 0 lengths_t = 0 count = 0 num_feats = 0 all_hyp_pieces = [] all_hyp_words = [] num_symbols = ( len([s for s in tgt_dict.symbols if not s.startswith("madeup")]) - tgt_dict.nspecial) targets = None if cfg.targets is not None: tgt_path = os.path.join( cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset + "." + cfg.targets) if os.path.exists(tgt_path): with open(tgt_path, "r") as f: targets = f.read().splitlines() viterbi_transcript = None if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0: logger.info( f"loading viterbi transcript from {cfg.viterbi_transcript}") with open(cfg.viterbi_transcript, "r") as vf: viterbi_transcript = vf.readlines() viterbi_transcript = [ v.rstrip().split() for v in viterbi_transcript ] gen_timer.start() start = 0 end = len(itr) hypo_futures = None if cfg.w2l_decoder == DecoderType.KALDI: logger.info("Extracting features") hypo_futures = [] samples = [] with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t: for i, sample in enumerate(t): if "net_input" not in sample or i < start or i >= end: continue if "padding_mask" not in sample["net_input"]: sample["net_input"]["padding_mask"] = None hypos, num_feats = gen_hypos(generator, models, num_feats, sample, task, use_cuda) hypo_futures.append(hypos) samples.append(sample) if cfg.debug: break itr = list(zip(hypo_futures, samples)) start = 0 end = len(itr) logger.info("Finished extracting features") with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t: for i, sample in enumerate(t): if i < start or i >= end: continue if hypo_futures is not None: hypos, sample = sample hypos = [h.result() for h in hypos] else: if "net_input" not in sample: continue hypos, num_feats = gen_hypos(generator, models, num_feats, sample, task, use_cuda) for i, sample_id in enumerate(sample["id"].tolist()): if targets is not None: target_tokens = targets[sample_id] elif "target" in sample or "target_label" in sample: toks = (sample["target"][i, :] if "target_label" not in sample else sample["target_label"][i, :]) target_tokens = utils.strip_pad( toks, tgt_dict.pad()).int().cpu() else: target_tokens = None # Process top predictions ( errs, length_hyp, length, hyp_pieces, hyp_words, ) = process_predictions( cfg, hypos[i], tgt_dict, target_tokens, res_files, ) errs_t += errs lengths_hyp_t += length_hyp lengths_hyp_unit_t += (len(hyp_pieces) if len(hyp_pieces) > 0 else len(hyp_words)) lengths_t += length count += 1 all_hyp_pieces.append(hyp_pieces) all_hyp_words.append(hyp_words) num_sentences += (sample["nsentences"] if "nsentences" in sample else sample["id"].numel()) lm_score_sum = 0 if kenlm is not None: if cfg.unit_lm: lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces) else: lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words) elif fairseq_lm is not None: lm_score_sum = sum( fairseq_lm.score([h.split() for h in all_hyp_words])[0]) vt_err_t = 0 vt_length_t = 0 if viterbi_transcript is not None: unit_hyps = [] if cfg.targets is not None and cfg.lexicon is not None: lex = {} with open(cfg.lexicon, "r") as lf: for line in lf: items = line.rstrip().split() lex[items[0]] = items[1:] for h in all_hyp_pieces: hyp_ws = [] for w in h.split(): assert w in lex, w hyp_ws.extend(lex[w]) unit_hyps.append(hyp_ws) else: unit_hyps.extend([h.split() for h in all_hyp_words]) vt_err_t = sum( editdistance.eval(vt, h) for vt, h in zip(viterbi_transcript, unit_hyps)) vt_length_t = sum(len(h) for h in viterbi_transcript) if res_files is not None: for r in res_files.values(): r.close() gen_timer.stop(lengths_hyp_t) return GenResult( count, errs_t, gen_timer, lengths_hyp_unit_t, lengths_hyp_t, lengths_t, lm_score_sum, num_feats, num_sentences, num_symbols, vt_err_t, vt_length_t, )
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) wandb_stats = {} from numbers import Number from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter for k in stats.keys(): stat = stats[k] if isinstance(stat, Number): wandb_stats[k] = stat elif isinstance(stat, AverageMeter): wandb_stats[k] = stat.avg elif isinstance(stat, TimeMeter): wandb_stats[k] = stat.avg elif isinstance(stat, StopwatchMeter): wandb_stats[k] = stat.sum wandb.log(wandb_stats) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' print(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) models, args = utils.load_ensemble_for_inference( parsed_args.path.split(':'), task) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary' }: setattr(args, arg, getattr(parsed_args, arg)) task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() if hasattr(model, 'set_targets'): model.set_targets(args, task) assert len(models) > 0 print('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), num_shards=args.num_shards, shard_id=args.shard_id, ignore_invalid_inputs=True, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(models, task.target_dictionary, target_idx=args.target_idx) if use_cuda: scorer.cuda() score_sum = 0. count = 0 if args.remove_bpe is not None: bpe_cont = args.remove_bpe.rstrip() bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() with progress_bar.build_progress_bar(args, itr) as t: results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) wps_meter = TimeMeter() for _, src_tokens, __, hypos in results: for hypo in hypos: pos_scores = hypo['positional_scores'] skipped_toks = 0 if bpe_toks is not None: for i in range(len(hypo['tokens']) - 1): if hypo['tokens'][i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq( float('-inf')) if inf_scores.any(): print( '| Skipping tokens with inf scores:', task.target_dictionary.string( hypo['tokens'][inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(hypo['tokens'])): w_ind = hypo['tokens'][i].item() w += task.dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(hypo['tokens']): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): print(ws)
def main(args): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() if args.score_reference: translator = SequenceScorer(models, task.target_dictionary) else: translator = SequenceGenerator( models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature, diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength, ) if use_cuda: translator.cuda() # Generate and compute BLEU score scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: if args.score_reference: translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) else: translations = translator.generate_batched_itr( t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, ) wps_meter = TimeMeter() for sample_id, src_tokens, target_tokens, hypos in translations: # Process input and ground truth has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) else: src_str = src_dict.string(src_tokens, args.remove_bpe) if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) print('P-{}\t{}'.format( sample_id, ' '.join(map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )) )) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join(map(lambda x: str(utils.item(x)), alignment)) )) # Score only the top hypothesis if has_target and i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tokenizer.Tokenizer.tokenize( target_str, tgt_dict, add_if_not_exist=True) scorer.add(target_tokens, hypo_tokens) wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) num_sentences += 1 print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' print(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task) args.__dict__.update(parsed_args.__dict__) print(args) task.args = args # Load dataset splits task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() assert len(models) > 0 itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), num_shards=args.num_shards, shard_id=args.shard_id, ignore_invalid_inputs=True, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(models, task.target_dictionary) if use_cuda: scorer.cuda() score_sum = 0. count = 0 if args.remove_bpe is not None: bpe_cont = args.remove_bpe.rstrip() bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() with progress_bar.build_progress_bar(args, itr) as t: results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) wps_meter = TimeMeter() for _, src_tokens, __, hypos in results: for hypo in hypos: pos_scores = hypo['positional_scores'] skipped_toks = 0 if bpe_toks is not None: for i in range(len(hypo['tokens']) - 1): if hypo['tokens'][i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) if inf_scores.any(): print('| Skipping tokens with inf scores:', task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += utils.item(pos_scores.sum()) count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(hypo['tokens'])): w_ind = hypo['tokens'][i].item() w += task.dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item()) is_bpe = False w = '' if args.output_word_probs: print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): print(ws)