def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() if args.knnlm and args.save_knnlm_dstore: raise ValueError( "Cannot use knnlm while trying to build the datastore!") if args.knnlm: knn_dstore = KNN_Dstore(args) with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() if args.save_knnlm_dstore: print('keytype being saved:', args.knn_keytype) if args.dstore_fp16: print('Saving fp16') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int16, mode='w+', shape=(args.dstore_size, 1)) else: print('Saving fp32') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) dstore_idx = 0 #knn_probs_file = open(args.output_log_probs_file_prefix + '.knn.txt', 'w') #orig_probs_file = open(args.output_log_probs_file_prefix + '.orig.txt', 'w') if args.knnlm: dists_file = open(args.output_log_probs_file_prefix + '.dists.txt', 'w') knns_file = open( args.output_log_probs_file_prefix + '.knn_indices.txt', 'w') if args.save_knnlm_dstore or args.knnlm: tokens_file = open(args.output_tokens_file, 'w') for ex_i, sample in enumerate(t): if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() if args.knnlm: hypos = scorer.generate(models, sample, knn_dstore=knn_dstore) else: hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): if i == len(hypos) - 1: continue hypo = hypos_i[0] skipped = False if args.save_knnlm_dstore: shape = hypo['dstore_keys'].shape if shape[0] == args.tokens_per_sample: if dstore_idx + shape[0] > args.dstore_size: shape = [args.dstore_size - dstore_idx] hypo['dstore_keys'] = hypo[ 'dstore_keys'][:shape[0]] if args.dstore_fp16: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float16) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int16) else: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float32) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int) dstore_idx += shape[0] else: skipped = True print('Skipping this one with shape', shape) sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() orig_scores = hypo['original_scores'].float() yhat_scores = hypo['yhat_scores'].float() if args.knnlm: assert hypo['dists_full'] != None dists_full = hypo['dists_full'].float() knns_full = hypo['knns_full'] # knn_probs_file.write('\n'.join([str(prob) for prob in yhat_scores.tolist()]) + '\n') # orig_probs_file.write('\n'.join([str(prob) for prob in orig_scores.tolist()]) + '\n') dists_file.write('\n'.join([ str(dists_for_token) for dists_for_token in dists_full.tolist() ]) + '\n') knns_file.write('\n'.join([ str(knns_for_token) for knns_for_token in knns_full.tolist() ]) + '\n') if args.save_knnlm_dstore or args.knnlm: if not skipped: word_tokens = [ task.target_dictionary[token] for token in hypo['tokens'] ] tokens_file.write('\n'.join(word_tokens) + '\n') assert len( hypo['yhat_scores'].float().tolist()) == len( word_tokens) ''' doc = spacy.tokens.doc.Doc( nlp.vocab, words=word_tokens, spaces=[True for token in tokens]) for name, proc in nlp.pipeline: doc = proc(doc) ''' if args.add_bos_token: assert hypo['tokens'][0].item( ) == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) #if inf_scores.any(): # logger.info( # 'skipping tokens with inf scores:', # task.target_dictionary.string(tokens[inf_scores.nonzero()]) # ) # pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) if args.save_knnlm_dstore: print("dstore_idx", dstore_idx, "final shape", shape) print("Keys", dstore_keys.shape, dstore_keys.dtype) print("Vals", dstore_vals.shape, dstore_vals.dtype) # knn_probs_file.close() # orig_probs_file.close() tokens_file.close() # Entities # mask = torch.tensor([1 if token.ent_type_ else 0 for token in doc], dtype=float) # count_entities = mask.sum() # if torch.cuda.is_available() and not parsed_args.cpu: # mask = mask.cuda() # avg_nll_loss_entities = - (pos_scores * mask).sum() / count_entities.cpu() / math.log(2) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss)) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def eval_lm( models: List[fairseq.models.FairseqModel], source_dictionary: fairseq.data.Dictionary, batch_iterator: Iterable, post_process: Optional[str] = None, output_word_probs: bool = False, output_word_stats: bool = False, target_dictionary: Optional[fairseq.data.Dictionary] = None, softmax_batch: int = 0, remove_bos_token: bool = False, device: Optional[torch.device] = None, ): """ Args: models (List[~fairseq.models.FairseqModel]): list of models to evaluate. Models are essentially `nn.Module` instances, but must be compatible with fairseq's `SequenceScorer`. source_dictionary (~fairseq.data.Dictionary): dictionary for applying any relevant post processing or outputing word probs/stats. batch_iterator (Iterable): yield batches of data post_process (Optional[str]): post-process text by removing BPE, letter segmentation, etc. Valid options can be found in fairseq.data.utils.post_process, although not all options are implemented here. output_word_probs (Optional[bool]): output words and their predicted log probabilities output_word_stats (Optional[bool]): output word statistics such as word count and average probability target_dictionary (Optional[~fairseq.data.Dictionary]): output dictionary (defaults to *source_dictionary*) softmax_batch (Optional[bool]): if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory remove_bos_token (Optional[bool]): if True, confirm that the first token is the beginning-of-sentence symbol (according to the relevant dictionary) and remove it from the output device (Optional[torch.device]): device to use for evaluation (defaults to device of first model parameter) """ if target_dictionary is None: target_dictionary = source_dictionary if device is None: device = next(models[0].parameters()).device gen_timer = StopwatchMeter() scorer = SequenceScorer(target_dictionary, softmax_batch) score_sum = 0.0 count = 0 if post_process is not None: if post_process in {"subword_nmt", "@@ "}: bpe_cont = post_process.rstrip() bpe_toks = { i for i in range(len(source_dictionary)) if source_dictionary[i].endswith(bpe_cont) } else: raise NotImplementedError( "--post-process={post_process} is not implemented") bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() for sample in batch_iterator: if "net_input" not in sample: continue sample = utils.move_to_cuda(sample, device=device) gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample["id"][i] tokens = hypo["tokens"] tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() if remove_bos_token: assert hypo["tokens"][0].item() == target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq( float("-inf")) if inf_scores.any(): logger.info( "skipping tokens with inf scores:", target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if output_word_probs or output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = "" if output_word_probs: logger.info( str(int(sample_id)) + " " + ("\t".join("{} [{:2f}]".format(x[0], x[1]) for x in word_prob))) avg_nll_loss = (-score_sum / count / math.log(2) if count > 0 else 0 ) # convert to base 2 logger.info("Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0)) if output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws) return { "loss": avg_nll_loss, "perplexity": 2**avg_nll_loss, }
def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' utils.import_user_module(parsed_args) print(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(':'), arg_overrides=eval(str(parsed_args.model_overrides)), # Add str() by xxx, for some reason, model_overrides={} not '{}' after training. task=task, ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = set( i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) ) bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if args.add_bos_token: assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) if inf_scores.any(): print('| Skipping tokens with inf scores:', task.target_dictionary.string(tokens[inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: # Modified by xxx #print( # str(int(sample_id)) + " " # + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) #) print( str(int(sample_id)) + "|||" + (' '.join('{:2f}'.format(x[1]) for x in word_prob)) ) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): print(ws)
def main(cfg: DictConfig, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) logger.info(cfg) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window # Initialize the task using the current *cfg* task = tasks.setup_task(cfg.task) # Initialize the model (but not the task) using the checkpoint's *cfg* models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=eval(cfg.common_eval.model_overrides), suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, task=task, ) # Load dataset splits gen_subset = cfg.dataset.gen_subset task.load_dataset(gen_subset) dataset = task.dataset(gen_subset) if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=cfg.task.tokens_per_sample, context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) assert len(models) > 0 logger.info("num. model params: {}".format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens or 36000, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=max( cfg.dataset.num_shards, cfg.distributed_training.distributed_world_size, ), shard_id=max( cfg.dataset.shard_id, cfg.distributed_training.distributed_rank, ), num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 if cfg.common_eval.post_process is not None: if cfg.common_eval.post_process == "sentencepiece": raise NotImplementedError else: bpe_cont = cfg.common_eval.post_process.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample["id"][i] tokens = hypo["tokens"] tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() if getattr(cfg.task, "add_bos_token", False): assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq( float("-inf")) if inf_scores.any(): logger.info( "skipping tokens with inf scores:", task.target_dictionary.string( tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = "" if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " + ("\t".join("{} [{:2f}]".format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample["ntokens"]) progress.log({"wps": round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log( 2) if count > 0 else 0 # convert to base 2 logger.info("Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0)) logger.info("Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( avg_nll_loss, 2**avg_nll_loss)) if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
class LMScorer(object): def __init__(self, parsed_args): self.args = parsed_args import_user_module(parsed_args) assert parsed_args.path is not None, '--path required for evaluation' print(parsed_args) self.use_cuda = torch.cuda.is_available() and not parsed_args.cpu self.task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) self.models, args = utils.load_ensemble_for_inference( parsed_args.path.split(':'), self.task, model_arg_overrides=eval(parsed_args.model_overrides), ) for model in self.models: model.make_generation_fast_() if self.use_cuda: model.cuda() for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary' }: setattr(args, arg, getattr(parsed_args, arg)) self.task = tasks.setup_task(args) self.gen_timer = StopwatchMeter() self.scorer = SequenceScorer(self.task.target_dictionary) def score_sent(self, line): score_dict = self.score([line]) return score_dict[0] def make_batches(self, lines): token_lst = [ self.task.source_dictionary.encode_line( line, add_if_not_exist=False).long() for line in lines ] length_lst = torch.LongTensor([tokens.numel() for tokens in token_lst]) ds = data.TokenBlockDataset(token_lst, length_lst, self.args.tokens_per_sample, pad=self.task.dictionary.pad(), eos=self.task.dictionary.eos(), break_mode='eos', include_targets=True) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' itr = self.task.get_batch_iterator( dataset=data.MonolingualDataset(ds, ds.sizes, self.task.dictionary, self.task.target_dictionary, add_eos_for_other_targets, shuffle=False, targets=self.task.targets), max_tokens=self.args.max_tokens or 3000, max_sentences=self.args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in self.models]), num_shards=self.args.num_shards, shard_id=self.args.shard_id, ignore_invalid_inputs=True, num_workers=self.args.num_workers, ).next_epoch_itr(shuffle=False) return itr def score(self, lines): batch = self.make_batches(lines) sample_score_dict = {} # with progress_bar.build_progress_bar(self.args, itr) as t: for sample in batch: sample_id_lst = sample['id'] sample = utils.move_to_cuda(sample) if self.use_cuda else sample if 'net_input' not in sample: continue hypos = self.scorer.generate(self.models, sample) # print(hypos) for sample_id, hypos_i in zip(sample_id_lst, hypos): hypo = hypos_i[0] pos_scores = hypo['positional_scores'] inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq( float('-inf')) if inf_scores.any(): print( '| Skipping tokens with inf scores:', self.task.target_dictionary.string( hypo['tokens'][inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] sample_score = pos_scores.sum().cpu() count = pos_scores.numel() w_lst = [] word_prob = [] for i in range(len(hypo['tokens'])): w_ind = hypo['tokens'][i].item() w = self.task.dictionary[w_ind] word_prob.append((w, pos_scores[i].item())) w_lst.append(w) sample_score = -sample_score / count if not self.args.quiet: if self.args.output_sent: print('H-{}\t{}\t{}'.format(sample_id, sample_score, ' '.join(w_lst))) else: print('H-{}\t{}'.format(sample_id, sample_score)) sample_score_dict[sample_id.item()] = sample_score.item() # print(sample_id, sample_score.item()) return sample_score_dict
def main(parsed_args, **unused_kwargs): assert parsed_args.path is not None, '--path required for evaluation!' if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() print(os.path.dirname(args.jason_test_output)) checkpoint_utils.verify_checkpoint_directory(os.path.dirname(args.jason_test_output)) test_loss_writer = open(args.jason_test_output, 'w') # test_loss_uid_writer = open(args.jason_test_uid_output, 'w') wps_meter = TimeMeter() for sample in progress: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) if inf_scores.any(): logger.info( 'skipping tokens with inf scores:', task.target_dictionary.string(tokens[inf_scores.nonzero()]) ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks # print(i, pos_scores.size(), pos_scores.cpu()[-3:], pos_scores.sum().cpu(), pos_scores.numel() - skipped_toks) # print(parsed_args.jason_test_output_dir) pos_scores_cpu = pos_scores.cpu() output_line = "" for j in range(pos_scores_cpu.size()[0]): nll_loss_base2 = - pos_scores_cpu[j].item() / math.log(2) test_loss_writer.write(f"{nll_loss_base2}\n") output_line += f"{nll_loss_base2:.5f}," output_line = output_line[:-1] + "\n" # test_loss_uid_writer.write(output_line) if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) ) wps_meter.update(sample['ntokens']) progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg )) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss )) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)