def eval_lm_dataloader( self, dataset, max_tokens: Optional[int] = 36000, batch_size: Optional[int] = None, max_positions: Optional[int] = None, num_shards: int = 1, shard_id: int = 0, num_workers: int = 1, data_buffer_size: int = 10, # ensures that every evaluated token has access to a context of at least # this size, if possible context_window: int = 0, ): if context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=self.args.tokens_per_sample, context_window=context_window, pad_idx=self.source_dictionary.pad(), ) return self.get_batch_iterator( dataset=dataset, max_tokens=max_tokens, max_sentences=batch_size, max_positions=max_positions, ignore_invalid_inputs=True, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, data_buffer_size=data_buffer_size, ).next_epoch_itr(shuffle=False)
def main(cfg: DictConfig, override_args=None, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None logger.info(cfg) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Load dataset splits gen_subset = cfg.dataset.gen_subset task.load_dataset(gen_subset) dataset = task.dataset(gen_subset) if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=cfg.task.tokens_per_sample, context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) assert len(models) > 0 logger.info( "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) ) itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens or 36000, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), ignore_invalid_inputs=True, num_shards=max( cfg.dataset.num_shards, cfg.distributed_training.distributed_world_size, ), shard_id=max( cfg.dataset.shard_id, cfg.distributed_training.distributed_rank, ), num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 if cfg.common_eval.remove_bpe is not None: if cfg.common_eval.remove_bpe == "sentencepiece": raise NotImplementedError else: bpe_cont = cfg.common_eval.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample["id"][i] tokens = hypo["tokens"] tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() if cfg.task.add_bos_token: assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) if inf_scores.any(): logger.info( "skipping tokens with inf scores:", task.target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob ) is_bpe = False w = "" if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " + ( "\t".join( "{} [{:2f}]".format(x[0], x[1]) for x in word_prob ) ) ) wps_meter.update(sample["ntokens"]) progress.log({"wps": round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info( "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg ) ) logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( avg_nll_loss, 2 ** avg_nll_loss ) ) if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def main(parsed_args, **unused_kwargs): assert parsed_args.path is not None, '--path required for evaluation!' if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq( float('-inf')) if inf_scores.any(): logger.info( 'skipping tokens with inf scores:', task.target_dictionary.string( tokens[inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample['ntokens']) progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss)) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' import_user_module(parsed_args) print(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) models, args = utils.load_ensemble_for_inference( parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary' }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 print('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = set(i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont)) bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for hypos_i in hypos: hypo = hypos_i[0] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq( float('-inf')) if inf_scores.any(): print( '| Skipping tokens with inf scores:', task.target_dictionary.string( tokens[inf_scores.nonzero()])) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) if args.log2: avg_nll_loss /= 0.69314718056 print( '| Loss: {:.4f}, Perplexity: {:.2f}, Cross-entropy: {:.4f}'.format( avg_nll_loss, np.power(2, avg_nll_loss), avg_nll_loss * (count - len(dataset)) / len(dataset))) else: print( '| Loss: {:.4f}, Perplexity: {:.2f}, Cross-entropy: {:.4f}'.format( avg_nll_loss, np.exp(avg_nll_loss), avg_nll_loss * (count - len(dataset)) / len(dataset))) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): print(ws)
def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() if args.knnlm and args.save_knnlm_dstore: raise ValueError( "Cannot use knnlm while trying to build the datastore!") if args.knnlm: knn_dstore = KNN_Dstore(args) with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() if args.save_knnlm_dstore: print('keytype being saved:', args.knn_keytype) if args.dstore_fp16: print('Saving fp16') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int16, mode='w+', shape=(args.dstore_size, 1)) else: print('Saving fp32') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) dstore_idx = 0 for ex_i, sample in enumerate(t): if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() if args.knnlm: hypos = scorer.generate(models, sample, knn_dstore=knn_dstore) else: hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] if args.save_knnlm_dstore: shape = hypo['dstore_keys'].shape if shape[0] == args.tokens_per_sample: if dstore_idx + shape[0] > args.dstore_size: shape = [args.dstore_size - dstore_idx] hypo['dstore_keys'] = hypo[ 'dstore_keys'][:shape[0]] if args.dstore_fp16: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float16) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int16) else: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float32) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int) dstore_idx += shape[0] else: print('Skipping this one with shape', shape) sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if args.add_bos_token: assert hypo['tokens'][0].item( ) == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) #if inf_scores.any(): # logger.info( # 'skipping tokens with inf scores:', # task.target_dictionary.string(tokens[inf_scores.nonzero()]) # ) # pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) if args.save_knnlm_dstore: print("dstore_idx", dstore_idx, "final shape", shape) print("Keys", dstore_keys.shape, dstore_keys.dtype) print("Vals", dstore_vals.shape, dstore_vals.dtype) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss)) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def test(parsed_args): # Make sure we didn't screw up the params assert parsed_args.path is not None, '--path required for evaluation!' assert parsed_args.sample_break_mode == 'eos', 'Sample break mode must be eos!' # Print the args import_user_module(parsed_args) print(parsed_args) # Do we use CUDA use_cuda = torch.cuda.is_available() and not parsed_args.cpu # Get the task (Language Modeling) task = tasks.setup_task(parsed_args) # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) models, args = utils.load_ensemble_for_inference( parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary' }: setattr(args, arg, getattr(parsed_args, arg)) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize model for generation assert len(models) > 0 model = models[0] model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() # Make data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Iterate over batches of sentences # Get the sentence logps for the batch all_s_logps = [] all_n_tokens = 0 all_n_sentences = 0 for sample in itr: if 'net_input' not in sample: continue # Move sample to GPU if possible sample = utils.move_to_cuda(sample) if use_cuda else sample # Number of sentences in this batch bsz = sample['nsentences'] all_n_sentences += bsz # Get the softmax outputs for the batch # The resultant tensor has shape: BATCH_SZ x N_TOKENS x VOCAB_SZ probs = [] net_input = sample['net_input'] with torch.no_grad(): model.eval() decoder_out = model.forward(**net_input) probs = model.get_normalized_probs(decoder_out, log_probs=True, sample=sample).data # Make sure we have a softmax-sequence for each sentence in the batch assert len(probs) == bsz # Assert that the softmax output is correct assert torch.allclose(torch.sum(torch.exp(probs), dim=2), torch.ones(probs.shape[:2])) # Get the token logps for each sentence from the softmax outputs target = sample['target'] logps = probs.gather( dim=2, index=target.unsqueeze(-1), ).squeeze(2) # Iterate over each sentence in the batch # Get the sum of logps for each sentence start_idxs = [0] * bsz for i in range(bsz): # Get the token indices / strings tokens = utils.strip_pad(target[i, start_idxs[i]:], task.source_dictionary.pad()) token_idxs = [tokens[i].item() for i in range(len(tokens))] token_strings = [task.source_dictionary[idx] for idx in token_idxs] # Maintain total number of tokens all_n_tokens += len(tokens) # This is the original sentence sentence = ' '.join(token_strings) # Get the token logps for this sentence s_len = len(tokens) s_logps = logps[i][:s_len] all_s_logps.append(torch.sum(s_logps)) # Get the average sentence logp over all sentences in the test set avg_s_logp = sum(all_s_logps) / all_n_sentences print(-1 * avg_s_logp.item())
def main(parsed_args): if parsed_args.dstore_mmap is not None: d = os.path.dirname(parsed_args.dstore_mmap) print('mmap from {}'.format(d)) if not os.path.exists(d): print('making dir') os.system('mkdir -p {}'.format(d)) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load model. hf_tokenizer = AutoTokenizer.from_pretrained(parsed_args.hf_model) if parsed_args.hf_enc_mode == 'masked': hf_model = AutoModelForMaskedLM.from_pretrained(parsed_args.hf_model) elif parsed_args.hf_enc_mode == 'causal': hf_model = AutoModelForCausalLM.from_pretrained(parsed_args.hf_model) if use_cuda: hf_model.cuda() device = next(hf_model.parameters()).device check_input_ids = hf_tokenizer('hello world')['input_ids'] add_cls_token = check_input_ids[0] == hf_tokenizer.cls_token_id add_sep_token = check_input_ids[-1] == hf_tokenizer.sep_token_id print('add_cls_token = {} {} {}'.format(add_cls_token, hf_tokenizer.cls_token, hf_tokenizer.cls_token_id)) print('add_sep_token = {} {} {}'.format(add_sep_token, hf_tokenizer.sep_token, hf_tokenizer.sep_token_id)) args = copy.deepcopy(parsed_args) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) task_dataset = task.dataset(args.gen_subset) assert args.context_window > 0 dataset = LMContextWindowDataset( dataset=task_dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) model_max_length = min(hf_tokenizer.model_max_length, parsed_args.hf_max_position) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model_max_length ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) #).next_epoch_itr(shuffle=True) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() if args.knnlm and args.save_knnlm_dstore: raise ValueError("Cannot use knnlm while trying to build the datastore!") if args.knnlm: knn_dstore = KNN_Dstore(args) with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() if args.save_knnlm_dstore: print('keytype being saved:', args.knn_keytype) dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, hf_model.config.d_model)) dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) if args.save_extra: writer = Writer(outdir='demo-out', max_size=args.save_extra_max_size, k=args.k, vec_size=1024) def pad(x, pad_id=-1): max_len = max([len(xx) for xx in x]) x = [xx + [pad_id] * (max_len - len(xx)) for xx in x] return x def batchify(batch): new_batch = {} new_batch['input_ids'] = torch.tensor(pad(batch['src_tokens'], hf_tokenizer.pad_token_id), dtype=torch.long, device=device) new_batch['context_mask'] = torch.tensor(pad(batch['mask'], -1), dtype=torch.long, device=device) new_batch['word_id'] = torch.tensor(pad(batch['word_id'], -1), dtype=torch.long, device=device) new_batch['target'] = torch.tensor(pad(batch['target'], -1), dtype=torch.long, device=device) return new_batch dstore_idx = 0 dstore_full = False num_tokens = 0 for ex_i, sample in tqdm(enumerate(t), desc='encode'): if 'net_input' not in sample: continue all_tokens = torch.cat([sample['net_input']['src_tokens'], sample['target'][:, -1, None]], -1) hf_batch = collections.defaultdict(list) for tok in all_tokens.tolist(): tok = [tt for tt in tok if tt != dataset.pad_idx] raw_text = [task_dataset.vocab[tt] for tt in tok] hf_src_tokens, hf_target, hf_raw_target, hf_raw_text, hf_word_id, hf_mask = [], [], [], [], [], [] for i_w in range(len(raw_text) - 1): w = raw_text[i_w] tok_ = hf_tokenizer.encode(w, add_special_tokens=False) if i_w == 0 and add_cls_token: if tok_[0] != hf_tokenizer.cls_token_id: tok_ = [hf_tokenizer.cls_token_id] + tok_ if len(hf_src_tokens) + len(tok_) > model_max_length: break hf_src_tokens += tok_ hf_raw_text += hf_tokenizer.convert_ids_to_tokens(tok_) hf_word_id += [i_w] * len(tok_) hf_mask += [0] * (len(tok_) - 1) + [1] hf_target += [tok[i_w + 1]] * len(tok_) hf_raw_target += [raw_text[i_w + 1]] assert len(hf_src_tokens) == len(hf_target) assert len(hf_src_tokens) == len(hf_word_id) assert len(hf_src_tokens) == len(hf_mask) hf_batch['src_tokens'].append(hf_src_tokens) hf_batch['target'].append(hf_target) # This is indexed by KNN-LM tokenizer. hf_batch['raw_target'].append(hf_raw_target) hf_batch['word_id'].append(hf_word_id) hf_batch['mask'].append(hf_mask) num_tokens += len(hf_src_tokens) hf_batch_ = batchify(hf_batch) model_output = hf_model(hf_batch_['input_ids'], output_hidden_states=True) h = model_output['hidden_states'][-1] assert h.shape[:2] == hf_batch_['input_ids'].shape[:2] if args.save_knnlm_dstore and not dstore_full: flat_h = h.view(-1, hf_model.config.d_model) mask_ = hf_batch_['context_mask'].view(-1) == 1 keys_ = flat_h[mask_] vals_ = hf_batch_['target'].view(-1, 1)[mask_] shape = keys_.shape if dstore_idx + shape[0] > args.dstore_size: shape = [args.dstore_size - dstore_idx] dstore_full = True keys_ = keys_[:shape[0]] vals_ = vals_[:shape[0]] assert keys_.shape[0] == vals_.shape[0] dstore_keys[dstore_idx:shape[0]+dstore_idx] = keys_.cpu().numpy().astype(np.float32) dstore_vals[dstore_idx:shape[0]+dstore_idx] = vals_.cpu().numpy().astype(np.int) dstore_idx += shape[0] if dstore_full: print('Datastore is full with {} items.'.format(args.dstore_size)) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) # Write saved values to disk. if args.save_extra: writer.update(extra) if args.save_knnlm_dstore: print("dstore_idx", dstore_idx, "final shape", shape) print("Keys", dstore_keys.shape, dstore_keys.dtype) print("Vals", dstore_vals.shape, dstore_vals.dtype) logger.info('done with {} tokens'.format(num_tokens))