def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.make_generation_fast_() if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format( sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() if args.knnlm and args.save_knnlm_dstore: raise ValueError( "Cannot use knnlm while trying to build the datastore!") if args.knnlm: knn_dstore = KNN_Dstore(args) with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() if args.save_knnlm_dstore: print('keytype being saved:', args.knn_keytype) if args.dstore_fp16: print('Saving fp16') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int16, mode='w+', shape=(args.dstore_size, 1)) else: print('Saving fp32') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) dstore_idx = 0 for ex_i, sample in enumerate(t): if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() if args.knnlm: hypos = scorer.generate(models, sample, knn_dstore=knn_dstore) else: hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] if args.save_knnlm_dstore: shape = hypo['dstore_keys'].shape if shape[0] == args.tokens_per_sample: if dstore_idx + shape[0] > args.dstore_size: shape = [args.dstore_size - dstore_idx] hypo['dstore_keys'] = hypo[ 'dstore_keys'][:shape[0]] if args.dstore_fp16: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float16) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int16) else: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float32) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int) dstore_idx += shape[0] else: print('Skipping this one with shape', shape) sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if args.add_bos_token: assert hypo['tokens'][0].item( ) == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) #if inf_scores.any(): # logger.info( # 'skipping tokens with inf scores:', # task.target_dictionary.string(tokens[inf_scores.nonzero()]) # ) # pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) if args.save_knnlm_dstore: print("dstore_idx", dstore_idx, "final shape", shape) print("Keys", dstore_keys.shape, dstore_keys.dtype) print("Vals", dstore_vals.shape, dstore_vals.dtype) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss)) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('fairseq_cli.generate') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, ) args.vocab_size = len(tgt_dict) for arg in vars(_model_args).keys(): if arg in {'decoder_embed_dim', 'vocab_size'}: setattr(args, arg, getattr(_model_args, arg)) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() if args.knnlm and args.save_knnlm_dstore: raise ValueError( "Cannot use knnlm while trying to build the datastore!") if args.knnlm: knn_dstore = KNN_Dstore(args) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) if args.save_knnlm_dstore: print('keytype being saved:', args.knn_keytype) if args.dstore_fp16: print('Saving fp16') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float16, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int16, mode='w+', shape=(args.dstore_size, 1)) else: print('Saving fp32') dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, args.decoder_embed_dim)) dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) dstore_idx = 0 if args.save_knnlm_dstore or args.knnlm: # source_tokens_file = open(args.output_tokens_file_prefix + '.src' , 'w') target_tokens_file = open(args.output_tokens_file_prefix + '.tgt', 'w') # This is only for MT right now, use interactive.py for language modeling assert task != 'language_modeling' num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() if args.knnlm: hypos = task.inference_step(generator, models, sample, prefix_tokens, knn_dstore=knn_dstore) else: hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) if args.save_knnlm_dstore: for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] shape = hypo['dstore_keys'].shape if dstore_idx + shape[0] > args.dstore_size: shape = [args.dstore_size - dstore_idx] hypo['dstore_keys'] = hypo['dstore_keys'][:shape[0]] # import pdb; pdb.set_trace() # print(hypo) if args.dstore_fp16: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float16) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int16) else: dstore_keys[dstore_idx:shape[0] + dstore_idx] = hypo['dstore_keys'].view( -1, args.decoder_embed_dim).cpu( ).numpy().astype(np.float32) dstore_vals[dstore_idx:shape[0] + dstore_idx] = hypo['tokens'].view( -1, 1).cpu().numpy().astype(np.int) dstore_idx += shape[0] if args.save_knnlm_dstore or args.knnlm: for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] # dump the tokens to a file, used for analysis and interactive printing # source_tokens = [task.source_dictionary[token] for token in hypo['source_tokens']] # source_tokens_file.write('\n'.join(source_tokens) + '\n') target_tokens = [ task.target_dictionary[token] for token in hypo['tokens'] ] target_tokens_file.write('\n'.join(target_tokens) + '\n') for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) if has_target: print('T-{}\t{}'.format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: score = hypo['score'] / math.log( 2) # convert to base 2 print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) print( 'P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_( math.log(2)).tolist(), ))), file=output_file) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join([ '{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ])), file=output_file) if args.print_step: print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) if getattr(args, 'retain_iter_history', False): for step, h in enumerate(hypo['history']): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print('E-{}_{}\t{}'.format( sample_id, step, h_str), file=output_file) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info( 'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: logger.info('Generate {} with beam={}: {}'.format( args.gen_subset, args.beam, scorer.result_string())) if args.save_knnlm_dstore: print("dstore_idx", dstore_idx, "final shape", shape) print("Keys", dstore_keys.shape, dstore_keys.dtype) print("Vals", dstore_vals.shape, dstore_vals.dtype) target_tokens_file.seek(0) num_lines = len(target_tokens_file.readlines()) if dstore_idx != num_lines: print( 'Warning: size of KNN datastore is {}, does not match number of lines in train tokens file which is {}' .format(dstore_idx, num_lines)) if args.save_knnlm_dstore or args.knnlm: # source_tokens_file.close() target_tokens_file.close() return scorer
def main(args): utils.import_user_module(args) assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ '--max-sentences/--batch-size cannot be larger than --buffer-size' logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Setup task, e.g., translation task = tasks.setup_task(args) # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, ) for arg in vars(_model_args).keys(): '''if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }:''' if arg in {'decoder_embed_dim', 'vocab_size'}: setattr(args, arg, getattr(_model_args, arg)) logger.info(args) if args.knnlm: knn_dstore = KNN_Dstore(args) print("Loading training tokens...") with open(args.input_tokens_file) as infile: train_tokens = infile.read().split() print("TODO, REMOVE ME\n\n\n !!!!Skipping first training tokens...") train_tokens = train_tokens[3072:] # TODO, remove this if args.buffer_size < 1: args.buffer_size = 1 if args.max_tokens is None and args.max_sentences is None: args.max_sentences = 1 # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Initialize generator generator = task.build_generator(args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]) if args.buffer_size > 1: logger.info('Sentence buffer size: %s', args.buffer_size) logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Type the input sentence and press return:') start_id = 0 for inputs in buffered_read(args.input, args.buffer_size): results = [] for batch in make_batches(inputs, args, task, max_positions, encode_fn): src_tokens = batch.src_tokens src_lengths = batch.src_lengths if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() sample = { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths, }, } if args.knnlm: translations = task.inference_step(generator, models, sample, None, knn_dstore=knn_dstore) else: translations = task.inference_step(generator, models, sample, None) for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) results.append((start_id + id, src_tokens_i, hypos)) # sort output to match input order for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) print('S-{}\t{}'.format(id, src_str)) # Process top predictions for hypo in hypos[:min(len(hypos), args.nbest)]: assert hypo['dists_full'] != None dists_full = hypo['dists_full'].float() knns_full = hypo['knns_full'] word_tokens = [ task.target_dictionary[token] for token in hypo['tokens'] ] #assert len(yhat_scores.tolist()) == len(word_tokens) # TODO, trim off padding when its batched context_size = 20 num_neighbors = 10 print("Example:", " ".join(word_tokens)) print(dists_full) best_dist_indices = np.argsort( dists_full)[-num_neighbors:][::-1] for j, neighbor_index in enumerate(best_dist_indices): distance = dists_full[neighbor_index] knn_index = knns_full[neighbor_index] print( "Best neighbor {} (distance {:.2f}):".format( j, distance), " ".join(train_tokens[knn_index - context_size:knn_index]), "[[", train_tokens[knn_index], "]]") hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) hypo_str = decode_fn(hypo_str) score = hypo['score'] / math.log(2) # convert to base 2 print('H-{}\t{}\t{}'.format(id, score, hypo_str)) print('P-{}\t{}'.format( id, ' '.join( map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2) ).tolist(), )))) if args.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment]) print('A-{}\t{}'.format(id, alignment_str)) # update running id counter start_id += len(inputs)
def main(parsed_args): if parsed_args.dstore_mmap is not None: d = os.path.dirname(parsed_args.dstore_mmap) print('mmap from {}'.format(d)) if not os.path.exists(d): print('making dir') os.system('mkdir -p {}'.format(d)) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load model. hf_tokenizer = AutoTokenizer.from_pretrained(parsed_args.hf_model) if parsed_args.hf_enc_mode == 'masked': hf_model = AutoModelForMaskedLM.from_pretrained(parsed_args.hf_model) elif parsed_args.hf_enc_mode == 'causal': hf_model = AutoModelForCausalLM.from_pretrained(parsed_args.hf_model) if use_cuda: hf_model.cuda() device = next(hf_model.parameters()).device check_input_ids = hf_tokenizer('hello world')['input_ids'] add_cls_token = check_input_ids[0] == hf_tokenizer.cls_token_id add_sep_token = check_input_ids[-1] == hf_tokenizer.sep_token_id print('add_cls_token = {} {} {}'.format(add_cls_token, hf_tokenizer.cls_token, hf_tokenizer.cls_token_id)) print('add_sep_token = {} {} {}'.format(add_sep_token, hf_tokenizer.sep_token, hf_tokenizer.sep_token_id)) args = copy.deepcopy(parsed_args) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) task_dataset = task.dataset(args.gen_subset) assert args.context_window > 0 dataset = LMContextWindowDataset( dataset=task_dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) model_max_length = min(hf_tokenizer.model_max_length, parsed_args.hf_max_position) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model_max_length ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) #).next_epoch_itr(shuffle=True) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch, args=args) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() if args.knnlm and args.save_knnlm_dstore: raise ValueError("Cannot use knnlm while trying to build the datastore!") if args.knnlm: knn_dstore = KNN_Dstore(args) with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() if args.save_knnlm_dstore: print('keytype being saved:', args.knn_keytype) dstore_keys = np.memmap(args.dstore_mmap+'_keys.npy', dtype=np.float32, mode='w+', shape=(args.dstore_size, hf_model.config.d_model)) dstore_vals = np.memmap(args.dstore_mmap+'_vals.npy', dtype=np.int, mode='w+', shape=(args.dstore_size, 1)) if args.save_extra: writer = Writer(outdir='demo-out', max_size=args.save_extra_max_size, k=args.k, vec_size=1024) def pad(x, pad_id=-1): max_len = max([len(xx) for xx in x]) x = [xx + [pad_id] * (max_len - len(xx)) for xx in x] return x def batchify(batch): new_batch = {} new_batch['input_ids'] = torch.tensor(pad(batch['src_tokens'], hf_tokenizer.pad_token_id), dtype=torch.long, device=device) new_batch['context_mask'] = torch.tensor(pad(batch['mask'], -1), dtype=torch.long, device=device) new_batch['word_id'] = torch.tensor(pad(batch['word_id'], -1), dtype=torch.long, device=device) new_batch['target'] = torch.tensor(pad(batch['target'], -1), dtype=torch.long, device=device) return new_batch dstore_idx = 0 dstore_full = False num_tokens = 0 for ex_i, sample in tqdm(enumerate(t), desc='encode'): if 'net_input' not in sample: continue all_tokens = torch.cat([sample['net_input']['src_tokens'], sample['target'][:, -1, None]], -1) hf_batch = collections.defaultdict(list) for tok in all_tokens.tolist(): tok = [tt for tt in tok if tt != dataset.pad_idx] raw_text = [task_dataset.vocab[tt] for tt in tok] hf_src_tokens, hf_target, hf_raw_target, hf_raw_text, hf_word_id, hf_mask = [], [], [], [], [], [] for i_w in range(len(raw_text) - 1): w = raw_text[i_w] tok_ = hf_tokenizer.encode(w, add_special_tokens=False) if i_w == 0 and add_cls_token: if tok_[0] != hf_tokenizer.cls_token_id: tok_ = [hf_tokenizer.cls_token_id] + tok_ if len(hf_src_tokens) + len(tok_) > model_max_length: break hf_src_tokens += tok_ hf_raw_text += hf_tokenizer.convert_ids_to_tokens(tok_) hf_word_id += [i_w] * len(tok_) hf_mask += [0] * (len(tok_) - 1) + [1] hf_target += [tok[i_w + 1]] * len(tok_) hf_raw_target += [raw_text[i_w + 1]] assert len(hf_src_tokens) == len(hf_target) assert len(hf_src_tokens) == len(hf_word_id) assert len(hf_src_tokens) == len(hf_mask) hf_batch['src_tokens'].append(hf_src_tokens) hf_batch['target'].append(hf_target) # This is indexed by KNN-LM tokenizer. hf_batch['raw_target'].append(hf_raw_target) hf_batch['word_id'].append(hf_word_id) hf_batch['mask'].append(hf_mask) num_tokens += len(hf_src_tokens) hf_batch_ = batchify(hf_batch) model_output = hf_model(hf_batch_['input_ids'], output_hidden_states=True) h = model_output['hidden_states'][-1] assert h.shape[:2] == hf_batch_['input_ids'].shape[:2] if args.save_knnlm_dstore and not dstore_full: flat_h = h.view(-1, hf_model.config.d_model) mask_ = hf_batch_['context_mask'].view(-1) == 1 keys_ = flat_h[mask_] vals_ = hf_batch_['target'].view(-1, 1)[mask_] shape = keys_.shape if dstore_idx + shape[0] > args.dstore_size: shape = [args.dstore_size - dstore_idx] dstore_full = True keys_ = keys_[:shape[0]] vals_ = vals_[:shape[0]] assert keys_.shape[0] == vals_.shape[0] dstore_keys[dstore_idx:shape[0]+dstore_idx] = keys_.cpu().numpy().astype(np.float32) dstore_vals[dstore_idx:shape[0]+dstore_idx] = vals_.cpu().numpy().astype(np.int) dstore_idx += shape[0] if dstore_full: print('Datastore is full with {} items.'.format(args.dstore_size)) wps_meter.update(sample['ntokens']) t.log({'wps': round(wps_meter.avg)}) # Write saved values to disk. if args.save_extra: writer.update(extra) if args.save_knnlm_dstore: print("dstore_idx", dstore_idx, "final shape", shape) print("Keys", dstore_keys.shape, dstore_keys.dtype) print("Vals", dstore_vals.shape, dstore_vals.dtype) logger.info('done with {} tokens'.format(num_tokens))