def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") # decoding parameters parser.add_argument( '--fp16', default=0, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument('--subset', type=int, default=0, help="Decode a subset of the input dataset.") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1." ) if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s", num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift)) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) forbid_ignore_set = None if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length, ffn_type=args.ffn_type, num_qkv=args.num_qkv, seg_emb=args.seg_emb, pos_shift=args.pos_shift) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length with open(args.input_file, encoding="utf-8") as fin: input_lines = [x.strip() for x in fin.readlines()] if args.subset > 0: logger.info("Decoding subset: %d", args.subset) input_lines = input_lines[:args.subset] data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer input_lines = [ data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines ] input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1])) output_lines = [""] * len(input_lines) score_trace_list = [None] * len(input_lines) total_batch = math.ceil(len(input_lines) / args.batch_size) with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[1] for x in _chunk] next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors( instances) batch = [ t.to(device) if t is not None else None for t in batch ] input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces.tolist() for i in range(len(buf)): w_ids = output_ids[i] output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) output_lines[buf_id[i]] = output_sequence if args.need_score_traces: score_trace_list[buf_id[i]] = { 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i] } pbar.update(1) if args.output_file: fn_out = args.output_file else: fn_out = model_recover_path + '.' + args.split with open(fn_out, "w", encoding="utf-8") as fout: for l in output_lines: fout.write(l) fout.write("\n") if args.need_score_traces: with open(fn_out + ".trace.pickle", "wb") as fout_trace: pickle.dump({ "version": 0.0, "num_samples": len(input_lines) }, fout_trace) for x in score_trace_list: pickle.dump(x, fout_trace)
def main(): parser = argparse.ArgumentParser() # parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized') # parser.add_argument('--src_lang', default='en', type=str, help='') # parser.add_argument('--tgt_lang', default='zh', type=str, help='') parser.add_argument( '--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument( '--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') parser.add_argument("--src_file", default='./.pkl', type=str, help="The input data file name.") # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--xml_vocab", type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge", type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) #useless parser.add_argument('--split', type=str, default='val') #wmt? parser.add_argument('--len_vis_input', type=int, default=1, help="The length of visual token input region 1") with open( '/data/private/chenyutong/dataset/concept_count/word_concept_count.pkl', 'rb') as f: word_fre = pickle.load(f) word_fre = defaultdict(int, word_fre) args = parser.parse_args() assert args.batch_size == 1, 'only support batch_size=1' args.max_tgt_length = max(args.max_len_en, args.max_len_zh) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize( x, lang='zh', bypass_tokenizer=False) with open(args.xml_vocab, 'r') as f: tokenizer_zh.vocab = json.load(f) indexer = Indexer( [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab]) with open('full_vocab.json', 'w') as f: json.dump(indexer.ids_to_tokens, f) tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh} print('tokenizer created') assert '.pkl' in args.src_file with open(args.src_file, 'rb') as f: src_data = pickle.load(f) # list [pred_id, vocab, vis, pos, distribution] # dict {'vgid':{'en':,'zh':,'region_features':[img, conf, fea[i], pos[i],dist]}} amp_handle = None if args.amp: from apex import amp # Prepare model cls_num_labels = 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(indexer(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): #logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover model.to(device) if args.amp: model = amp.initialize(model, opt_level='O2') #'02') torch.cuda.empty_cache() model.eval() fout = open(os.path.join(args.output_dir, 'region2txt_output.txt'), 'w') output_lines = [] select_ids = [87, 120, 179, 297, 721, 852, 1025] for step_val, sd in enumerate(src_data.items()): # if step_val>=1: # break vgid, input_item = sd en, zh = input_item['en'], input_item['zh'] fout.writelines('\n' + '#' * 10 + '\n') fout.writelines('{}\n'.format(vgid)) fout.writelines('{} coco: word_fre {} vis_fre {} \n'.format( en, input_item['coco_fre']['word'], input_item['coco_fre']['vis'])) fout.writelines('{} aic: word_fre {} vis_fre {} \n'.format( zh, input_item['aic_fre']['word'], input_item['aic_fre']['vis'])) print('step_val {} Process {}'.format(step_val, en)) for rf in tqdm(input_item['region_features']): filename, conf, vis_feats, vis_pe, cls_label = rf vis_feats = torch.from_numpy(vis_feats).to(device) vis_feats = vis_feats.unsqueeze(0) vis_pe = torch.from_numpy(vis_pe).to(device) vis_pe = vis_pe.unsqueeze(0) cls_label = torch.from_numpy(cls_label).to(device) cls_label = cls_label.unsqueeze(0) # # lazy normalization of the coordinates... copy from seq2seq w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5 h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5 vis_pe[:, [0, 2]] /= w_est vis_pe[:, [1, 3]] /= h_est assert h_est > 0, 'should greater than 0! {}'.format(h_est) assert w_est > 0, 'should greater than 0! {}'.format(w_est) rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0]) rel_area.clamp_(0) vis_pe = torch.cat( (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1) vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \ F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H vis_feats = vis_feats.unsqueeze(0) vis_pe = vis_pe.unsqueeze(0) #print('input shape', vis_feats.shape, vis_pe.shape) assert args.new_segment_ids == False, 'only support 0 1 6 now' tokens = ['[CLS]', '[UNK]', '[SEP]'] input_ids = indexer(tokens) input_ids = np.expand_dims(np.array(input_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) max_len_in_batch = len(tokens) + args.max_tgt_length _tril_matrix = torch.tril( torch.ones((max_len_in_batch, max_len_in_batch), dtype=torch.long)) input_mask = torch.zeros(max_len_in_batch, max_len_in_batch, dtype=torch.long, device=device) input_mask[:, :len(tokens)].fill_(1) second_st, second_end = len(tokens), max_len_in_batch input_mask[second_st:second_end, second_st:second_end].copy_( _tril_matrix[:second_end - second_st, :second_end - second_st]) #L,L input_mask = input_mask.unsqueeze(0) position_ids = torch.arange(max_len_in_batch, dtype=torch.long, device=device) #L position_ids = position_ids.unsqueeze(0) # B,L predictions = { 'en': None, 'zh': None, 'en2zh': None, 'zh2en': None } for tgt_lang, lang_id in zip(['en', 'zh'], [1, 6]): token_type_ids = [0] * len( tokens) + [lang_id] * args.max_tgt_length token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) with torch.no_grad(): # print(token_type_ids[0]) # print(position_ids[0]) # print(input_ids[0]) # print(input_mask[0]) # input() traces = model( vis_feats=vis_feats, vis_pe=vis_pe, input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=input_mask, search_beam_size=args.beam_size, task_idx=3, mode='img2txt', sample_mode='greedy') #validation greedy output_sequence = postprocess(traces, args.beam_size, tgt_lang, indexer) predictions[tgt_lang] = output_sequence #truncate for langs, lang_ids in zip(['en2zh', 'zh2en'], [[1, 6], [6, 1]]): src_lang = langs[:2] #en,zh tgt_lang = langs[-2:] w = predictions[ src_lang] # predictions['en']/ predictions['zh'] w_t = tokenizers[src_lang].tokenize(w) tokens = ['[CLS]'] + w_t + ['[SEP]'] input_ids = indexer(tokens) token_type_ids = [lang_ids[0]] * len( input_ids) + [lang_ids[1]] * args.max_tgt_length input_ids = np.expand_dims(np.array(input_ids), axis=0) token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) max_len_in_batch = len( tokens) + args.max_tgt_length #2+64 = 66 position_ids = torch.arange(max_len_in_batch, dtype=torch.long, device=device) #L position_ids = position_ids.unsqueeze(0) # B,L _tril_matrix = torch.tril( torch.ones((max_len_in_batch, max_len_in_batch), dtype=torch.long)) input_mask = torch.zeros(max_len_in_batch, max_len_in_batch, dtype=torch.long, device=device) input_mask[:, :len(tokens)].fill_(1) second_st, second_end = len(tokens), max_len_in_batch input_mask[second_st:second_end, second_st:second_end].copy_( _tril_matrix[:second_end - second_st, :second_end - second_st]) #L,L input_mask = input_mask.unsqueeze(0) with torch.no_grad(): traces = model( vis_feats=None, vis_pe=None, input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=input_mask, search_beam_size=args.beam_size, task_idx=3, mode='txt2txt', sample_mode='greedy') #validation greedy output_sequence = postprocess(traces, args.beam_size, tgt_lang, indexer) predictions[langs] = output_sequence #print(predictions) fout.writelines( 'conf:{:.2f} en:{: <10} fre:{:<5d} en2zh:{: <10} zh:{: <10} fre:{:<5d} zh2en:{: <10} \n' .format(conf, predictions['en'], word_fre['coco'][predictions['en']], predictions['en2zh'], predictions['zh'], word_fre['aic'][predictions['zh']], predictions['zh2en'])) fout.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized') parser.add_argument('--src_lang', default='en', type=str, help='') parser.add_argument('--tgt_lang', default='zh', type=str, help='') parser.add_argument( '--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument( '--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') parser.add_argument("--src_file", default='./src.txt', type=str, help="The input data file name.") parser.add_argument('--corpus', default='txt', type=str) parser.add_argument('--en_first', action='store_true', help='always to put english as the first sentence') # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--xml_vocab", type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge", type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) #useless parser.add_argument('--split', type=str, default='val') #wmt? parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") args = parser.parse_args() args.tgt_lang = 'en' if args.src_lang == 'zh' else 'zh' # print(sys.getfilesystemencoding()) # print('这是中文') assert args.batch_size == 1, 'only support batch_size=1' args.max_tgt_length = args.max_len_en if args.src_lang == 'zh' else args.max_len_zh if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize( x, lang='zh', bypass_tokenizer=False) with open(args.xml_vocab, 'r') as f: tokenizer_zh.vocab = json.load(f) indexer = Indexer( [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab]) with open('full_vocab.json', 'w') as f: json.dump(indexer.ids_to_tokens, f) tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh} print('tokenizer created') if '.txt' in args.src_file: with codecs.open(args.src_file, 'r') as f: src_lines = f.readlines() src_lines = [line.strip() for line in src_lines] N_lines = len(src_lines) elif 'hdf5' in args.src_file: assert 'wmt' == args.corpus src_lines = args.src_file N_lines = 1999 else: raise pipeline = seq2seq_loader.Preprocess4Seq2SeqBilingualDecoder( corpus=args.corpus, file_src=src_lines, src_lang=args.src_lang, indexer=indexer, tokenizers=tokenizers, max_len=args.max_len_en + args.max_len_zh + 3, max_tgt_length=args.max_tgt_length, preprocessed=False if args.corpus == 'txt' else True, new_segment_ids=args.new_segment_ids, mode='s2s') eval_dataset = seq2seq_loader.Txt2txtDataset( N_lines=N_lines, split=None, batch_size=args.batch_size, tokenizers=tokenizers, max_len=args.max_len_en + args.max_len_zh + 3, preprocessed=False if args.corpus == 'txt' else True, bi_uni_pipeline=[pipeline], s2s_prob=1, bi_prob=0) eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size, sampler=SequentialSampler(eval_dataset), num_workers=4, collate_fn=batch_list_to_batch_tensors, pin_memory=True) amp_handle = None if args.amp: from apex import amp # Prepare model cls_num_labels = 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(indexer(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): #logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover model.to(device) if args.amp: model = amp.initialize(model, opt_level='O2') #'02') torch.cuda.empty_cache() model.eval() val_iter_bar = tqdm(eval_dataloader) output_lines = [] for step_val, val_iter_output in enumerate(val_iter_bar): info_, batch = val_iter_output[0], val_iter_output[1] with torch.no_grad(): batch = [t.to(device) for t in batch] input_ids, segment_ids, position_ids, input_mask, task_idx = batch traces = model(None, None, input_ids, segment_ids, position_ids, input_mask, search_beam_size=args.beam_size, task_idx=task_idx, mode='txt2txt', sample_mode='greedy') #validation greedy # if step_val==0: # print(segment_ids[0]) # print(input_ids[0]) # print(position_ids[0]) # input() if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces[0].tolist() #print(output_ids) #input() for ii, w_ids in enumerate(output_ids): output_buf = indexer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) #print(t) if args.tgt_lang == 'en': output_sequence = ' '.join(detokenize(output_tokens)) output_sequence = output_sequence.replace( ' @ - @ ', '-') #print(id_,output_sequence) #id_ = step_val*args.batch_size+ii #output_sequence = output_sequence.replace('</w>',' ').replace(' ','') if args.tgt_lang == 'zh': output_sequence = ''.join( detokenize(output_tokens)).replace('</w>', '').replace( '[SEP]', '') output_lines.append(output_sequence) # with open(os.path.join(args.output_dir,'translation_output.json'),'w') as f: # json.dump(output_lines, f) with open(os.path.join(args.output_dir, 'translation_output.txt'), 'w') as f: for line in output_lines: f.writelines(line + '\n')
def main(): parser = argparse.ArgumentParser() # General parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--output_dir", default='tmp', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--log_file", default="training.log", type=str, help="The output directory where the log will be written.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--do_train", action='store_true', help="Whether to run training. This should ALWAYS be set to True.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=64, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.") parser.add_argument("--finetune_decay", action='store_true', help="Weight decay to the original weights.") parser.add_argument("--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--fp32_embedding', action='store_true', help= "Whether to use 32-bit float precision instead of 32-bit for embeddings" ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument( '--from_scratch', action='store_true', help= "Initialize parameters with random values (i.e., training from scratch)." ) parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") parser.add_argument('--max_len_b', type=int, default=20, help="Truncate_config: maximum length of segment B.") parser.add_argument( '--trunc_seg', default='b', help="Truncate_config: first truncate segment A/B (option: a, b).") parser.add_argument( '--always_truncate_tail', action='store_true', help="Truncate_config: Whether we should always truncate tail.") parser.add_argument( "--mask_prob", default=0.15, type=float, help= "Number of prediction is sometimes less than max_pred when sequence is short." ) parser.add_argument('--max_pred', type=int, default=3, help="Max tokens of prediction.") parser.add_argument("--num_workers", default=4, type=int, help="Number of workers for the data loader.") parser.add_argument('--max_position_embeddings', type=int, default=None, help="max position embeddings") # Others for VLP parser.add_argument( "--src_file", default=['/mnt/dat/COCO/annotations/dataset_coco.json'], type=str, nargs='+', help="The input data file name.") parser.add_argument('--enable_visdom', action='store_true') parser.add_argument('--visdom_port', type=int, default=8888) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument('--image_root', type=str, default='/mnt/dat/COCO/images') parser.add_argument('--dataset', default='coco', type=str, help='coco | flickr30k | cc') parser.add_argument('--split', type=str, nargs='+', default=['train', 'restval']) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='file://[PT_OUTPUT_DIR]/nonexistent_file', type=str, help='url used to set up distributed training') parser.add_argument( '--file_valid_jpgs', default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json', type=str) parser.add_argument('--sche_mode', default='warmup_linear', type=str, help="warmup_linear | warmup_constant | warmup_cosine") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--use_num_imgs', default=-1, type=int) parser.add_argument('--vis_mask_prob', default=0, type=float) parser.add_argument('--max_drop_worst_ratio', default=0, type=float) parser.add_argument('--drop_after', default=6, type=int) parser.add_argument( '--s2s_prob', default=1, type=float, help="Percentage of examples that are bi-uni-directional LM (seq2seq)." ) parser.add_argument( '--bi_prob', default=0, type=float, help="Percentage of examples that are bidirectional LM.") parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument( '--region_bbox_file', default= 'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str) parser.add_argument( '--region_det_file_prefix', default= 'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval', type=str) parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2') parser.add_argument('--relax_projection', action='store_true', help="Use different projection layers for tasks.") parser.add_argument('--scst', action='store_true', help='Self-critical sequence training') args = parser.parse_args() print('global_rank: {}, local rank: {}'.format(args.global_rank, args.local_rank)) args.max_seq_length = args.max_len_b + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] args.mask_image_regions = (args.vis_mask_prob > 0 ) # whether to mask out image regions args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir) # arguments inspection assert (args.tasks in ('img2txt', 'vqa2')) assert args.enable_butd == True, 'only support region attn! featmap attn deprecated' assert ( not args.scst) or args.dataset == 'coco', 'scst support on coco only!' if args.scst: assert args.dataset == 'coco', 'scst support on coco only!' assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!' rl_crit = RewardCriterion() if args.enable_butd: assert (args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join( args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco') and args.region_det_file_prefix != '' else '' # output config os.makedirs(args.output_dir, exist_ok=True) json.dump(args.__dict__, open(os.path.join(args.output_dir, 'opt.json'), 'w'), sort_keys=True, indent=2) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode='w', format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group( backend='nccl', init_method='tcp://localhost:10001', #args.dist_url, world_size=args.world_size, rank=args.global_rank) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {'iter': None, 'score': None} tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank)) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_image_regions=args.mask_image_regions, mode="s2s", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == 'vqa2')) ] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ 'max_len_b': args.max_len_b, 'trunc_seg': args.trunc_seg, 'always_truncate_tail': args.always_truncate_tail }, mask_image_regions=args.mask_image_regions, mode="bi", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == 'vqa2'))) train_dataset = seq2seq_loader.Img2txtDataset( args.src_file, args.image_root, args.split, args.train_batch_size, data_tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, enable_butd=args.enable_butd, tasks=args.tasks) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs * 1. / args.gradient_accumulation_steps) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == 'img2txt' else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"]) # index in BERT vocab: 103, 102, 0 if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, 'must init from maximum likelihood training' _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load( os.path.join(args.output_dir, "model.{0}.bin".format(recover_step))) # recover_step == number of epochs global_step = math.floor(recover_step * t_total * 1. / args.num_train_epochs) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 if not args.scst: model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + '/.pretrained_model_{}'.format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks) else: model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=task_idx_proj, mask_word_id=mask_word_id, search_beam_size=1, eos_id=eos_word_ids, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover torch.cuda.empty_cache() # deprecated # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) # cnn.to(device) if args.local_rank != -1: try: # from apex.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # cnn = DDP(cnn) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # cnn = DataParallelImbalance(cnn) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))) if hasattr(optim_recover, 'state_dict'): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: logger.info("***** Running training *****") logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", t_total) logger.info(" Loader length = %d", len(train_dataloader)) model.train() if recover_step: start_epoch = recover_step + 1 else: start_epoch = 1 for i_epoch in trange(start_epoch, args.num_train_epochs + 1, desc="Epoch"): if args.local_rank >= 0: train_sampler.set_epoch(i_epoch - 1) iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)') nbatches = len(train_dataloader) train_loss = [] pretext_loss = [] vqa2_loss = [] scst_reward = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1).permute(0, 2, 1).contiguous() if not args.scst: loss_tuple = model( conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio if i_epoch > args.drop_after else 0) mean_reward = loss_tuple[0].new(1).fill_(0) else: # scst training model.eval() position_ids = torch.arange( input_ids.size(1), dtype=input_ids.dtype, device=input_ids.device).unsqueeze(0).expand_as( input_ids) input_dummy = input_ids[:, :args.len_vis_input + 2] # +2 for [CLS] and [SEP] greedy_res = input_ids.new( input_ids.size(0), input_ids.size(1) - args.len_vis_input - 2).fill_(0) gen_result = input_ids.new( input_ids.size(0), input_ids.size(1) - args.len_vis_input - 2).fill_(0) with torch.no_grad(): greedy_res_raw, _ = model(conv_feats, vis_pe, input_dummy, segment_ids, position_ids, input_mask, task_idx=task_idx, sample_mode='greedy') for b in range(greedy_res_raw.size(0)): for idx in range(greedy_res_raw.size(1)): if greedy_res_raw[b][idx] not in [ eos_word_ids, pad_word_ids ]: greedy_res[b][idx] = greedy_res_raw[b][idx] else: if greedy_res_raw[b][idx] == eos_word_ids: greedy_res[b][idx] = eos_word_ids break model.train() gen_result_raw, sample_logprobs = model( conv_feats, vis_pe, input_dummy, segment_ids, position_ids, input_mask, task_idx=task_idx, sample_mode='sample') for b in range(gen_result_raw.size(0)): for idx in range(gen_result_raw.size(1)): if gen_result_raw[b][idx] not in [ eos_word_ids, pad_word_ids ]: gen_result[b][idx] = gen_result_raw[b][idx] else: if gen_result_raw[b][idx] == eos_word_ids: gen_result[b][idx] = eos_word_ids break gt_ids = input_ids[:, args.len_vis_input + 2:] reward = get_self_critical_reward(greedy_res, gt_ids, gen_result, gt_ids.size(0)) reward = torch.from_numpy(reward).float().to( gen_result.device) mean_reward = reward.mean() loss = rl_crit(sample_logprobs, gen_result.data, reward) loss_tuple = [ loss, loss.new(1).fill_(0.), loss.new(1).fill_(0.) ] # disable pretext_loss_deprecated for now masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple if n_gpu > 1: # mean() to average on multi-gpu. For dist, this is done through gradient addition. masked_lm_loss = masked_lm_loss.mean() pretext_loss_deprecated = pretext_loss_deprecated.mean() ans_loss = ans_loss.mean() loss = masked_lm_loss + pretext_loss_deprecated + ans_loss # logging for each step (i.e., before normalization by args.gradient_accumulation_steps) iter_bar.set_description('Iter (loss=%5.3f)' % loss.item()) train_loss.append(loss.item()) pretext_loss.append(pretext_loss_deprecated.item()) vqa2_loss.append(ans_loss.item()) scst_reward.append(mean_reward.item()) if step % 100 == 0: logger.info( "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n" .format(i_epoch, step, np.mean(train_loss), np.mean(pretext_loss), np.mean(vqa2_loss), np.mean(scst_reward))) if args.enable_visdom: if vis_window['iter'] is None: vis_window['iter'] = vis.line( X=np.tile( np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack( (np.asarray([np.mean(train_loss)]), )), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total'])) else: vis.line(X=np.tile( np.arange((i_epoch - 1) * nbatches + step, (i_epoch - 1) * nbatches + step + 1), (1, 1)).T, Y=np.column_stack( (np.asarray([np.mean(train_loss)]), )), opts=dict(title='Training Loss', xlabel='Training Iteration', ylabel='Loss', legend=['total']), win=vis_window['iter'], update='append') # ensure that accumlated gradients are normalized if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) if amp_handle: amp_handle._clear_cache() else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: lr_this_step = args.learning_rate * \ warmup_linear(global_step/t_total, args.warmup_proportion) if args.fp16: # modify learning rate with special warm up BERT uses for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 # Save a trained model logger.info( "** ** * Saving fine-tuned model and optimizer ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "model.{0}.bin".format(i_epoch)) output_optim_file = os.path.join(args.output_dir, "optim.{0}.bin".format(i_epoch)) if args.global_rank in ( -1, 0): # save model if the first device or no dist torch.save( copy.deepcopy(model_to_save).cpu().state_dict(), output_model_file) # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.world_size > 1: torch.distributed.barrier()
def main(): args = load_args() ss = load_server_socket() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1." ) if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s", num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift)) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) def _get_token_id_set(s): r = None if s: w_list = [] for w in s.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) r = set(tokenizer.convert_tokens_to_ids(w_list)) return r forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word) not_predict_set = _get_token_id_set(args.not_predict_token) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, not_predict_set=not_predict_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length, ffn_type=args.ffn_type, num_qkv=args.num_qkv, seg_emb=args.seg_emb, pos_shift=args.pos_shift) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() max_src_length = args.max_seq_length - 2 - args.max_tgt_length DONE_SIGNAL = 42 # bytescode of asterisk '*' while True: ss.listen(100) print("Waiting Connection:") cs, addr = ss.accept() data = bytes() while True: recv = cs.recv(1024) if 0 < len(recv) and DONE_SIGNAL == recv[-1]: data += recv[:len(recv) - 1] break data += recv print("Connection with:", addr) print("Received:", len(data)) input_lines = [ x.strip() for x in data.decode('utf-8').splitlines() ] if args.subset > 0: logger.info("Decoding subset: %d", args.subset) input_lines = input_lines[:args.subset] data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer input_lines = [ data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines ] input_lines = list(enumerate(input_lines)) output_lines = [""] * len(input_lines) score_trace_list = [None] * len(input_lines) next_i = 0 while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[1] for x in _chunk] next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors( instances) batch = [ t.to(device) if t is not None else None for t in batch ] input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces.tolist() qg_result = [] for i in range(len(buf)): w_ids = output_ids[i] output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) qg_result.append(output_sequence) if args.need_score_traces: score_trace_list[buf_id[i]] = { 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i] } cs.sendall(ascii_print('\n'.join(qg_result))) cs.close() if args.output_file: fn_out = args.output_file else: fn_out = model_recover_path + '.' + args.split with open(fn_out, "w", encoding="utf-8") as fout: for l in output_lines: fout.write(l) fout.write("\n") if args.need_score_traces: with open(fn_out + ".trace.pickle", "wb") as fout_trace: pickle.dump( { "version": 0.0, "num_samples": len(input_lines) }, fout_trace) for x in score_trace_list: pickle.dump(x, fout_trace)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") # decoding parameters parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument('--subset', type=int, default=0, help="Decode a subset of the input dataset.") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument('--topk', type=int, default=10, help="Value of K.") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Ignore the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument('--not_predict_token', type=str, default=None, help="Do not predict the tokens during decoding.") args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1." ) if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # tokenizer = BertTokenizer.from_pretrained( # args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer( vocab_file= '/ps2/intern/clsi/BERT/bert_weights/cased_L-24_H-1024_A-16/vocab.txt', do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s", num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift)) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) def _get_token_id_set(s): r = None if s: w_list = [] for w in s.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) r = set(tokenizer.convert_tokens_to_ids(w_list)) return r forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word) not_predict_set = _get_token_id_set(args.not_predict_token) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, not_predict_set=not_predict_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length, ffn_type=args.ffn_type, num_qkv=args.num_qkv, seg_emb=args.seg_emb, pos_shift=args.pos_shift, topk=args.topk, config_path=args.config_path) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length ## for YFG style json # testset = loads_json(args.input_file, 'Load Test Set: '+args.input_file) # if args.subset > 0: # logger.info("Decoding subset: %d", args.subset) # testset = testset[:args.subset] with open(args.input_file, encoding="utf-8") as fin: data = json.load(fin) # input_lines = [x.strip() for x in fin.readlines()] # if args.subset > 0: # logger.info("Decoding subset: %d", args.subset) # input_lines = input_lines[:args.subset] # data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer # input_lines = [data_tokenizer.tokenize( # x)[:max_src_length] for x in input_lines] # input_lines = sorted(list(enumerate(input_lines)), # key=lambda x: -len(x[1])) # output_lines = [""] * len(input_lines) # score_trace_list = [None] * len(input_lines) # total_batch = math.ceil(len(input_lines) / args.batch_size) data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer PQA_dict = {} #will store the generated distractors dis_tot = 0 dis_n = 0 len_tot = 0 hypothesis = {} ##change to process one by one and store the distractors in PQA json form ##with tqdm(total=total_batch) as pbar: # for example in tqdm(testset): # question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id']) # if question_id in hypothesis: # continue # dis_n += 1 # if dis_n % 2000 == 0: # logger.info("Already processed: "+str(dis_n)) counter = 0 for race_id, example in tqdm(data.items()): counter += 1 if args.subset > 0 and counter >= args.subset: break eg_dict = {} # eg_dict["question_id"] = question_id # eg_dict["question"] = ' '.join(example['question']) # eg_dict["context"] = ' '.join(example['article']) eg_dict["question"] = example['question'] eg_dict["context"] = example['context'] label = int(example["label"]) options = example["options"] answer = options[label] #new_distractors = [] pred1 = None pred2 = None pred3 = None #while next_i < len(input_lines): #_chunk = input_lines[next_i:next_i + args.batch_size] #line = example["context"].strip() + ' ' + example["question"].strip() question = example['question'] question = question.replace('_', ' ') line = ' '.join( nltk.word_tokenize(example['context']) + nltk.word_tokenize(question)) line = [data_tokenizer.tokenize(line)[:max_src_length]] # buf_id = [x[0] for x in _chunk] # buf = [x[1] for x in _chunk] buf = line #next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors(instances) batch = [ t.to(device) if t is not None else None for t in batch ] input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch # for i in range(1): #try max 10 times # if len(new_distractors) >= 3: # break traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] # print (np.array(output_ids).shape) # print (output_ids) else: output_ids = traces.tolist() # now only supports single batch decoding!!! # will keep the second and third sequence as backup for i in range(len(buf)): # print (len(buf), buf) for s in range(len(output_ids)): output_seq = output_ids[s] #w_ids = output_ids[i] #output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_buf = tokenizer.convert_ids_to_tokens( output_seq) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) if s == 1: backup_1 = output_tokens if s == 2: backup_2 = output_tokens if pred1 is None: pred1 = output_tokens elif jaccard_similarity(pred1, output_tokens) < 0.5: if pred2 is None: pred2 = output_tokens elif pred3 is None: if jaccard_similarity(pred2, output_tokens) < 0.5: pred3 = output_tokens if pred1 is not None and pred2 is not None and pred3 is not None: break if pred2 is None: pred2 = backup_1 if pred3 is None: pred3 = backup_2 elif pred3 is None: pred3 = backup_1 # output_sequence = ' '.join(detokenize(output_tokens)) # print (output_sequence) # print (output_sequence) # if output_sequence.lower().strip() == answer.lower().strip(): # continue # repeated = False # for cand in new_distractors: # if output_sequence.lower().strip() == cand.lower().strip(): # repeated = True # break # if not repeated: # new_distractors.append(output_sequence.strip()) #hypothesis[question_id] = [pred1, pred2, pred3] new_distractors = [pred1, pred2, pred3] # print (new_distractors) # dis_tot += len(new_distractors) # # fill the missing ones with original distractors # for i in range(4): # if len(new_distractors) >= 3: # break # elif i == label: # continue # else: # new_distractors.append(options[i]) for dis in new_distractors: len_tot += len(dis) dis_n += 1 new_distractors = [ ' '.join(detokenize(dis)) for dis in new_distractors if dis is not None ] assert len(new_distractors) == 3, "Number of distractors WRONG" new_distractors.insert(label, answer) #eg_dict["generated_distractors"] = new_distractors eg_dict["options"] = new_distractors eg_dict["label"] = label #PQA_dict[question_id] = eg_dict PQA_dict[race_id] = eg_dict # reference = {} # for example in testset: # question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id']) # if question_id not in reference.keys(): # reference[question_id] = [example['distractor']] # else: # reference[question_id].append(example['distractor']) # _ = eval(hypothesis, reference) # assert len(PQA_dict) == len(data), "Number of examples WRONG" # logger.info("Average number of GENERATED distractor per question: "+str(dis_tot/dis_n)) logger.info("Average length of distractors: " + str(len_tot / dis_n)) with open(args.output_file, mode='w', encoding='utf-8') as f: json.dump(PQA_dict, f, indent=4)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--unilm_model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining unilm model.") parser.add_argument("--topic_model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") parser.add_argument("--topic_data_path", default=None, type=str, help="The file of topic model data.") parser.add_argument("--topic_num", default=50, type=int, help="topic_num.") parser.add_argument("--data_path", default=None, type=str, help="The file of topic model data.") parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument('--ffn_type', default=0, type=int, help="0: default mlp; 1: W((Wx+b) elem_prod x);") parser.add_argument('--num_qkv', default=0, type=int, help="Number of different <Q,K,V>.") parser.add_argument('--seg_emb', action='store_true', help="Using segment embedding for self-attention.") parser.add_argument('--topic_mode', default=1, type=float, help="1:idea1 1.1:idea1_wo_theta 2:idea2 ") parser.add_argument("--topic_model_dict_path", default=None, type=str, help="The file of fine-tuned pretraining topic model.") # decoding parameters parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument('--subset', type=int, default=0, help="Decode a subset of the input dataset.") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--new_pos_ids', action='store_true', help="Use new position ids for LMs.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Ignore the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") parser.add_argument( '--s2s_special_token', action='store_true', help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") parser.add_argument('--s2s_add_segment', action='store_true', help="Additional segmental for the encoder of S2S.") parser.add_argument( '--s2s_share_segment', action='store_true', help= "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)." ) parser.add_argument('--pos_shift', action='store_true', help="Using position shift for fine-tuning.") parser.add_argument('--not_predict_token', type=str, default=None, help="Do not predict the tokens during decoding.") args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1." ) if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s", num_qkv=args.num_qkv, s2s_special_token=args.s2s_special_token, s2s_add_segment=args.s2s_add_segment, s2s_share_segment=args.s2s_share_segment, pos_shift=args.pos_shift)) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) # logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 + \ (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2 mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[S2S_SOS]"]) def _get_token_id_set(s): r = None if s: w_list = [] for w in s.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) r = set(tokenizer.convert_tokens_to_ids(w_list)) return r forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word) not_predict_set = _get_token_id_set(args.not_predict_token) unilm_model_recover = torch.load(args.unilm_model_recover_path) unilm = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, state_dict=unilm_model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, not_predict_set=not_predict_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length, ffn_type=args.ffn_type, num_qkv=args.num_qkv, seg_emb=args.seg_emb, pos_shift=args.pos_shift) topic_model_recover = torch.load(args.topic_model_recover_path) dictionary = Dictionary.load_from_text(args.topic_model_dict_path) gsm = GSM(len(dictionary)) gsm.load_state_dict(topic_model_recover) del unilm_model_recover del topic_model_recover if args.fp16: unilm.half() gsm.half() unilm.to(device) gsm.to(device) if n_gpu > 1: unilm = torch.nn.DataParallel(unilm) gsm = torch.nn.DataParallel(gsm) torch.cuda.empty_cache() unilm.eval() gsm.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length with open(args.input_file, encoding="utf-8") as fin: input_lines = [x.strip() for x in fin.readlines()] if args.subset > 0: #==0 可忽略 # logger.info("Decoding subset: %d", args.subset) input_lines = input_lines[:args.subset] data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer input_lines = [ data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines ] input_lines = sorted( list(enumerate(input_lines)), key=lambda x: -len(x[1]) ) #input_lines = [(ori_index,[tokens]), (ori_index,[tokens])] 按照文本长度倒着排 output_lines = [""] * len(input_lines) #一维[] score_trace_list = [None] * len(input_lines) total_batch = math.ceil(len(input_lines) / args.batch_size) # get topic_model bows def detokenize(tk_list): r_list = [] src = " ".join(tk_list) src = src.replace("[UNK]", "") tk_list = src.split() for tk in tk_list: if tk.startswith('##') and len(r_list) > 0: r_list[-1] = r_list[-1] + tk[2:] else: r_list.append(tk) src = " ".join(r_list) src = src.replace("UNK", "") r_list = src.split() return r_list txtLines = [] for input_line in input_lines: textline = " ".join(detokenize(input_line[1])) txtLines.append(textline) cwd = os.getcwd() dictionary = Dictionary.load_from_text(args.topic_model_dict_path) dictionary.id2token = { v: k for k, v in dictionary.token2id.items() } # because id2token is empty be default, it is a bug. stopwords = set([ l.strip('\n').strip() for l in open(os.path.join(cwd, 'data/topic_model', 'stopwords.txt'), 'r', encoding='utf-8') ]) topic_tokenizer = seq2seq_loader.SpacyTokenizer(stopwords=stopwords) docs = topic_tokenizer.tokenize(txtLines) # convert to BOW representation bows, _docs = [], [] vocabsize = len(dictionary) print("vocabsize", vocabsize) for doc in docs: _bow = dictionary.doc2bow(doc) if _bow != []: _docs.append(list(doc)) bows.append(_bow) else: bows.append([(vocabsize - 1, 1)]) docs = _docs with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] #如果超过就到最后一个,这是list[a:b]的特性 buf_id = [x[0] for x in _chunk] buf = [x[1] for x in _chunk] max_a_len = max([len(x) for x in buf]) instances = [] batch_bow = [] for i in range(next_i, next_i + args.batch_size): if i < len(input_lines): bow = torch.zeros(vocabsize) item = list( zip(*bows[i]) ) # bow = [[token_id1,token_id2,...],[freq1,freq2,...]] bow[list(item[0])] = torch.tensor(list(item[1])).float() batch_bow.append(bow) next_i += args.batch_size for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: #proc 是 Preprocess4Seq2seqDecoder 相当于可以把数据给padding instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors(instances) batch = [ t.to(device) if t is not None else None for t in batch ] batch_bow = torch.stack(batch_bow) batch_bow = batch_bow.to(device) input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch p_x, mus, log_vars, theta, beta, topic_embedding = gsm( batch_bow) traces = unilm(input_ids, theta, beta, topic_embedding, args.topic_mode, token_type_ids, position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) cal_ppl(batch_bow, p_x, log_vars, mus) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces.tolist() for i in range(len(buf)): w_ids = output_ids[i] output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) output_lines[buf_id[i]] = output_sequence if args.need_score_traces: score_trace_list[buf_id[i]] = { 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i] } pbar.update(1) print("word_count", word_count) ppx = np.exp(loss_sum / word_count) ppx_document = np.exp(ppx_sum / doc_count) print("ppx", ppx) print("ppx_document", ppx_document) topic_words = show_topic_words(gsm.module, args.topic_num, device, dictionary.id2token, topic_id=None, topK=10) # evaluate_topic_quality(topic_words, docs, dictionary, taskname="unilm", calc4each=False) topic_diversity = calc_topic_diversity(topic_words) print("topic_diversity", topic_diversity) # print('\n'.join([str(lst) for lst in topic_words])) # print('='*30) if args.output_file: fn_out = args.output_file else: fn_out = args.unilm_model_recover_path + '.' + args.split with open(fn_out, "w", encoding="utf-8") as fout: for l in output_lines: fout.write(l) fout.write("\n") if args.need_score_traces: with open(fn_out + ".trace.pickle", "wb") as fout_trace: pickle.dump({ "version": 0.0, "num_samples": len(input_lines) }, fout_trace) for x in score_trace_list: pickle.dump(x, fout_trace)
def main(): parser = argparse.ArgumentParser() # General parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.", ) parser.add_argument( "--config_path", default=None, type=str, help="Bert config file path." ) parser.add_argument( "--output_dir", default="tmp", type=str, help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--log_file", default="eval.log", type=str, help="The output directory where the log will be written.", ) parser.add_argument( "--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.", ) parser.add_argument( "--do_train", action="store_true", help="Whether to run training. This should ALWAYS be set to True.", ) parser.add_argument( "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", ) parser.add_argument( "--train_batch_size", default=64, type=int, help="Total batch size for training.", ) parser.add_argument( "--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.", ) parser.add_argument( "--label_smoothing", default=0, type=float, help="The initial learning rate for Adam.", ) parser.add_argument( "--weight_decay", default=0.01, type=float, help="The weight decay rate for Adam.", ) parser.add_argument( "--finetune_decay", action="store_true", help="Weight decay to the original weights.", ) parser.add_argument( "--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.", ) parser.add_argument( "--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.", ) parser.add_argument( "--no_cuda", action="store_true", help="Whether not to use CUDA when available" ) parser.add_argument( "--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus", ) parser.add_argument( "--global_rank", type=int, default=-1, help="global_rank for distributed training on gpus", ) parser.add_argument( "--seed", type=int, default=42, help="random seed for initialization" ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit float precision instead of 32-bit", ) parser.add_argument( "--fp32_embedding", action="store_true", help="Whether to use 32-bit float precision instead of 32-bit for embeddings", ) parser.add_argument( "--loss_scale", type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n", ) parser.add_argument( "--amp", action="store_true", help="Whether to use amp for fp16" ) parser.add_argument( "--from_scratch", action="store_true", help="Initialize parameters with random values (i.e., training from scratch).", ) parser.add_argument( "--new_segment_ids", action="store_true", help="Use new segment ids for bi-uni-directional LM.", ) parser.add_argument( "--tokenized_input", action="store_true", help="Whether the input is tokenized." ) parser.add_argument( "--len_vis_input", type=int, default=100, help="The length of visual token input", ) parser.add_argument( "--max_len_b", type=int, default=20, help="Truncate_config: maximum length of segment B.", ) parser.add_argument( "--trunc_seg", default="b", help="Truncate_config: first truncate segment A/B (option: a, b).", ) parser.add_argument( "--always_truncate_tail", action="store_true", help="Truncate_config: Whether we should always truncate tail.", ) parser.add_argument( "--mask_prob", default=0.15, type=float, help="Number of prediction is sometimes less than max_pred when sequence is short.", ) parser.add_argument( "--max_pred", type=int, default=3, help="Max tokens of prediction." ) parser.add_argument( "--num_workers", default=4, type=int, help="Number of workers for the data loader.", ) parser.add_argument( "--max_position_embeddings", type=int, default=None, help="max position embeddings", ) # Others for VLP parser.add_argument( "--src_file", default=["/mnt/dat/COCO/annotations/dataset_coco.json"], type=str, nargs="+", help="The input data file name.", ) parser.add_argument("--enable_visdom", action="store_true") parser.add_argument("--visdom_port", type=int, default=8888) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument("--image_root", type=str, default="/mnt/dat/COCO/images") parser.add_argument( "--dataset", default="coco", type=str, help="coco | flickr30k | cc" ) parser.add_argument("--split", type=str, nargs="+", default=["train", "restval"]) parser.add_argument( "--world_size", default=1, type=int, help="number of distributed processes" ) parser.add_argument( "--dist_url", default="file://[PT_OUTPUT_DIR]/nonexistent_file", type=str, help="url used to set up distributed training", ) parser.add_argument( "--file_valid_jpgs", default="/mnt/dat/COCO/annotations/coco_valid_jpgs.json", type=str, ) parser.add_argument( "--sche_mode", default="warmup_linear", type=str, help="warmup_linear | warmup_constant | warmup_cosine", ) parser.add_argument("--drop_prob", default=0.1, type=float) parser.add_argument("--use_num_imgs", default=-1, type=int) parser.add_argument("--vis_mask_prob", default=0, type=float) parser.add_argument("--max_drop_worst_ratio", default=0, type=float) parser.add_argument("--drop_after", default=6, type=int) parser.add_argument( "--s2s_prob", default=1, type=float, help="Percentage of examples that are bi-uni-directional LM (seq2seq).", ) parser.add_argument( "--bi_prob", default=0, type=float, help="Percentage of examples that are bidirectional LM.", ) parser.add_argument( "--enable_butd", action="store_true", help="set to take in region features" ) parser.add_argument( "--region_bbox_file", default="coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5", type=str, ) parser.add_argument( "--region_det_file_prefix", default="feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval", type=str, ) parser.add_argument("--tasks", default="img2txt", help="img2txt | vqa2") parser.add_argument( "--relax_projection", action="store_true", help="Use different projection layers for tasks.", ) parser.add_argument( "--scst", action="store_true", help="Self-critical sequence training" ) args = parser.parse_args() print("global_rank: {}, local rank: {}".format(args.global_rank, args.local_rank)) args.max_seq_length = ( args.max_len_b + args.len_vis_input + 3 ) # +3 for 2x[SEP] and [CLS] args.mask_image_regions = ( args.vis_mask_prob > 0 ) # whether to mask out image regions args.dist_url = args.dist_url.replace("[PT_OUTPUT_DIR]", args.output_dir) # arguments inspection assert args.tasks in ("img2txt", "vqa2") assert args.enable_butd == True, "only support region attn! featmap attn deprecated" assert (not args.scst) or args.dataset == "coco", "scst support on coco only!" if args.scst: assert args.dataset == "coco", "scst support on coco only!" assert args.max_pred == 0 and args.mask_prob == 0, "no mask for scst!" rl_crit = RewardCriterion() if args.enable_butd: assert args.len_vis_input == 100 args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = ( os.path.join(args.image_root, args.region_det_file_prefix) if args.dataset in ("cc", "coco") and args.region_det_file_prefix != "" else "" ) # output config os.makedirs(args.output_dir, exist_ok=True) json.dump( args.__dict__, open(os.path.join(args.output_dir, "eval_opt.json"), "w"), sort_keys=True, indent=2, ) logging.basicConfig( filename=os.path.join(args.output_dir, args.log_file), filemode="w", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) if args.local_rank == -1 or args.no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" ) n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=args.global_rank, ) logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16 ) ) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps ) ) args.train_batch_size = int( args.train_batch_size / args.gradient_accumulation_steps ) # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # plotting loss, optional if args.enable_visdom: import visdom vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir) vis_window = {"iter": None, "score": None} # preprocessing/data loader tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank), ) if args.max_position_embeddings: tokenizer.max_len = args.max_position_embeddings data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer if args.do_train: bi_uni_pipeline = [ seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ "max_len_b": args.max_len_b, "trunc_seg": args.trunc_seg, "always_truncate_tail": args.always_truncate_tail, }, mask_image_regions=args.mask_image_regions, mode="s2s", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == "vqa2"), ) ] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seq( args.max_pred, args.mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={ "max_len_b": args.max_len_b, "trunc_seg": args.trunc_seg, "always_truncate_tail": args.always_truncate_tail, }, mask_image_regions=args.mask_image_regions, mode="bi", len_vis_input=args.len_vis_input, vis_mask_prob=args.vis_mask_prob, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix, local_rank=args.local_rank, load_vqa_ann=(args.tasks == "vqa2"), ) ) train_dataset = seq2seq_loader.Img2txtDataset( args.src_file, args.image_root, args.split, args.train_batch_size, data_tokenizer, args.max_seq_length, file_valid_jpgs=args.file_valid_jpgs, bi_uni_pipeline=bi_uni_pipeline, use_num_imgs=args.use_num_imgs, s2s_prob=args.s2s_prob, bi_prob=args.bi_prob, enable_butd=args.enable_butd, tasks=args.tasks, ) if args.world_size == 1: train_sampler = RandomSampler(train_dataset, replacement=False) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, sampler=train_sampler, num_workers=args.num_workers, collate_fn=batch_list_to_batch_tensors, pin_memory=True, ) # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps) t_total = int( len(train_dataloader) * args.num_train_epochs * 1.0 / args.gradient_accumulation_steps ) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model recover_step = _get_max_epoch_model(args.output_dir) cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 relax_projection = 4 if args.relax_projection else 0 task_idx_proj = 3 if args.tasks == "img2txt" else 0 mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]", "[PAD]"] ) # index in BERT vocab: 103, 102, 0 if (recover_step is None) and (args.model_recover_path is None): # if _state_dict == {}, the parameters are randomly initialized # if _state_dict == None, the parameters are initialized with bert-init assert args.scst == False, "must init from maximum likelihood training" _state_dict = {} if args.from_scratch else None model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=_state_dict, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks, ) global_step = 0 else: if recover_step: logger.info("***** Recover model: %d *****", recover_step) model_recover = torch.load( os.path.join(args.output_dir, "model.{0}.bin".format(recover_step)) ) # recover_step == number of epochs global_step = math.floor( recover_step * t_total * 1.0 / args.num_train_epochs ) elif args.model_recover_path: logger.info("***** Recover model: %s *****", args.model_recover_path) model_recover = torch.load(args.model_recover_path) global_step = 0 if not args.scst: model = BertForPreTrainingLossMask.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, relax_projection=relax_projection, config_path=args.config_path, task_idx=task_idx_proj, max_position_embeddings=args.max_position_embeddings, label_smoothing=args.label_smoothing, fp32_embedding=args.fp32_embedding, cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank), drop_prob=args.drop_prob, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, tasks=args.tasks, ) else: model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=task_idx_proj, mask_word_id=mask_word_id, search_beam_size=1, eos_id=eos_word_ids, mode="s2s", enable_butd=args.enable_butd, len_vis_input=args.len_vis_input, ) del model_recover torch.cuda.empty_cache() # deprecated # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() if args.fp32_embedding: model.bert.embeddings.word_embeddings.float() model.bert.embeddings.position_embeddings.float() model.bert.embeddings.token_type_embeddings.float() model.to(device) # cnn.to(device) if args.local_rank != -1: try: # from apex.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # cnn = DDP(cnn) elif n_gpu > 1: # model = torch.nn.DataParallel(model) model = DataParallelImbalance(model) # cnn = DataParallelImbalance(cnn) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] if args.fp16: try: # from apex.optimizers import FP16_Optimizer from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam( optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0, ) if args.loss_scale == 0: optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer_State( optimizer, static_loss_scale=args.loss_scale ) else: optimizer = BertAdam( optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, schedule=args.sche_mode, t_total=t_total, ) if recover_step: logger.info("***** Recover optimizer: %d *****", recover_step) optim_recover = torch.load( os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step)) ) if hasattr(optim_recover, "state_dict"): optim_recover = optim_recover.state_dict() optimizer.load_state_dict(optim_recover) if args.loss_scale == 0: logger.info("***** Recover optimizer: dynamic_loss_scale *****") optimizer.dynamic_loss_scale = True logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() if args.do_train: model.eval() losses = [] for batch in tqdm(train_dataloader): # wrangle batch batch = [t.to(device) for t in batch] ( input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels, ) = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = ( conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1) .permute(0, 2, 1) .contiguous() ) # compute loss masked_lm_loss, _, _ = model( conv_feats, vis_pe, input_ids, segment_ids, input_mask, lm_label_ids, ans_labels, is_next, masked_pos=masked_pos, masked_weights=masked_weights, task_idx=task_idx, vis_masked_pos=vis_masked_pos, mask_image_regions=args.mask_image_regions, drop_worst_ratio=args.max_drop_worst_ratio ) # average across multiple GPUs if n_gpu > 1: masked_lm_loss = masked_lm_loss.mean() losses.append(masked_lm_loss.item()) print(args.split, 'perplexity:', np.exp(np.mean(losses)))
def main(): parser = argparse.ArgumentParser() # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased" ) parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--max_tgt_length', type=int, default=20, help="maximum length of target sequence") # Others for VLP parser.add_argument("--src_file", default='/mnt/dat/COCO/annotations/dataset_coco.json', type=str, help="The input data file name.") parser.add_argument('--dataset', default='coco', type=str, help='coco | flickr30k | cc') parser.add_argument('--len_vis_input', type=int, default=100) # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth') parser.add_argument('--image_root', type=str, default='/mnt/dat/COCO/images') parser.add_argument('--split', type=str, default='val') parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument( '--region_bbox_file', default= 'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str) parser.add_argument( '--region_det_file_prefix', default= 'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval', type=str) parser.add_argument('--file_valid_jpgs', default='', type=str) args = parser.parse_args() if args.enable_butd: assert (args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join( args.image_root, args.region_det_file_prefix) if args.dataset in ( 'cc', 'coco', 'fake_media') and args.region_det_file_prefix != '' else '' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) args.max_seq_length = args.max_tgt_length + args.len_vis_input + 3 # +3 for 2x[SEP] and [CLS] tokenizer.max_len = args.max_seq_length bi_uni_pipeline = [] bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode='s2s', len_vis_input=args.len_vis_input, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file, region_det_file_prefix=args.region_det_file_prefix)) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]"]) forbid_ignore_set = None if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover # from vlp.resnet import resnet # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning if args.fp16: model.half() # cnn.half() model.to(device) # cnn.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) # cnn = torch.nn.DataParallel(cnn) torch.cuda.empty_cache() model.eval() # cnn.eval() eval_lst = [] with open(args.src_file, "r", encoding='utf-8') as f_src: img_dat = json.load(f_src)['images'] img_idx = 0 valid_jpgs = None if (args.file_valid_jpgs == '' or args.dataset in \ ('coco', 'flickr30k')) else json.load(open(args.file_valid_jpgs)) for src in img_dat: if src['split'] == args.split and ( valid_jpgs is None or src['filename'] in valid_jpgs): if args.enable_butd: src_tk = os.path.join(args.image_root, src.get('filepath', 'trainval'), src['filename'][:-4] + '.npy') else: src_tk = os.path.join(args.image_root, src.get('filepath', 'trainval'), src['filename']) if args.dataset == 'coco': imgid = int(src['filename'].split('_')[2][:-4]) elif args.dataset == 'cc' or args.dataset == 'fake_media': imgid = int(src['imgid']) elif args.dataset == 'flickr30k': imgid = int(src['filename'].split('.')[0]) eval_lst.append( (img_idx, imgid, src_tk)) # id and path for COCO img_idx += 1 input_lines = eval_lst predictions = {} print('start the caption evaluation...') total_batch = math.ceil( len(input_lines) / args.batch_size) * len(SEED_PHRASES) # output_lines = [""] * len(input_lines) with tqdm(total=total_batch) as pbar: for seed in SEED_PHRASES: seed_ids = tokenizer.convert_tokens_to_ids( tokenizer.tokenize(seed)) next_i = 0 while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[2] for x in _chunk] next_i += args.batch_size instances = [] for instance in [(x, args.len_vis_input) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = batch_list_to_batch_tensors(instances) batch = [t.to(device) for t in batch] input_ids, token_type_ids, position_ids, input_mask, task_idx, img, vis_pe = batch if args.fp16: img = img.half() vis_pe = vis_pe.half() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view( conv_feats.size(0), conv_feats.size(1), -1).permute(0, 2, 1).contiguous() traces = model(conv_feats, vis_pe, input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx, seed_ids=seed_ids) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces[0].tolist() for i in range(len(buf)): w_ids = output_ids[i] output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join( detokenize(output_tokens)) img_id = input_lines[buf_id[i]][1] if img_id not in predictions: predictions[img_id] = [output_sequence] else: predictions[img_id].append(output_sequence) # output_lines[buf_id[i]] = output_sequence pbar.update(1) # predictions = [{'image_id': tup[1], 'caption': output_lines[img_idx]} for img_idx, tup in enumerate(input_lines)] # output captions to file output_dir = os.path.dirname(args.model_recover_path) caption_output_path = os.path.join( output_dir, args.split + '_captions_beam' + str(args.beam_size) + '.json') with open(caption_output_path, 'w') as caption_out_fp: json.dump(predictions, caption_out_fp, indent=4)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='coco', type=str, nargs='*', help='') parser.add_argument('--lang', default='en zh', type=str, nargs='*', help='') parser.add_argument('--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument('--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') parser.add_argument('--max_len_en_cap', default=25, type=int, help='maximum length of English in **img2txt** corpus') parser.add_argument('--max_len_zh_cap', default=25, type=int, help='maximum length of Chinese in **img2txt** corpus') parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") parser.add_argument("--src_file", default='$DATA_ROOT/{}/annotations/{}_dataset.json', type=str, help="The input data file name.") parser.add_argument('--split', type=str, default='val') parser.add_argument('--file_valid_jpgs', default='$DATA_ROOT/{}/annotations/{}_valid_jpgs.json', type=str) parser.add_argument('--image_root', type=str, default='$DATA_ROOT/{}/region_feat_gvd_wo_bgd') parser.add_argument('--region_bbox_file', default='raw_bbox/{}_detection_vg_100dets_vlp_checkpoint_trainval_bbox', type=str) parser.add_argument('--region_det_file_prefix', default='feat_cls_1000/{}_detection_vg_100dets_vlp_checkpoint_trainval', type=str) # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument("--bert_model", default="bert-base-cased", type=str, help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.") parser.add_argument("--xml_vocab",type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge",type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--max_tgt_length', type=int, default=20, help="maximum length of target sequence") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) args = parser.parse_args() dataset = {} for d in args.dataset: assert d in ['coco','aic','wmt'] if d == 'coco': dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_en_cap} elif d == 'aic': dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_zh_cap} else:# d == 'wmt': dataset[d] = {'max_len_a': args.max_len_en, 'max_len_b': args.max_len_zh} dataset[d]['max_seq_length'] = dataset[d]['max_len_a'] + dataset[d]['max_len_b'] + 3 args.dataset = dataset lang2cap_max_seq_length = {'zh':args.len_vis_input+args.max_len_zh_cap+3, 'en':args.len_vis_input+args.max_len_en_cap+3} if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.enable_butd: assert(args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir+'/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en tokenizers = {'en':tokenizer_en} if 'aic' in args.dataset or 'wmt' in args.dataset: tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(x, lang='zh', bypass_tokenizer=True) with open(args.xml_vocab,'r') as f: tokenizer_zh.vocab = json.load(f) tokenizers['zh'] = tokenizer_zh if 'coco_g8_lr3e-5_batch512_ft_from_s0.75_b0.25' in args.model_recover_path: indexer = Indexer([os.path.join(args.bert_model,'vocab.txt')]) else: indexer = Indexer([os.path.join(args.bert_model,'vocab.txt'), args.xml_vocab]) for corpus in args.dataset: if corpus in ['coco','aic']: #bilingual for lang in args.lang: tokenizer = tokenizers[lang]#tokenizers['en'] if corpus=='coco' else tokenizers['zh'] max_seq_length = lang2cap_max_seq_length[lang] decode_pipeline= [seq2seq_loader.Preprocess4Seq2seqDecoder( corpus, lang, list(tokenizer.vocab.keys()), indexer, max_len=max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode='s2s', len_vis_input=args.len_vis_input, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file.format(corpus.upper(), corpus.lower()), region_det_file_prefix=args.region_det_file_prefix.format(corpus.upper(), corpus.lower()))] eval_dataset = seq2seq_loader.Img2txtDataset( args.src_file.format(corpus.upper(), corpus.lower()), args.image_root.format(corpus.upper()), args.split, args.batch_size, tokenizer, max_seq_length, preprocessed=True, file_valid_jpgs=args.file_valid_jpgs.format(corpus.upper(), corpus.lower()), bi_uni_pipeline=decode_pipeline, use_num_imgs=-1, s2s_prob=1, bi_prob=0, enable_butd=args.enable_butd, tasks='img2txt') args.dataset[corpus][lang+'_eval_dataloader'] = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size, sampler=SequentialSampler(eval_dataset), num_workers=4, collate_fn=batch_list_to_batch_tensors, pin_memory=True) else: raise NotImplementedError # only support aic and coco now amp_handle = None if args.amp: from apex import amp # amp_handle = amp.init(enable_caching=True) # logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 #type_vocab_size = 6 if args.new_segment_ids else 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer( ["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(indexer(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover model.to(device) if args.amp: model = amp.initialize(model, opt_level='O2')#'02') torch.cuda.empty_cache() model.eval() for corpus in args.dataset: for lang in ['en','zh']: if not lang+'_eval_dataloader' in args.dataset[corpus]: continue print('corpus {} lang {}'.format(corpus, lang)) output_lines = {} val_iter_bar = tqdm(args.dataset[corpus][lang+'_eval_dataloader']) for step_val, val_iter_output in enumerate(val_iter_bar): info_, batch = val_iter_output[0],val_iter_output[1] with torch.no_grad(): batch = [t.to(device) for t in batch] input_ids, segment_ids, position_ids, input_mask, task_idx, img, vis_pe = batch # if step_val==0: # print(segment_ids[0][100:]) # input() #input() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1).permute(0,2,1).contiguous() if args.amp: conv_feats = conv_feats.half() vis_pe = vis_pe.half() traces = model(conv_feats, vis_pe, input_ids, segment_ids, position_ids, input_mask, search_beam_size=args.beam_size, task_idx=task_idx, sample_mode='greedy') #validation greedy if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces[0].tolist() for ii,w_ids in enumerate(output_ids): output_buf = indexer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) if corpus=='coco': id_ = int(info_[ii][2].split('_')[2]) else: id_ = info_[ii][2] #print(id_,output_sequence) output_lines[id_] = output_sequence predictions = [{'image_id': ids_, 'caption': output_lines[ids_]} for ids_ in output_lines] with open(os.path.join(args.output_dir,'{}_{}_{}_predictions.json').format(args.split, corpus, lang),'w') as f: json.dump(predictions, f) if (corpus=='coco' and lang=='en') or (corpus=='aic' and lang=='zh'): print('Begin evaluating '+corpus) lang_stats = language_eval(corpus, predictions, args.model_recover_path.split('/')[-2]+'-'+args.split+'-'+args.model_recover_path.split('/')[-1].split('.')[-2], args.split, ['Bleu','METEOR','Rouge','CIDEr']) with open(os.path.join(args.output_dir,'{}_{}_{}_scores.json').format(args.split, corpus,lang),'w') as f: json.dump(lang_stats, f)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") # decoding parameters parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--top_k', type=int, default=1, help="Top k for output") parser.add_argument('--top_kk', type=int, default=0, help="Top k sample method for output") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1." ) if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] if args.mode == "s2s" or args.mode == "both": bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s")) if args.mode == "l2r" or args.mode == "both": bi_uni_pipeline.append( seq2seq_loader.Preprocess4Seq2seqDecoder( list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="l2r")) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]"]) forbid_ignore_set = None if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length, top_kk=args.top_kk) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length with open(args.input_file, encoding="utf-8") as fin: input_lines = [x.strip() for x in fin.readlines()] data_tokenizer = WhitespaceTokenizer( ) if args.tokenized_input else tokenizer input_lines = [ data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines ] input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1])) output_lines = [""] * len(input_lines) score_trace_list = [None] * len(input_lines) total_batch = math.ceil(len(input_lines) / args.batch_size) with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[1] for x in _chunk] next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = long_loader.batch_list_to_batch_tensors(instances) batch = [t.to(device) for t in batch] input_ids, token_type_ids, position_ids, input_mask, task_idx = batch traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces.tolist() for i in range(len(buf)): scores = traces['scores'][i] wids_list = traces['wids'][i] ptrs = traces['ptrs'][i] eos_id = 102 top_k = args.top_k # first we need to find the eos frame where all symbols are eos # any frames after the eos frame are invalid last_frame_id = len(scores) - 1 for _i, wids in enumerate(wids_list): if all(wid == eos_id for wid in wids): last_frame_id = _i break frame_id = -1 pos_in_frame = -1 seqs = [] for fid in range(last_frame_id + 1): for _i, wid in enumerate(wids_list[fid]): if wid == eos_id or fid == last_frame_id: s = scores[fid][_i] frame_id = fid pos_in_frame = _i if frame_id != -1 and s < 0: seq = [ wids_list[frame_id][pos_in_frame] ] for _fid in range(frame_id, 0, -1): pos_in_frame = ptrs[_fid][ pos_in_frame] seq.append( wids_list[_fid - 1][pos_in_frame]) seq.reverse() seqs.append([seq, s]) seqs = sorted(seqs, key=lambda x: x[1], reverse=True) w_idss = [seq[0] for seq in seqs[:top_k]] output_sequences = [] for w_ids in w_idss: output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join( detokenize(output_tokens)) output_sequences.append(output_sequence) output_lines[buf_id[i]] = output_sequences if args.need_score_traces: score_trace_list[buf_id[i]] = { 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i] } pbar.update(1) if args.output_file: fn_out = args.output_file else: fn_out = model_recover_path + '.' + args.split with open(fn_out, "w", encoding="utf-8") as fout: for l in output_lines: fout.write('\t'.join(l)) fout.write("\n") if args.need_score_traces: with open(fn_out + ".trace.pickle", "wb") as fout_trace: pickle.dump({ "version": 0.0, "num_samples": len(input_lines) }, fout_trace) for x in score_trace_list: pickle.dump(x, fout_trace)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--enable_vis', action='store_true', help='whether to visualize visual embedding') parser.add_argument('--num_concept', default='-1', type=int, help='number of concepts to visualize') parser.add_argument('--num_vis', default='1', type=int, help='number of visual embeddings per concept') parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized') parser.add_argument('--src_lang', default='en', type=str, help='') parser.add_argument('--tgt_lang', default='zh', type=str, help='') parser.add_argument( '--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument( '--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') #parser.add_argument("--vocab_file", default='./src.txt', type=str, help="The input data file name.") parser.add_argument('--vocab_file', type=str, required=True, nargs='+') parser.add_argument('--en_first', action='store_true', help='always to put english as the first sentence') # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--xml_vocab", type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge", type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) #useless parser.add_argument('--split', type=str, default='val') #wmt? parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") args = parser.parse_args() print(args.vocab_file) if args.enable_vis or '.pkl' in args.vocab_file[0]: args.vocab_file = args.vocab_file[0] assert '.pkl' in args.vocab_file, args.vocab_file if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize( x, lang='zh', bypass_tokenizer=False) with open(args.xml_vocab, 'r') as f: tokenizer_zh.vocab = json.load(f) indexer = Indexer( [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab]) with open('full_vocab.json', 'w') as f: json.dump(indexer.ids_to_tokens, f) tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh} print('tokenizer created') concept_list = [] if args.enable_vis or 'pkl' in args.vocab_file: with open(args.vocab_file, 'rb') as f: vocab_list = pickle.load(f) vocab_list = vocab_list[::-1] if args.num_concept == -1: args.num_concept = len(vocab_list) vocab = [] #[ [En, Zh, Vis1(list), Vis2, Vis3, Vis4,..] *num_concept] for key, ls in vocab_list[40:40 + args.num_concept]: concept = [ls[0][0], ls[0][1]] #En, Zh concept_list.append(ls[0][0]) for inst in ls[:args.num_vis]: concept.append(inst[2:]) #2,3,4 vocab.append(concept) print('Number of Concept {}'.format(len(vocab))) if args.num_concept != -1: print(concept_list) print('Number of visual instance per concept {}'.format( len(vocab[0]) - 2)) #print('Example {}'.format(vocab[0])) #input() else: vocab = [] for filename in args.vocab_file: with open(filename) as f: v = f.readlines() if args.num_concept == -1: args.num_concept = len(v) v = [(a.split(' ')[0].strip(), a.split(' ')[-1].strip()) for a in v[:args.num_concept]] #EN ZH vocab.extend(v) print('Number of vocabulary {} {}'.format(len(vocab), vocab[0])) cls_num_labels = 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 print(args.model_recover_path) model_recover = torch.load(args.model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=2, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=True, len_vis_input=args.len_vis_input) del model_recover model.to(device) torch.cuda.empty_cache() model.eval() N_layer = 12 embeddings = OrderedDict({ 'en': [[] for i in range(N_layer)], 'zh': [[] for i in range(N_layer)] }) if args.enable_vis: embeddings['vis'] = [[] for i in range(N_layer)] for _, pair in tqdm(enumerate(vocab)): for w, lang in zip(pair, ('en', 'zh')): segment_id = 1 if lang == 'en' else 6 w_t = tokenizers[lang].tokenize(w) tokens = ['[CLS]'] + w_t + ['[SEP]'] input_ids = indexer(tokens) token_type_ids = [segment_id] * len(input_ids) input_ids = np.expand_dims(np.array(input_ids), axis=0) token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) output_embeddings = model.compute_embeddings( input_ids=input_ids, token_type_ids=token_type_ids, mode='txt2txt') #tuple 12 1,L,768 for i, e in enumerate(output_embeddings): e = e.detach().cpu().numpy() ave = np.mean(e[0, 1:-1, :], axis=0) # 768 embeddings[lang][i].append(ave) if args.enable_vis: instance_embeddings = [[] for layer_i in range(N_layer)] for vis_embed in pair[2:]: vis_feats, vis_pe, cls_label = vis_embed[0], vis_embed[ 1], vis_embed[2] #1024 vis_feats = torch.from_numpy(vis_feats).to(device) vis_feats = vis_feats.unsqueeze(0) vis_pe = torch.from_numpy(vis_pe).to(device) vis_pe = vis_pe.unsqueeze(0) cls_label = torch.from_numpy(cls_label).to(device) cls_label = cls_label.unsqueeze(0) # # lazy normalization of the coordinates... copy from seq2seq w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5 h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5 vis_pe[:, [0, 2]] /= w_est vis_pe[:, [1, 3]] /= h_est assert h_est > 0, 'should greater than 0! {}'.format(h_est) assert w_est > 0, 'should greater than 0! {}'.format(w_est) rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0]) rel_area.clamp_(0) vis_pe = torch.cat( (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1) vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \ F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H vis_feats = vis_feats.unsqueeze(0) vis_pe = vis_pe.unsqueeze(0) #print('input shape', vis_feats.shape, vis_pe.shape) segment_id = 0 tokens = ['[CLS]', '[UNK]', '[SEP]'] input_ids = indexer(tokens) token_type_ids = [segment_id] * len(input_ids) input_ids = np.expand_dims(np.array(input_ids), axis=0) token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) vis_embeddings = model.compute_embeddings( vis_feats=vis_feats, vis_pe=vis_pe, input_ids=input_ids, token_type_ids=token_type_ids, mode='img2txt', len_vis_input=1) #print(len(vis_embeddings), vis_embeddings[0].shape) #input() for i, e in enumerate(vis_embeddings): e = e.detach().cpu().numpy() ave = np.mean(e[0, 1:-1, :], axis=0) # 768 instance_embeddings[i].append(ave) #embeddings['vis'][i].append(ave) for i, embed_list in enumerate(instance_embeddings): if args.num_vis == 1: embeddings['vis'][i].append(embed_list[0]) else: embeddings['vis'][i].append(embed_list) #list of array # args.output_dir = os.path.join(args.output_dir, 'embedding_vis') # if not os.path.exists(args.output_dir): # os.makedirs(args.output_dir) for ly in range(N_layer): embed = {'en': embeddings['en'][ly], 'zh': embeddings['zh'][ly]} if args.enable_vis: embed['vis'] = embeddings['vis'][ly] #save_numpy(embed,os.path.join(args.output_dir,'hiddenstates_layer_{}.npy'.format(ly))) if args.num_vis == 1: tSNE_reduce( embed, os.path.join(args.output_dir, 'tSNE layer {}'.format(ly))) Cosine_Sim( embed, os.path.join(args.output_dir, 'CosineSim Layer {}'.format(ly))) sim_histgram( embed, os.path.join(args.output_dir, 'Histgram Layer {}'.format(ly))) if args.num_vis > 1: tSNE_reduce_visual( concept_list, embed, os.path.join(args.output_dir, 'tSNE layer {}'.format(ly))) print('Save layer {}'.format(ly))
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument("--max_seq_length", default=512, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") # decoding parameters parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument("--input_file", type=str, help="Input file") parser.add_argument("--output_file", type=str, help="output file") parser.add_argument("--split", type=str, default="", help="Data split (train/val/test).") parser.add_argument('--tokenized_input', action='store_true', help="Whether the input is tokenized.") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--top_k', type=int, default=1, help="Top k for output") parser.add_argument('--top_kk', type=int, default=0, help="Top k sample method for output") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--need_score_traces', action='store_true') parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--mode', default="s2s", choices=["s2s", "l2r", "both"]) parser.add_argument('--max_tgt_length', type=int, default=128, help="maximum length of target sequence") # evaluate parameters parser.add_argument('--do_predict', action='store_true', help="do_predict") parser.add_argument("--do_evaluate", action="store_true", help="caculate the scores if have label file") parser.add_argument("--label_file", type=str, default="") parser.add_argument("--experiment", type=str, default="full", help="full/title/title-l1/hierachical/title-first/title-first-rouge") # ranker parameters parser.add_argument("--ranker_recover_path", type=str, help="ranker model for extract sentence") parser.add_argument("--ranker_max_len", type=int, default=192, help ="max length of the ranker input") parser.add_argument("--ranker_batch_size", type=int, default=128) args = parser.parse_args() if args.need_score_traces and args.beam_size <= 1: raise ValueError( "Score trace is only available for beam search with beam size > 1.") if args.max_tgt_length >= args.max_seq_length - 2: raise ValueError("Maximum tgt length exceeds max seq length - 2.") device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case) tokenizer.max_len = args.max_seq_length pair_num_relation = 0 bi_uni_pipeline = [] if args.mode == "s2s" or args.mode == "both": bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list( tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s")) if args.mode == "l2r" or args.mode == "both": bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list( tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="l2r")) if args.experiment == "segsep": bi_uni_pipeline = [] bi_uni_pipeline.append(Preprocess4SegSepDecoder(list( tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s")) amp_handle = None if args.fp16 and args.amp: from apex import amp amp_handle = amp.init(enable_caching=True) logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 type_vocab_size = 6 if args.new_segment_ids else 2 if args.experiment == "segsep": type_vocab_size = 11 mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids( ["[MASK]", "[SEP]"]) forbid_ignore_set = None if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) print(args.model_recover_path) if args.do_predict: for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length) del model_recover if args.fp16: model.half() model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) torch.cuda.empty_cache() model.eval() next_i = 0 max_src_length = args.max_seq_length - 2 - args.max_tgt_length if args.experiment in ["full", "title", "title-l1"]: input_lines = EvalDataset(args.input_file, args.experiment).proc() elif args.experiment == "single": input_lines, map_dict = EvalDataset(args.input_file, args.experiment).proc() elif args.experiment == "title-first": input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc() elif args.experiment == "segsep": input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc() elif args.experiment == "heirachical": logger.info("***** Recover rank model: %s *****", args.ranker_recover_path) # extract sentences before load data # load rank model rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu") global_step = 0 rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2) # set model for multi GPUs or multi nodes if args.fp16: rank_model.half() rank_model.to(device) if n_gpu > 1: rank_model = DataParallelImbalance(rank_model) DatasetFunc = ScoreEvalDataset # Load title + sentence pair print ("Loading Rank Dataset from ", args.input_file) data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer max_pred = 16 mask_prob = 0.7 rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.ranker_max_len, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 64, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)] fn_src = args.input_file fn_tgt = None eval_dataset = DatasetFunc( fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.ranker_max_len, bi_uni_pipeline=rank_bi_uni_pipeline ) eval_sampler = SequentialSampler(eval_dataset) _batch_size = args.ranker_batch_size eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() logger.info("***** Runinning ranker *****") logger.info(" Batch size = %d", _batch_size) logger.info(" Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size)) rank_model.to(device) rank_model.eval() iter_bar = tqdm(eval_dataloader, desc = "Iter: ") num_rank_labels = 2 all_labels = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv) labels = torch.argmax(logits.view(-1, num_rank_labels), dim=-1) all_labels.append(labels) all_labels_results = [] for label in all_labels: all_labels_results.extend(label.detach().cpu().numpy()) # collect results logger.info("**** Collect results ******") clu2doc_dict, doc2sent_dict, all_titles, all_sents = eval_dataset.get_maps() all_docs = [] for i, doc in enumerate(doc2sent_dict): text = all_titles[i] sent_idx = doc2sent_dict[doc] for idx in sent_idx: if all_labels_results[idx] == 1: text += ". " + all_sents[idx] all_docs.append(text) input_lines = [] for clu in tqdm(clu2doc_dict): doc_idx = clu2doc_dict[clu] input_line = "" for idx in doc_idx: input_line += all_docs[idx] input_lines.append(input_line) elif args.experiment == "title-first-rank": logger.info("***** Recover rank model: %s *****", args.ranker_recover_path) # extract sentences before load data # load rank model rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu") global_step = 0 rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2) # set model for multi GPUs or multi nodes if args.fp16: rank_model.half() rank_model.to(device) if n_gpu > 1: rank_model = DataParallelImbalance(rank_model) DatasetFunc = EvalRankDataset # Load title + sentence pair print ("Loading Rank Dataset from ", args.input_file) data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer max_pred = 16 mask_prob = 0.7 rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 512, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)] fn_src = args.input_file fn_tgt = None eval_dataset = DatasetFunc( fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=rank_bi_uni_pipeline ) eval_sampler = SequentialSampler(eval_dataset) _batch_size = args.ranker_batch_size eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False) logger.info("***** CUDA.empty_cache() *****") torch.cuda.empty_cache() logger.info("***** Runinning ranker *****") logger.info(" Batch size = %d", _batch_size) logger.info(" Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size)) rank_model.to(device) rank_model.eval() iter_bar = tqdm(eval_dataloader, desc = "Iter: ") num_rank_labels = 2 all_labels = [] for step, batch in enumerate(iter_bar): batch = [t.to(device) if t is not None else None for t in batch] input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch # print("input_ids", len(input_ids[0]), "segment_ids", len(segment_ids[0])) with torch.no_grad(): logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv) labels = logits.view(-1) all_labels.append(labels) all_labels_results = [] for label in all_labels: all_labels_results.extend(label.detach().cpu().numpy()) print("test label results") print(all_labels_results[0]) # collect results logger.info("**** Collect results ******") clu2sent_dict, all_sents, all_titles= eval_dataset.get_maps() all_clusters = [] input_lines = [] for i, clu_id in enumerate(clu2sent_dict): text = all_titles[clu_id] sent_idx = clu2sent_dict[clu_id] sents_collect = [] for idx in sent_idx: sents_collect.append([all_sents[idx], all_labels_results[idx]]) sents_collect_sort = sorted(sents_collect, key=lambda x:x[1]) sents_collect = [x[0] for x in sents_collect_sort] text_tk = tokenizer.tokenize(text) j = 0 while j < len(sents_collect) and len(text_tk) + len(tokenizer.tokenize(sents_collect[j])) <= args.max_seq_length: text += " " + sents_collect[j] j += 1 input_lines.append(text) else: input_lines = [] data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer input_lines = [data_tokenizer.tokenize( x)[:max_src_length] for x in input_lines] input_lines = sorted(list(enumerate(input_lines)), key=lambda x: -len(x[1])) output_lines = [""] * len(input_lines) score_trace_list = [None] * len(input_lines) total_batch = math.ceil(len(input_lines) / args.batch_size) with tqdm(total=total_batch) as pbar: while next_i < len(input_lines): _chunk = input_lines[next_i:next_i + args.batch_size] buf_id = [x[0] for x in _chunk] buf = [x[1] for x in _chunk] next_i += args.batch_size max_a_len = max([len(x) for x in buf]) instances = [] for instance in [(x, max_a_len) for x in buf]: for proc in bi_uni_pipeline: instances.append(proc(instance)) with torch.no_grad(): batch = seq2seq_loader.batch_list_to_batch_tensors( instances) # print("batch") # print(batch) # print(len(batch)) batch = [t.to(device) for t in batch if t is not None] input_ids, token_type_ids, position_ids, input_mask, task_idx = batch traces = model(input_ids, token_type_ids, position_ids, input_mask, task_idx=task_idx) if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces.tolist() for i in range(len(buf)): scores = traces['scores'][i] wids_list = traces['wids'][i] ptrs = traces['ptrs'][i] eos_id = 102 top_k = args.top_k # first we need to find the eos frame where all symbols are eos # any frames after the eos frame are invalid last_frame_id = len(scores) - 1 for _i, wids in enumerate(wids_list): if all(wid == eos_id for wid in wids): last_frame_id = _i break frame_id = -1 pos_in_frame = -1 seqs = [] for fid in range(last_frame_id + 1): for _i, wid in enumerate(wids_list[fid]): if wid == eos_id or fid == last_frame_id: s = scores[fid][_i] frame_id = fid pos_in_frame = _i if frame_id != -1 and s < 0: seq = [wids_list[frame_id][pos_in_frame]] for _fid in range(frame_id, 0, -1): pos_in_frame = ptrs[_fid][pos_in_frame] seq.append(wids_list[_fid - 1][pos_in_frame]) seq.reverse() seqs.append([seq, s]) seqs = sorted(seqs, key= lambda x:x[1], reverse=True) w_idss = [seq[0] for seq in seqs[:top_k]] output_sequences = [] for w_ids in w_idss: output_buf = tokenizer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) output_sequences.append(output_sequence) output_lines[buf_id[i]] = output_sequences if args.need_score_traces: score_trace_list[buf_id[i]] = { 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} pbar.update(1) # collect instances after split results = [] if args.experiment == "single": for clu in map_dict: record = [] clu_ixs = map_dict[clu] for i in clu_ixs: record.extend(output_lines[i]) record_top10 = Counter(record).most_common(10) record_top10 = [x[0] for x in record_top10] results.append(record_top10) output_lines = results if args.output_file: fn_out = args.output_file else: fn_out = model_recover_path+'.'+args.split with open(fn_out, "w", encoding="utf-8") as fout: for l in output_lines: fout.write('\t'.join(l)) fout.write("\n") if args.need_score_traces: with open(fn_out + ".trace.pickle", "wb") as fout_trace: pickle.dump( {"version": 0.0, "num_samples": len(input_lines)}, fout_trace) for x in score_trace_list: pickle.dump(x, fout_trace) # Evaluate ! if args.do_evaluate: labels = [] if not os.path.exists(args.label_file): raise ValueError("Label file not exists") print("Loading label file from {}".format(args.label_file)) with open(args.label_file) as f: for line in tqdm(f.readlines()): line = line.strip().split("\t") labels.append(line) results = output_lines ks = [1, 5, 10] results_dict = {} for k in ks: acc_cul = 0 r_cul = 0 f1_cul = 0 cnt = 0 for predict, true_label in zip(tqdm(results), tqdm(labels)): predict = predict[:k] true_label = true_label[:k] if len(predict) > 0 and len(true_label) > 0: acc_cul += acc_score(predict, true_label) r_cul += recall_score(predict, true_label) f1_cul += f1_score(acc_score(predict, true_label), recall_score(predict, true_label)) cnt += 1 results_dict["P@{}".format(k)] = acc_cul*1.000 / cnt results_dict["R@{}".format(k)] = r_cul*1.000 / cnt results_dict["F1@{}".format(k)] = f1_cul*1.000 / cnt print(results_dict)