def get_train_iterator( self, epoch, combine=True, load_dataset=True, data_selector=None, shard_batch_itr=True, ): """Return an EpochBatchIterator over the training set for a given epoch.""" if load_dataset: logger.info("loading train data for epoch {}".format(epoch)) self.task.load_dataset( self.args.train_subset, epoch=epoch, combine=combine, data_selector=data_selector, ) return self.task.get_batch_iterator( dataset=self.task.dataset(self.args.train_subset), max_tokens=self.args.max_tokens, max_sentences=self.args.max_sentences, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), self.args.max_tokens, ), ignore_invalid_inputs=True, required_batch_size_multiple=self.args.required_batch_size_multiple, seed=self.args.seed, num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.args.num_workers, epoch=epoch, )
def __init__(self, args, task, models): super().__init__() self.args = args self.task = task self.models = nn.ModuleList(models) self.src_dict = task.source_dictionary self.tgt_dict = task.target_dictionary # optimize model for generation for model in self.models: model.prepare_for_inference_(args) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) self.align_dict = utils.load_align_dict( getattr(args, 'replace_unk', None)) self.tokenizer = encoders.build_tokenizer(args) self.bpe = encoders.build_bpe(args) self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in models]) # this is useful for determining the device self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
def get_valid_iterator( self, subset, ): """Return an EpochBatchIterator over given validation subset for a given epoch.""" return self.task.get_batch_iterator( dataset=self.task.dataset(subset), max_tokens=self.args.max_tokens_valid, max_sentences=self.args.max_sentences_valid, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), ), ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=self.args.required_batch_size_multiple, seed=self.args.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, num_workers=self.args.num_workers, )
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if use_cuda: torch.cuda.set_device(args.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=getattr(args, 'all_gather_list_size', 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('generate') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Optimize ensemble for generation for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(args, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding if 'src_tokens' in sample['net_input']: src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) else: src_tokens = None target_tokens = None if has_target: target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, args.remove_bpe, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) src_str = decode_fn(src_str) if has_target: target_str = decode_fn(target_str) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) if has_target: print('T-{}\t{}'.format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) # detokenized hypothesis print('D-{}\t{}\t{}'.format(sample_id, score, detok_hypo_str), file=output_file) print('P-{}\t{}'.format( sample_id, ' '.join(map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2)).tolist(), )) ), file=output_file) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) ), file=output_file) if args.print_step: print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) if getattr(args, 'retain_iter_history', False): for step, h in enumerate(hypo['history']): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line(detok_hypo_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") else: logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), file=output_file) return scorer
def main(args, task=None, model_state=None): check_args(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 4000000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu if task is None: # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) logger.info("| {} {} {} examples".format( args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) label_path = os.path.join(args.data, "{}.word".format(args.gen_subset)) labels = [] with open(label_path, "r") as f: for line in f: labels.append(line) # Set dictionary tgt_dict = task.target_dictionary logger.info("| decoding with criterion {}".format(args.criterion)) # Load ensemble if args.load_emissions: models, criterions = [], [] else: logger.info("| loading model(s) from {}".format(args.path)) models, criterions, _ = load_models_and_criterions( args.path, data_path=args.data, arg_overrides=eval(args.model_overrides), # noqa task=task, model_state=model_state, ) optimize_models(args, use_cuda, models) # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": trans = criterions[0].asg.trans.data args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) # Initialize generator gen_timer = StopwatchMeter() def build_generator(args): w2l_decoder = getattr(args, "w2l_decoder", None) if w2l_decoder == "viterbi": from speech_recognition.w2l_decoder import W2lViterbiDecoder return W2lViterbiDecoder(args, task.target_dictionary) elif w2l_decoder == "kenlm": from speech_recognition.w2l_decoder import W2lKenLMDecoder return W2lKenLMDecoder(args, task.target_dictionary) elif w2l_decoder == "fairseqlm": from speech_recognition.w2l_decoder import W2lFairseqLMDecoder return W2lFairseqLMDecoder(args, task.target_dictionary) elif w2l_decoder == "ctc_decoder": from speech_recognition.ctc_decoder import CTCDecoder return CTCDecoder(args, task.target_dictionary) else: return super().build_generator(args) generator = build_generator(args) if args.load_emissions: generator = ExistingEmissionsDecoder( generator, np.load(args.load_emissions, allow_pickle=True)) logger.info("loaded emissions from " + args.load_emissions) num_sentences = 0 if args.results_path is not None and not os.path.exists(args.results_path): os.makedirs(args.results_path) max_source_pos = (utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ) if max_source_pos is not None: max_source_pos = max_source_pos[0] if max_source_pos is not None: max_source_pos = max_source_pos[0] - 1 if args.dump_emissions: emissions = {} if args.dump_features: features = {} models[0].bert.proj = None else: res_files = prepare_result_files(args) errs_t = 0 lengths_t = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample["target"][:, :args.prefix_size] gen_timer.start() if args.dump_emissions: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) emm = models[0].get_normalized_probs(encoder_out, log_probs=True) emm = emm.transpose(0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): emissions[id.item()] = emm[i] continue elif args.dump_features: with torch.no_grad(): encoder_out = models[0](**sample["net_input"]) feat = encoder_out["encoder_out"].transpose( 0, 1).cpu().numpy() for i, id in enumerate(sample["id"]): padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() \ if encoder_out["encoder_padding_mask"] is not None else None features[id.item()] = (feat[i], padding) continue hypos = task.inference_step(generator, models, sample, prefix_tokens) tokens_len = 0 for h in hypos: try: h_len = len(h[0]["tokens"]) tokens_len += h_len except Exception as ex: print(ex) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): speaker = None # id = task.dataset(args.gen_subset).ids[int(sample_id)] id = sample_id toks = sample["target"][ i, :] if 'target_label' not in sample else sample[ "target_label"][i, :] target_tokens = (utils.strip_pad(toks, tgt_dict.pad()).int().cpu()) # Process top predictions errs, length = process_predictions(args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id, labels) errs_t += errs lengths_t += length wps_meter.update(num_generated_tokens) t.log({"wps": round(wps_meter.avg)}) num_sentences += sample[ "nsentences"] if "nsentences" in sample else sample[ "id"].numel() wer = None if args.dump_emissions: emm_arr = [] for i in range(len(emissions)): emm_arr.append(emissions[i]) np.save(args.dump_emissions, emm_arr) logger.info( f"saved {len(emissions)} emissions to {args.dump_emissions}") elif args.dump_features: feat_arr = [] for i in range(len(features)): feat_arr.append(features[i]) np.save(args.dump_features, feat_arr) logger.info(f"saved {len(features)} emissions to {args.dump_features}") else: if lengths_t > 0: wer = errs_t * 100.0 / lengths_t logger.info(f"WER: {wer}") logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" "sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, )) logger.info("| Generate {} with beam={}".format( args.gen_subset, args.beam)) return task, wer
def main(args): start_time = time.time() total_translate_time = 0 utils.import_user_module(args) if args.buffer_size < 1: args.buffer_size = 1 if args.max_tokens is None and args.max_sentences is None: args.max_sentences = 1 assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ '--max-sentences/--batch-size cannot be larger than --buffer-size' logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Setup task, e.g., translation task = tasks.setup_task(args) # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() # Initialize generator generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]) if args.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) if args.buffer_size > 1: logger.info('Sentence buffer size: %s', args.buffer_size) logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Type the input sentence and press return:') start_id = 0 for inputs in buffered_read(args.input, args.buffer_size): results = [] for batch in make_batches(inputs, args, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths, }, } translate_start_time = time.time() translations = task.inference_step(generator, models, sample, constraints=constraints) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] if args.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append((start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations) })) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) print('S-{}\t{}'.format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print("C-{}\t{}".format( id_, tgt_dict.string(constraint, args.remove_bpe))) # Process top predictions for hypo in hypos[:min(len(hypos), args.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator), ) detok_hypo_str = decode_fn(hypo_str) score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print('H-{}\t{}\t{}'.format(id_, score, hypo_str)) # detokenized hypothesis print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str)) print('P-{}\t{}'.format( id_, ' '.join( map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2) ).tolist(), )))) if args.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment]) print('A-{}\t{}'.format(id_, alignment_str)) # update running id_ counter start_id += len(inputs) logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time))
def main(parsed_args, **unused_kwargs): assert parsed_args.path is not None, '--path required for evaluation!' if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) if inf_scores.any(): logger.info( 'skipping tokens with inf scores:', task.target_dictionary.string(tokens[inf_scores.nonzero()]) ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) ) wps_meter.update(sample['ntokens']) progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg )) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss )) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)