def main(): parser = argparse.ArgumentParser() # parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized') # parser.add_argument('--src_lang', default='en', type=str, help='') # parser.add_argument('--tgt_lang', default='zh', type=str, help='') parser.add_argument( '--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument( '--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') parser.add_argument("--src_file", default='./.pkl', type=str, help="The input data file name.") # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--xml_vocab", type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge", type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) #useless parser.add_argument('--split', type=str, default='val') #wmt? parser.add_argument('--len_vis_input', type=int, default=1, help="The length of visual token input region 1") with open( '/data/private/chenyutong/dataset/concept_count/word_concept_count.pkl', 'rb') as f: word_fre = pickle.load(f) word_fre = defaultdict(int, word_fre) args = parser.parse_args() assert args.batch_size == 1, 'only support batch_size=1' args.max_tgt_length = max(args.max_len_en, args.max_len_zh) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize( x, lang='zh', bypass_tokenizer=False) with open(args.xml_vocab, 'r') as f: tokenizer_zh.vocab = json.load(f) indexer = Indexer( [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab]) with open('full_vocab.json', 'w') as f: json.dump(indexer.ids_to_tokens, f) tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh} print('tokenizer created') assert '.pkl' in args.src_file with open(args.src_file, 'rb') as f: src_data = pickle.load(f) # list [pred_id, vocab, vis, pos, distribution] # dict {'vgid':{'en':,'zh':,'region_features':[img, conf, fea[i], pos[i],dist]}} amp_handle = None if args.amp: from apex import amp # Prepare model cls_num_labels = 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(indexer(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): #logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover model.to(device) if args.amp: model = amp.initialize(model, opt_level='O2') #'02') torch.cuda.empty_cache() model.eval() fout = open(os.path.join(args.output_dir, 'region2txt_output.txt'), 'w') output_lines = [] select_ids = [87, 120, 179, 297, 721, 852, 1025] for step_val, sd in enumerate(src_data.items()): # if step_val>=1: # break vgid, input_item = sd en, zh = input_item['en'], input_item['zh'] fout.writelines('\n' + '#' * 10 + '\n') fout.writelines('{}\n'.format(vgid)) fout.writelines('{} coco: word_fre {} vis_fre {} \n'.format( en, input_item['coco_fre']['word'], input_item['coco_fre']['vis'])) fout.writelines('{} aic: word_fre {} vis_fre {} \n'.format( zh, input_item['aic_fre']['word'], input_item['aic_fre']['vis'])) print('step_val {} Process {}'.format(step_val, en)) for rf in tqdm(input_item['region_features']): filename, conf, vis_feats, vis_pe, cls_label = rf vis_feats = torch.from_numpy(vis_feats).to(device) vis_feats = vis_feats.unsqueeze(0) vis_pe = torch.from_numpy(vis_pe).to(device) vis_pe = vis_pe.unsqueeze(0) cls_label = torch.from_numpy(cls_label).to(device) cls_label = cls_label.unsqueeze(0) # # lazy normalization of the coordinates... copy from seq2seq w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5 h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5 vis_pe[:, [0, 2]] /= w_est vis_pe[:, [1, 3]] /= h_est assert h_est > 0, 'should greater than 0! {}'.format(h_est) assert w_est > 0, 'should greater than 0! {}'.format(w_est) rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0]) rel_area.clamp_(0) vis_pe = torch.cat( (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1) vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \ F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H vis_feats = vis_feats.unsqueeze(0) vis_pe = vis_pe.unsqueeze(0) #print('input shape', vis_feats.shape, vis_pe.shape) assert args.new_segment_ids == False, 'only support 0 1 6 now' tokens = ['[CLS]', '[UNK]', '[SEP]'] input_ids = indexer(tokens) input_ids = np.expand_dims(np.array(input_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) max_len_in_batch = len(tokens) + args.max_tgt_length _tril_matrix = torch.tril( torch.ones((max_len_in_batch, max_len_in_batch), dtype=torch.long)) input_mask = torch.zeros(max_len_in_batch, max_len_in_batch, dtype=torch.long, device=device) input_mask[:, :len(tokens)].fill_(1) second_st, second_end = len(tokens), max_len_in_batch input_mask[second_st:second_end, second_st:second_end].copy_( _tril_matrix[:second_end - second_st, :second_end - second_st]) #L,L input_mask = input_mask.unsqueeze(0) position_ids = torch.arange(max_len_in_batch, dtype=torch.long, device=device) #L position_ids = position_ids.unsqueeze(0) # B,L predictions = { 'en': None, 'zh': None, 'en2zh': None, 'zh2en': None } for tgt_lang, lang_id in zip(['en', 'zh'], [1, 6]): token_type_ids = [0] * len( tokens) + [lang_id] * args.max_tgt_length token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) with torch.no_grad(): # print(token_type_ids[0]) # print(position_ids[0]) # print(input_ids[0]) # print(input_mask[0]) # input() traces = model( vis_feats=vis_feats, vis_pe=vis_pe, input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=input_mask, search_beam_size=args.beam_size, task_idx=3, mode='img2txt', sample_mode='greedy') #validation greedy output_sequence = postprocess(traces, args.beam_size, tgt_lang, indexer) predictions[tgt_lang] = output_sequence #truncate for langs, lang_ids in zip(['en2zh', 'zh2en'], [[1, 6], [6, 1]]): src_lang = langs[:2] #en,zh tgt_lang = langs[-2:] w = predictions[ src_lang] # predictions['en']/ predictions['zh'] w_t = tokenizers[src_lang].tokenize(w) tokens = ['[CLS]'] + w_t + ['[SEP]'] input_ids = indexer(tokens) token_type_ids = [lang_ids[0]] * len( input_ids) + [lang_ids[1]] * args.max_tgt_length input_ids = np.expand_dims(np.array(input_ids), axis=0) token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) max_len_in_batch = len( tokens) + args.max_tgt_length #2+64 = 66 position_ids = torch.arange(max_len_in_batch, dtype=torch.long, device=device) #L position_ids = position_ids.unsqueeze(0) # B,L _tril_matrix = torch.tril( torch.ones((max_len_in_batch, max_len_in_batch), dtype=torch.long)) input_mask = torch.zeros(max_len_in_batch, max_len_in_batch, dtype=torch.long, device=device) input_mask[:, :len(tokens)].fill_(1) second_st, second_end = len(tokens), max_len_in_batch input_mask[second_st:second_end, second_st:second_end].copy_( _tril_matrix[:second_end - second_st, :second_end - second_st]) #L,L input_mask = input_mask.unsqueeze(0) with torch.no_grad(): traces = model( vis_feats=None, vis_pe=None, input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=input_mask, search_beam_size=args.beam_size, task_idx=3, mode='txt2txt', sample_mode='greedy') #validation greedy output_sequence = postprocess(traces, args.beam_size, tgt_lang, indexer) predictions[langs] = output_sequence #print(predictions) fout.writelines( 'conf:{:.2f} en:{: <10} fre:{:<5d} en2zh:{: <10} zh:{: <10} fre:{:<5d} zh2en:{: <10} \n' .format(conf, predictions['en'], word_fre['coco'][predictions['en']], predictions['en2zh'], predictions['zh'], word_fre['aic'][predictions['zh']], predictions['zh2en'])) fout.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='coco', type=str, nargs='*', help='') parser.add_argument('--lang', default='en zh', type=str, nargs='*', help='') parser.add_argument('--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument('--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') parser.add_argument('--max_len_en_cap', default=25, type=int, help='maximum length of English in **img2txt** corpus') parser.add_argument('--max_len_zh_cap', default=25, type=int, help='maximum length of Chinese in **img2txt** corpus') parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") parser.add_argument("--src_file", default='$DATA_ROOT/{}/annotations/{}_dataset.json', type=str, help="The input data file name.") parser.add_argument('--split', type=str, default='val') parser.add_argument('--file_valid_jpgs', default='$DATA_ROOT/{}/annotations/{}_valid_jpgs.json', type=str) parser.add_argument('--image_root', type=str, default='$DATA_ROOT/{}/region_feat_gvd_wo_bgd') parser.add_argument('--region_bbox_file', default='raw_bbox/{}_detection_vg_100dets_vlp_checkpoint_trainval_bbox', type=str) parser.add_argument('--region_det_file_prefix', default='feat_cls_1000/{}_detection_vg_100dets_vlp_checkpoint_trainval', type=str) # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument("--bert_model", default="bert-base-cased", type=str, help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.") parser.add_argument("--xml_vocab",type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge",type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--max_tgt_length', type=int, default=20, help="maximum length of target sequence") parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) args = parser.parse_args() dataset = {} for d in args.dataset: assert d in ['coco','aic','wmt'] if d == 'coco': dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_en_cap} elif d == 'aic': dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_zh_cap} else:# d == 'wmt': dataset[d] = {'max_len_a': args.max_len_en, 'max_len_b': args.max_len_zh} dataset[d]['max_seq_length'] = dataset[d]['max_len_a'] + dataset[d]['max_len_b'] + 3 args.dataset = dataset lang2cap_max_seq_length = {'zh':args.len_vis_input+args.max_len_zh_cap+3, 'en':args.len_vis_input+args.max_len_en_cap+3} if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.enable_butd: assert(args.len_vis_input == 100) args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file) args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir+'/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en tokenizers = {'en':tokenizer_en} if 'aic' in args.dataset or 'wmt' in args.dataset: tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(x, lang='zh', bypass_tokenizer=True) with open(args.xml_vocab,'r') as f: tokenizer_zh.vocab = json.load(f) tokenizers['zh'] = tokenizer_zh if 'coco_g8_lr3e-5_batch512_ft_from_s0.75_b0.25' in args.model_recover_path: indexer = Indexer([os.path.join(args.bert_model,'vocab.txt')]) else: indexer = Indexer([os.path.join(args.bert_model,'vocab.txt'), args.xml_vocab]) for corpus in args.dataset: if corpus in ['coco','aic']: #bilingual for lang in args.lang: tokenizer = tokenizers[lang]#tokenizers['en'] if corpus=='coco' else tokenizers['zh'] max_seq_length = lang2cap_max_seq_length[lang] decode_pipeline= [seq2seq_loader.Preprocess4Seq2seqDecoder( corpus, lang, list(tokenizer.vocab.keys()), indexer, max_len=max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode='s2s', len_vis_input=args.len_vis_input, enable_butd=args.enable_butd, region_bbox_file=args.region_bbox_file.format(corpus.upper(), corpus.lower()), region_det_file_prefix=args.region_det_file_prefix.format(corpus.upper(), corpus.lower()))] eval_dataset = seq2seq_loader.Img2txtDataset( args.src_file.format(corpus.upper(), corpus.lower()), args.image_root.format(corpus.upper()), args.split, args.batch_size, tokenizer, max_seq_length, preprocessed=True, file_valid_jpgs=args.file_valid_jpgs.format(corpus.upper(), corpus.lower()), bi_uni_pipeline=decode_pipeline, use_num_imgs=-1, s2s_prob=1, bi_prob=0, enable_butd=args.enable_butd, tasks='img2txt') args.dataset[corpus][lang+'_eval_dataloader'] = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size, sampler=SequentialSampler(eval_dataset), num_workers=4, collate_fn=batch_list_to_batch_tensors, pin_memory=True) else: raise NotImplementedError # only support aic and coco now amp_handle = None if args.amp: from apex import amp # amp_handle = amp.init(enable_caching=True) # logger.info("enable fp16 with amp") # Prepare model cls_num_labels = 2 #type_vocab_size = 6 if args.new_segment_ids else 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer( ["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(indexer(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover model.to(device) if args.amp: model = amp.initialize(model, opt_level='O2')#'02') torch.cuda.empty_cache() model.eval() for corpus in args.dataset: for lang in ['en','zh']: if not lang+'_eval_dataloader' in args.dataset[corpus]: continue print('corpus {} lang {}'.format(corpus, lang)) output_lines = {} val_iter_bar = tqdm(args.dataset[corpus][lang+'_eval_dataloader']) for step_val, val_iter_output in enumerate(val_iter_bar): info_, batch = val_iter_output[0],val_iter_output[1] with torch.no_grad(): batch = [t.to(device) for t in batch] input_ids, segment_ids, position_ids, input_mask, task_idx, img, vis_pe = batch # if step_val==0: # print(segment_ids[0][100:]) # input() #input() if args.enable_butd: conv_feats = img.data # Bx100x2048 vis_pe = vis_pe.data else: conv_feats, _ = cnn(img.data) # Bx2048x7x7 conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1).permute(0,2,1).contiguous() if args.amp: conv_feats = conv_feats.half() vis_pe = vis_pe.half() traces = model(conv_feats, vis_pe, input_ids, segment_ids, position_ids, input_mask, search_beam_size=args.beam_size, task_idx=task_idx, sample_mode='greedy') #validation greedy if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces[0].tolist() for ii,w_ids in enumerate(output_ids): output_buf = indexer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) output_sequence = ' '.join(detokenize(output_tokens)) if corpus=='coco': id_ = int(info_[ii][2].split('_')[2]) else: id_ = info_[ii][2] #print(id_,output_sequence) output_lines[id_] = output_sequence predictions = [{'image_id': ids_, 'caption': output_lines[ids_]} for ids_ in output_lines] with open(os.path.join(args.output_dir,'{}_{}_{}_predictions.json').format(args.split, corpus, lang),'w') as f: json.dump(predictions, f) if (corpus=='coco' and lang=='en') or (corpus=='aic' and lang=='zh'): print('Begin evaluating '+corpus) lang_stats = language_eval(corpus, predictions, args.model_recover_path.split('/')[-2]+'-'+args.split+'-'+args.model_recover_path.split('/')[-1].split('.')[-2], args.split, ['Bleu','METEOR','Rouge','CIDEr']) with open(os.path.join(args.output_dir,'{}_{}_{}_scores.json').format(args.split, corpus,lang),'w') as f: json.dump(lang_stats, f)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized') parser.add_argument('--src_lang', default='en', type=str, help='') parser.add_argument('--tgt_lang', default='zh', type=str, help='') parser.add_argument( '--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument( '--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') parser.add_argument("--src_file", default='./src.txt', type=str, help="The input data file name.") parser.add_argument('--corpus', default='txt', type=str) parser.add_argument('--en_first', action='store_true', help='always to put english as the first sentence') # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--xml_vocab", type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge", type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) #useless parser.add_argument('--split', type=str, default='val') #wmt? parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") args = parser.parse_args() args.tgt_lang = 'en' if args.src_lang == 'zh' else 'zh' # print(sys.getfilesystemencoding()) # print('这是中文') assert args.batch_size == 1, 'only support batch_size=1' args.max_tgt_length = args.max_len_en if args.src_lang == 'zh' else args.max_len_zh if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize( x, lang='zh', bypass_tokenizer=False) with open(args.xml_vocab, 'r') as f: tokenizer_zh.vocab = json.load(f) indexer = Indexer( [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab]) with open('full_vocab.json', 'w') as f: json.dump(indexer.ids_to_tokens, f) tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh} print('tokenizer created') if '.txt' in args.src_file: with codecs.open(args.src_file, 'r') as f: src_lines = f.readlines() src_lines = [line.strip() for line in src_lines] N_lines = len(src_lines) elif 'hdf5' in args.src_file: assert 'wmt' == args.corpus src_lines = args.src_file N_lines = 1999 else: raise pipeline = seq2seq_loader.Preprocess4Seq2SeqBilingualDecoder( corpus=args.corpus, file_src=src_lines, src_lang=args.src_lang, indexer=indexer, tokenizers=tokenizers, max_len=args.max_len_en + args.max_len_zh + 3, max_tgt_length=args.max_tgt_length, preprocessed=False if args.corpus == 'txt' else True, new_segment_ids=args.new_segment_ids, mode='s2s') eval_dataset = seq2seq_loader.Txt2txtDataset( N_lines=N_lines, split=None, batch_size=args.batch_size, tokenizers=tokenizers, max_len=args.max_len_en + args.max_len_zh + 3, preprocessed=False if args.corpus == 'txt' else True, bi_uni_pipeline=[pipeline], s2s_prob=1, bi_prob=0) eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size, sampler=SequentialSampler(eval_dataset), num_workers=4, collate_fn=batch_list_to_batch_tensors, pin_memory=True) amp_handle = None if args.amp: from apex import amp # Prepare model cls_num_labels = 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 if args.forbid_ignore_word: w_list = [] for w in args.forbid_ignore_word.split('|'): if w.startswith('[') and w.endswith(']'): w_list.append(w.upper()) else: w_list.append(w) forbid_ignore_set = set(indexer(w_list)) print(args.model_recover_path) for model_recover_path in glob.glob(args.model_recover_path.strip()): #logger.info("***** Recover model: %s *****", model_recover_path) model_recover = torch.load(model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=cls_num_labels, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=args.enable_butd, len_vis_input=args.len_vis_input) del model_recover model.to(device) if args.amp: model = amp.initialize(model, opt_level='O2') #'02') torch.cuda.empty_cache() model.eval() val_iter_bar = tqdm(eval_dataloader) output_lines = [] for step_val, val_iter_output in enumerate(val_iter_bar): info_, batch = val_iter_output[0], val_iter_output[1] with torch.no_grad(): batch = [t.to(device) for t in batch] input_ids, segment_ids, position_ids, input_mask, task_idx = batch traces = model(None, None, input_ids, segment_ids, position_ids, input_mask, search_beam_size=args.beam_size, task_idx=task_idx, mode='txt2txt', sample_mode='greedy') #validation greedy # if step_val==0: # print(segment_ids[0]) # print(input_ids[0]) # print(position_ids[0]) # input() if args.beam_size > 1: traces = {k: v.tolist() for k, v in traces.items()} output_ids = traces['pred_seq'] else: output_ids = traces[0].tolist() #print(output_ids) #input() for ii, w_ids in enumerate(output_ids): output_buf = indexer.convert_ids_to_tokens(w_ids) output_tokens = [] for t in output_buf: if t in ("[SEP]", "[PAD]"): break output_tokens.append(t) #print(t) if args.tgt_lang == 'en': output_sequence = ' '.join(detokenize(output_tokens)) output_sequence = output_sequence.replace( ' @ - @ ', '-') #print(id_,output_sequence) #id_ = step_val*args.batch_size+ii #output_sequence = output_sequence.replace('</w>',' ').replace(' ','') if args.tgt_lang == 'zh': output_sequence = ''.join( detokenize(output_tokens)).replace('</w>', '').replace( '[SEP]', '') output_lines.append(output_sequence) # with open(os.path.join(args.output_dir,'translation_output.json'),'w') as f: # json.dump(output_lines, f) with open(os.path.join(args.output_dir, 'translation_output.txt'), 'w') as f: for line in output_lines: f.writelines(line + '\n')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--enable_vis', action='store_true', help='whether to visualize visual embedding') parser.add_argument('--num_concept', default='-1', type=int, help='number of concepts to visualize') parser.add_argument('--num_vis', default='1', type=int, help='number of visual embeddings per concept') parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized') parser.add_argument('--src_lang', default='en', type=str, help='') parser.add_argument('--tgt_lang', default='zh', type=str, help='') parser.add_argument( '--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus') parser.add_argument( '--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus') #parser.add_argument("--vocab_file", default='./src.txt', type=str, help="The input data file name.") parser.add_argument('--vocab_file', type=str, required=True, nargs='+') parser.add_argument('--en_first', action='store_true', help='always to put english as the first sentence') # General parser.add_argument("--config_path", default=None, type=str, help="Bert config file path.") parser.add_argument( "--bert_model", default="bert-base-cased", type=str, help= "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased." ) parser.add_argument("--xml_vocab", type=str, default='./download_models/xml_vocab.json') parser.add_argument("--xml_merge", type=str, default='./download_models/xml_merges.txt') parser.add_argument("--model_recover_path", default=None, type=str, help="The file of fine-tuned pretraining model.") parser.add_argument('--max_position_embeddings', type=int, default=512, help="max position embeddings") # For decoding #parser.add_argument('--fp16', action='store_true', # help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--amp', action='store_true', help="Whether to use amp for fp16") parser.add_argument('--seed', type=int, default=123, help="random seed for initialization") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--new_segment_ids', action='store_true', help="Use new segment ids for bi-uni-directional LM.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for decoding.") parser.add_argument('--beam_size', type=int, default=1, help="Beam size for searching") parser.add_argument('--length_penalty', type=float, default=0, help="Length penalty for beam search") parser.add_argument('--forbid_duplicate_ngrams', action='store_true') parser.add_argument('--forbid_ignore_word', type=str, default=None, help="Forbid the word during forbid_duplicate_ngrams") parser.add_argument("--min_len", default=None, type=int) parser.add_argument('--ngram_size', type=int, default=3) parser.add_argument('--drop_prob', default=0.1, type=float) parser.add_argument('--enable_butd', action='store_true', help='set to take in region features') parser.add_argument('--output_dir', default='./result', type=str) #useless parser.add_argument('--split', type=str, default='val') #wmt? parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input") args = parser.parse_args() print(args.vocab_file) if args.enable_vis or '.pkl' in args.vocab_file[0]: args.vocab_file = args.vocab_file[0] assert '.pkl' in args.vocab_file, args.vocab_file if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # fix random seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) tokenizer_en = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case, cache_dir=args.output_dir + '/.pretrained_model') if args.max_position_embeddings: tokenizer_en.max_len = args.max_position_embeddings tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge) tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize( x, lang='zh', bypass_tokenizer=False) with open(args.xml_vocab, 'r') as f: tokenizer_zh.vocab = json.load(f) indexer = Indexer( [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab]) with open('full_vocab.json', 'w') as f: json.dump(indexer.ids_to_tokens, f) tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh} print('tokenizer created') concept_list = [] if args.enable_vis or 'pkl' in args.vocab_file: with open(args.vocab_file, 'rb') as f: vocab_list = pickle.load(f) vocab_list = vocab_list[::-1] if args.num_concept == -1: args.num_concept = len(vocab_list) vocab = [] #[ [En, Zh, Vis1(list), Vis2, Vis3, Vis4,..] *num_concept] for key, ls in vocab_list[40:40 + args.num_concept]: concept = [ls[0][0], ls[0][1]] #En, Zh concept_list.append(ls[0][0]) for inst in ls[:args.num_vis]: concept.append(inst[2:]) #2,3,4 vocab.append(concept) print('Number of Concept {}'.format(len(vocab))) if args.num_concept != -1: print(concept_list) print('Number of visual instance per concept {}'.format( len(vocab[0]) - 2)) #print('Example {}'.format(vocab[0])) #input() else: vocab = [] for filename in args.vocab_file: with open(filename) as f: v = f.readlines() if args.num_concept == -1: args.num_concept = len(v) v = [(a.split(' ')[0].strip(), a.split(' ')[-1].strip()) for a in v[:args.num_concept]] #EN ZH vocab.extend(v) print('Number of vocabulary {} {}'.format(len(vocab), vocab[0])) cls_num_labels = 2 type_vocab_size = 12 if args.new_segment_ids else 12 mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"]) forbid_ignore_set = None #default None relax_projection, task_idx_proj = 0, 3 print(args.model_recover_path) model_recover = torch.load(args.model_recover_path) model = BertForSeq2SeqDecoder.from_pretrained( args.bert_model, max_position_embeddings=args.max_position_embeddings, config_path=args.config_path, state_dict=model_recover, num_labels=2, vocab_size=len(indexer), type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt search_beam_size=args.beam_size, length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, enable_butd=True, len_vis_input=args.len_vis_input) del model_recover model.to(device) torch.cuda.empty_cache() model.eval() N_layer = 12 embeddings = OrderedDict({ 'en': [[] for i in range(N_layer)], 'zh': [[] for i in range(N_layer)] }) if args.enable_vis: embeddings['vis'] = [[] for i in range(N_layer)] for _, pair in tqdm(enumerate(vocab)): for w, lang in zip(pair, ('en', 'zh')): segment_id = 1 if lang == 'en' else 6 w_t = tokenizers[lang].tokenize(w) tokens = ['[CLS]'] + w_t + ['[SEP]'] input_ids = indexer(tokens) token_type_ids = [segment_id] * len(input_ids) input_ids = np.expand_dims(np.array(input_ids), axis=0) token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) output_embeddings = model.compute_embeddings( input_ids=input_ids, token_type_ids=token_type_ids, mode='txt2txt') #tuple 12 1,L,768 for i, e in enumerate(output_embeddings): e = e.detach().cpu().numpy() ave = np.mean(e[0, 1:-1, :], axis=0) # 768 embeddings[lang][i].append(ave) if args.enable_vis: instance_embeddings = [[] for layer_i in range(N_layer)] for vis_embed in pair[2:]: vis_feats, vis_pe, cls_label = vis_embed[0], vis_embed[ 1], vis_embed[2] #1024 vis_feats = torch.from_numpy(vis_feats).to(device) vis_feats = vis_feats.unsqueeze(0) vis_pe = torch.from_numpy(vis_pe).to(device) vis_pe = vis_pe.unsqueeze(0) cls_label = torch.from_numpy(cls_label).to(device) cls_label = cls_label.unsqueeze(0) # # lazy normalization of the coordinates... copy from seq2seq w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5 h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5 vis_pe[:, [0, 2]] /= w_est vis_pe[:, [1, 3]] /= h_est assert h_est > 0, 'should greater than 0! {}'.format(h_est) assert w_est > 0, 'should greater than 0! {}'.format(w_est) rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0]) rel_area.clamp_(0) vis_pe = torch.cat( (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1) vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \ F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H vis_feats = vis_feats.unsqueeze(0) vis_pe = vis_pe.unsqueeze(0) #print('input shape', vis_feats.shape, vis_pe.shape) segment_id = 0 tokens = ['[CLS]', '[UNK]', '[SEP]'] input_ids = indexer(tokens) token_type_ids = [segment_id] * len(input_ids) input_ids = np.expand_dims(np.array(input_ids), axis=0) token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0) input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long, device=device) vis_embeddings = model.compute_embeddings( vis_feats=vis_feats, vis_pe=vis_pe, input_ids=input_ids, token_type_ids=token_type_ids, mode='img2txt', len_vis_input=1) #print(len(vis_embeddings), vis_embeddings[0].shape) #input() for i, e in enumerate(vis_embeddings): e = e.detach().cpu().numpy() ave = np.mean(e[0, 1:-1, :], axis=0) # 768 instance_embeddings[i].append(ave) #embeddings['vis'][i].append(ave) for i, embed_list in enumerate(instance_embeddings): if args.num_vis == 1: embeddings['vis'][i].append(embed_list[0]) else: embeddings['vis'][i].append(embed_list) #list of array # args.output_dir = os.path.join(args.output_dir, 'embedding_vis') # if not os.path.exists(args.output_dir): # os.makedirs(args.output_dir) for ly in range(N_layer): embed = {'en': embeddings['en'][ly], 'zh': embeddings['zh'][ly]} if args.enable_vis: embed['vis'] = embeddings['vis'][ly] #save_numpy(embed,os.path.join(args.output_dir,'hiddenstates_layer_{}.npy'.format(ly))) if args.num_vis == 1: tSNE_reduce( embed, os.path.join(args.output_dir, 'tSNE layer {}'.format(ly))) Cosine_Sim( embed, os.path.join(args.output_dir, 'CosineSim Layer {}'.format(ly))) sim_histgram( embed, os.path.join(args.output_dir, 'Histgram Layer {}'.format(ly))) if args.num_vis > 1: tSNE_reduce_visual( concept_list, embed, os.path.join(args.output_dir, 'tSNE layer {}'.format(ly))) print('Save layer {}'.format(ly))