Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized')
    # parser.add_argument('--src_lang', default='en', type=str, help='')
    # parser.add_argument('--tgt_lang', default='zh', type=str, help='')
    parser.add_argument(
        '--max_len_en',
        default=25,
        type=int,
        help='maximum length of English in **bilingual** corpus')
    parser.add_argument(
        '--max_len_zh',
        default=25,
        type=int,
        help='maximum length of Chinese in **bilingual** corpus')
    parser.add_argument("--src_file",
                        default='./.pkl',
                        type=str,
                        help="The input data file name.")

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--xml_vocab",
                        type=str,
                        default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",
                        type=str,
                        default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
    #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)

    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    #useless
    parser.add_argument('--split', type=str, default='val')  #wmt?
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=1,
                        help="The length of visual token input region 1")

    with open(
            '/data/private/chenyutong/dataset/concept_count/word_concept_count.pkl',
            'rb') as f:
        word_fre = pickle.load(f)
    word_fre = defaultdict(int, word_fre)

    args = parser.parse_args()

    assert args.batch_size == 1, 'only support batch_size=1'
    args.max_tgt_length = max(args.max_len_en, args.max_len_zh)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en
    tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
    tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(
        x, lang='zh', bypass_tokenizer=False)
    with open(args.xml_vocab, 'r') as f:
        tokenizer_zh.vocab = json.load(f)
    indexer = Indexer(
        [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab])
    with open('full_vocab.json', 'w') as f:
        json.dump(indexer.ids_to_tokens, f)
    tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh}
    print('tokenizer created')

    assert '.pkl' in args.src_file
    with open(args.src_file, 'rb') as f:
        src_data = pickle.load(f)
    # list [pred_id, vocab, vis, pos, distribution]
    # dict {'vgid':{'en':,'zh':,'region_features':[img, conf, fea[i], pos[i],dist]}}
    amp_handle = None
    if args.amp:
        from apex import amp

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"])
    forbid_ignore_set = None  #default None
    relax_projection, task_idx_proj = 0, 3
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(indexer(w_list))

    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        #logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            max_position_embeddings=args.max_position_embeddings,
            config_path=args.config_path,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            vocab_size=len(indexer),
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,  #img2txt
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input)

        del model_recover

        model.to(device)

        if args.amp:
            model = amp.initialize(model, opt_level='O2')  #'02')
        torch.cuda.empty_cache()
        model.eval()

        fout = open(os.path.join(args.output_dir, 'region2txt_output.txt'),
                    'w')
        output_lines = []
        select_ids = [87, 120, 179, 297, 721, 852, 1025]
        for step_val, sd in enumerate(src_data.items()):
            # if step_val>=1:
            #     break
            vgid, input_item = sd
            en, zh = input_item['en'], input_item['zh']
            fout.writelines('\n' + '#' * 10 + '\n')
            fout.writelines('{}\n'.format(vgid))
            fout.writelines('{} coco: word_fre {}  vis_fre {} \n'.format(
                en, input_item['coco_fre']['word'],
                input_item['coco_fre']['vis']))
            fout.writelines('{} aic: word_fre {}  vis_fre {} \n'.format(
                zh, input_item['aic_fre']['word'],
                input_item['aic_fre']['vis']))
            print('step_val {} Process {}'.format(step_val, en))
            for rf in tqdm(input_item['region_features']):
                filename, conf, vis_feats, vis_pe, cls_label = rf
                vis_feats = torch.from_numpy(vis_feats).to(device)
                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = torch.from_numpy(vis_pe).to(device)
                vis_pe = vis_pe.unsqueeze(0)
                cls_label = torch.from_numpy(cls_label).to(device)
                cls_label = cls_label.unsqueeze(0)  #
                # lazy normalization of the coordinates... copy from seq2seq
                w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5
                h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5
                vis_pe[:, [0, 2]] /= w_est
                vis_pe[:, [1, 3]] /= h_est
                assert h_est > 0, 'should greater than 0! {}'.format(h_est)
                assert w_est > 0, 'should greater than 0! {}'.format(w_est)
                rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] -
                                                            vis_pe[:, 0])
                rel_area.clamp_(0)

                vis_pe = torch.cat(
                    (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]),
                    -1)  # confident score
                normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5,
                                               dim=-1)
                vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \
                    F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H

                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = vis_pe.unsqueeze(0)
                #print('input shape', vis_feats.shape, vis_pe.shape)
                assert args.new_segment_ids == False, 'only support 0 1 6 now'
                tokens = ['[CLS]', '[UNK]', '[SEP]']
                input_ids = indexer(tokens)
                input_ids = np.expand_dims(np.array(input_ids), axis=0)
                input_ids = torch.tensor(input_ids,
                                         dtype=torch.long,
                                         device=device)

                max_len_in_batch = len(tokens) + args.max_tgt_length
                _tril_matrix = torch.tril(
                    torch.ones((max_len_in_batch, max_len_in_batch),
                               dtype=torch.long))
                input_mask = torch.zeros(max_len_in_batch,
                                         max_len_in_batch,
                                         dtype=torch.long,
                                         device=device)
                input_mask[:, :len(tokens)].fill_(1)
                second_st, second_end = len(tokens), max_len_in_batch
                input_mask[second_st:second_end, second_st:second_end].copy_(
                    _tril_matrix[:second_end - second_st, :second_end -
                                 second_st])  #L,L
                input_mask = input_mask.unsqueeze(0)

                position_ids = torch.arange(max_len_in_batch,
                                            dtype=torch.long,
                                            device=device)  #L
                position_ids = position_ids.unsqueeze(0)  # B,L

                predictions = {
                    'en': None,
                    'zh': None,
                    'en2zh': None,
                    'zh2en': None
                }
                for tgt_lang, lang_id in zip(['en', 'zh'], [1, 6]):
                    token_type_ids = [0] * len(
                        tokens) + [lang_id] * args.max_tgt_length
                    token_type_ids = np.expand_dims(np.array(token_type_ids),
                                                    axis=0)
                    token_type_ids = torch.tensor(token_type_ids,
                                                  dtype=torch.long,
                                                  device=device)
                    with torch.no_grad():
                        # print(token_type_ids[0])
                        # print(position_ids[0])
                        # print(input_ids[0])
                        # print(input_mask[0])
                        # input()
                        traces = model(
                            vis_feats=vis_feats,
                            vis_pe=vis_pe,
                            input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            attention_mask=input_mask,
                            search_beam_size=args.beam_size,
                            task_idx=3,
                            mode='img2txt',
                            sample_mode='greedy')  #validation greedy

                    output_sequence = postprocess(traces, args.beam_size,
                                                  tgt_lang, indexer)
                    predictions[tgt_lang] = output_sequence  #truncate

                for langs, lang_ids in zip(['en2zh', 'zh2en'],
                                           [[1, 6], [6, 1]]):
                    src_lang = langs[:2]  #en,zh
                    tgt_lang = langs[-2:]
                    w = predictions[
                        src_lang]  # predictions['en']/ predictions['zh']
                    w_t = tokenizers[src_lang].tokenize(w)
                    tokens = ['[CLS]'] + w_t + ['[SEP]']
                    input_ids = indexer(tokens)
                    token_type_ids = [lang_ids[0]] * len(
                        input_ids) + [lang_ids[1]] * args.max_tgt_length
                    input_ids = np.expand_dims(np.array(input_ids), axis=0)
                    token_type_ids = np.expand_dims(np.array(token_type_ids),
                                                    axis=0)
                    input_ids = torch.tensor(input_ids,
                                             dtype=torch.long,
                                             device=device)
                    token_type_ids = torch.tensor(token_type_ids,
                                                  dtype=torch.long,
                                                  device=device)

                    max_len_in_batch = len(
                        tokens) + args.max_tgt_length  #2+64 = 66
                    position_ids = torch.arange(max_len_in_batch,
                                                dtype=torch.long,
                                                device=device)  #L
                    position_ids = position_ids.unsqueeze(0)  # B,L
                    _tril_matrix = torch.tril(
                        torch.ones((max_len_in_batch, max_len_in_batch),
                                   dtype=torch.long))
                    input_mask = torch.zeros(max_len_in_batch,
                                             max_len_in_batch,
                                             dtype=torch.long,
                                             device=device)
                    input_mask[:, :len(tokens)].fill_(1)
                    second_st, second_end = len(tokens), max_len_in_batch
                    input_mask[second_st:second_end,
                               second_st:second_end].copy_(
                                   _tril_matrix[:second_end -
                                                second_st, :second_end -
                                                second_st])  #L,L
                    input_mask = input_mask.unsqueeze(0)
                    with torch.no_grad():
                        traces = model(
                            vis_feats=None,
                            vis_pe=None,
                            input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            attention_mask=input_mask,
                            search_beam_size=args.beam_size,
                            task_idx=3,
                            mode='txt2txt',
                            sample_mode='greedy')  #validation greedy
                    output_sequence = postprocess(traces, args.beam_size,
                                                  tgt_lang, indexer)
                    predictions[langs] = output_sequence

                #print(predictions)
                fout.writelines(
                    'conf:{:.2f} en:{: <10} fre:{:<5d} en2zh:{: <10} zh:{: <10} fre:{:<5d} zh2en:{: <10} \n'
                    .format(conf, predictions['en'],
                            word_fre['coco'][predictions['en']],
                            predictions['en2zh'], predictions['zh'],
                            word_fre['aic'][predictions['zh']],
                            predictions['zh2en']))

        fout.close()
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='coco', type=str, nargs='*', help='')
    parser.add_argument('--lang', default='en zh', type=str, nargs='*', help='')
    parser.add_argument('--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus')
    parser.add_argument('--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus')
    parser.add_argument('--max_len_en_cap', default=25, type=int, help='maximum length of English in **img2txt** corpus')
    parser.add_argument('--max_len_zh_cap', default=25, type=int, help='maximum length of Chinese in **img2txt** corpus')
    parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input")
    parser.add_argument("--src_file", default='$DATA_ROOT/{}/annotations/{}_dataset.json',
                        type=str, help="The input data file name.")
    parser.add_argument('--split', type=str, default='val')
    parser.add_argument('--file_valid_jpgs', default='$DATA_ROOT/{}/annotations/{}_valid_jpgs.json', type=str)
    parser.add_argument('--image_root', type=str, default='$DATA_ROOT/{}/region_feat_gvd_wo_bgd')
    parser.add_argument('--region_bbox_file', default='raw_bbox/{}_detection_vg_100dets_vlp_checkpoint_trainval_bbox', type=str)
    parser.add_argument('--region_det_file_prefix', default='feat_cls_1000/{}_detection_vg_100dets_vlp_checkpoint_trainval', type=str)

    # General
    parser.add_argument("--config_path", default=None, type=str,
                        help="Bert config file path.")
    parser.add_argument("--bert_model", default="bert-base-cased", type=str,
                        help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.")
    parser.add_argument("--xml_vocab",type=str, default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",type=str, default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings', type=int, default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
     #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed', type=int, default=123,
                        help="random seed for initialization")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids', action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty', type=float, default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word', type=str, default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--max_tgt_length', type=int, default=20,
                        help="maximum length of target sequence")


    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd', action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    args = parser.parse_args()
    dataset = {}
    for d in args.dataset:
        assert d in ['coco','aic','wmt']
        if d == 'coco':
            dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_en_cap}
        elif d == 'aic':
            dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_zh_cap}
        else:# d == 'wmt':
            dataset[d] = {'max_len_a': args.max_len_en, 'max_len_b': args.max_len_zh}
        dataset[d]['max_seq_length'] = dataset[d]['max_len_a'] + dataset[d]['max_len_b'] + 3
    args.dataset = dataset
    lang2cap_max_seq_length = {'zh':args.len_vis_input+args.max_len_zh_cap+3, 'en':args.len_vis_input+args.max_len_en_cap+3}

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    if args.enable_butd:
        assert(args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) 

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir+'/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en
    tokenizers = {'en':tokenizer_en}
    if 'aic' in args.dataset or 'wmt' in args.dataset:
        tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
        tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(x, lang='zh', bypass_tokenizer=True)
        with open(args.xml_vocab,'r') as f:
            tokenizer_zh.vocab = json.load(f)
        tokenizers['zh'] = tokenizer_zh
    if 'coco_g8_lr3e-5_batch512_ft_from_s0.75_b0.25' in args.model_recover_path:
        indexer = Indexer([os.path.join(args.bert_model,'vocab.txt')])
    else:
        indexer = Indexer([os.path.join(args.bert_model,'vocab.txt'), args.xml_vocab])

    for corpus in args.dataset:
        if corpus in ['coco','aic']:
            #bilingual
            for lang in args.lang:
                tokenizer = tokenizers[lang]#tokenizers['en'] if corpus=='coco' else tokenizers['zh']
                max_seq_length = lang2cap_max_seq_length[lang]
                decode_pipeline= [seq2seq_loader.Preprocess4Seq2seqDecoder(
                    corpus,  
                    lang,
                    list(tokenizer.vocab.keys()), indexer, 
                    max_len=max_seq_length,
                    max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids,
                    mode='s2s', len_vis_input=args.len_vis_input, enable_butd=args.enable_butd,
                    region_bbox_file=args.region_bbox_file.format(corpus.upper(), corpus.lower()), 
                    region_det_file_prefix=args.region_det_file_prefix.format(corpus.upper(), corpus.lower()))]
                eval_dataset = seq2seq_loader.Img2txtDataset(
                                        args.src_file.format(corpus.upper(), corpus.lower()),
                                        args.image_root.format(corpus.upper()), 
                                        args.split, args.batch_size,
                                        tokenizer,
                                        max_seq_length, 
                                        preprocessed=True,
                                        file_valid_jpgs=args.file_valid_jpgs.format(corpus.upper(), corpus.lower()),
                                        bi_uni_pipeline=decode_pipeline, use_num_imgs=-1,
                                        s2s_prob=1, bi_prob=0,
                                        enable_butd=args.enable_butd, tasks='img2txt')
                args.dataset[corpus][lang+'_eval_dataloader'] = torch.utils.data.DataLoader(
                            eval_dataset, batch_size=args.batch_size,
                            sampler=SequentialSampler(eval_dataset), num_workers=4, 
                            collate_fn=batch_list_to_batch_tensors, pin_memory=True)
        else:
            raise NotImplementedError # only support aic and coco now

    amp_handle = None
    if args.amp:
        from apex import amp
    #    amp_handle = amp.init(enable_caching=True)
    #    logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    #type_vocab_size = 6 if args.new_segment_ids else 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(
        ["[MASK]", "[SEP]"])
    forbid_ignore_set = None #default None
    relax_projection, task_idx_proj = 0, 3
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(indexer(w_list))
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model,
            max_position_embeddings=args.max_position_embeddings, config_path=args.config_path,
            state_dict=model_recover, num_labels=cls_num_labels,
            vocab_size=len(indexer),
            type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt
            search_beam_size=args.beam_size, length_penalty=args.length_penalty,
            eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len,
            enable_butd=args.enable_butd, len_vis_input=args.len_vis_input)

        del model_recover

        model.to(device)

        if args.amp:
            model = amp.initialize(model, opt_level='O2')#'02')
        torch.cuda.empty_cache()
        model.eval()
        for corpus in args.dataset:
            for lang in ['en','zh']:
                if not lang+'_eval_dataloader' in args.dataset[corpus]:
                    continue
                print('corpus {} lang {}'.format(corpus, lang))
                output_lines = {}
                val_iter_bar = tqdm(args.dataset[corpus][lang+'_eval_dataloader'])
                for step_val, val_iter_output in enumerate(val_iter_bar):
                    info_, batch = val_iter_output[0],val_iter_output[1]
                    with torch.no_grad():
                        batch = [t.to(device) for t in batch]
                        input_ids, segment_ids, position_ids, input_mask, task_idx, img, vis_pe = batch
                        # if step_val==0:
                        #     print(segment_ids[0][100:])
                        #     input()
                        #input()
                        if args.enable_butd:
                            conv_feats = img.data # Bx100x2048
                            vis_pe = vis_pe.data
                        else:
                            conv_feats, _ = cnn(img.data) # Bx2048x7x7
                            conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1),
                                -1).permute(0,2,1).contiguous()
                        if args.amp:
                            conv_feats = conv_feats.half()
                            vis_pe = vis_pe.half()

                        traces = model(conv_feats, vis_pe, input_ids, segment_ids, position_ids, input_mask, 
                            search_beam_size=args.beam_size, task_idx=task_idx, sample_mode='greedy') #validation greedy
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces[0].tolist()
                        for ii,w_ids in enumerate(output_ids):
                            output_buf = indexer.convert_ids_to_tokens(w_ids)
                            output_tokens = []
                            for t in output_buf:
                                if t in ("[SEP]", "[PAD]"):
                                    break
                                output_tokens.append(t)
                            output_sequence = ' '.join(detokenize(output_tokens))
                            if corpus=='coco':
                                id_ = int(info_[ii][2].split('_')[2])
                            else:
                                id_ = info_[ii][2]
                            #print(id_,output_sequence)
                            output_lines[id_] = output_sequence
                predictions = [{'image_id': ids_, 'caption': output_lines[ids_]} for ids_ in output_lines]
                with open(os.path.join(args.output_dir,'{}_{}_{}_predictions.json').format(args.split, corpus, lang),'w') as f:
                    json.dump(predictions, f)
                if (corpus=='coco' and lang=='en') or (corpus=='aic' and lang=='zh'):
                    print('Begin evaluating '+corpus)
                    lang_stats = language_eval(corpus, predictions, 
                        args.model_recover_path.split('/')[-2]+'-'+args.split+'-'+args.model_recover_path.split('/')[-1].split('.')[-2], 
                        args.split,
                        ['Bleu','METEOR','Rouge','CIDEr'])
                    with open(os.path.join(args.output_dir,'{}_{}_{}_scores.json').format(args.split, corpus,lang),'w') as f:
                        json.dump(lang_stats, f)
Пример #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        default='txt',
                        type=str,
                        help='txt -> self-customized')
    parser.add_argument('--src_lang', default='en', type=str, help='')
    parser.add_argument('--tgt_lang', default='zh', type=str, help='')
    parser.add_argument(
        '--max_len_en',
        default=25,
        type=int,
        help='maximum length of English in **bilingual** corpus')
    parser.add_argument(
        '--max_len_zh',
        default=25,
        type=int,
        help='maximum length of Chinese in **bilingual** corpus')
    parser.add_argument("--src_file",
                        default='./src.txt',
                        type=str,
                        help="The input data file name.")
    parser.add_argument('--corpus', default='txt', type=str)
    parser.add_argument('--en_first',
                        action='store_true',
                        help='always to put english as the first sentence')

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--xml_vocab",
                        type=str,
                        default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",
                        type=str,
                        default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
    #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)

    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    #useless
    parser.add_argument('--split', type=str, default='val')  #wmt?
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")

    args = parser.parse_args()
    args.tgt_lang = 'en' if args.src_lang == 'zh' else 'zh'
    # print(sys.getfilesystemencoding())
    # print('这是中文')
    assert args.batch_size == 1, 'only support batch_size=1'
    args.max_tgt_length = args.max_len_en if args.src_lang == 'zh' else args.max_len_zh

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en
    tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
    tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(
        x, lang='zh', bypass_tokenizer=False)
    with open(args.xml_vocab, 'r') as f:
        tokenizer_zh.vocab = json.load(f)
    indexer = Indexer(
        [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab])
    with open('full_vocab.json', 'w') as f:
        json.dump(indexer.ids_to_tokens, f)
    tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh}

    print('tokenizer created')

    if '.txt' in args.src_file:
        with codecs.open(args.src_file, 'r') as f:
            src_lines = f.readlines()
        src_lines = [line.strip() for line in src_lines]
        N_lines = len(src_lines)
    elif 'hdf5' in args.src_file:
        assert 'wmt' == args.corpus
        src_lines = args.src_file
        N_lines = 1999
    else:
        raise

    pipeline = seq2seq_loader.Preprocess4Seq2SeqBilingualDecoder(
        corpus=args.corpus,
        file_src=src_lines,
        src_lang=args.src_lang,
        indexer=indexer,
        tokenizers=tokenizers,
        max_len=args.max_len_en + args.max_len_zh + 3,
        max_tgt_length=args.max_tgt_length,
        preprocessed=False if args.corpus == 'txt' else True,
        new_segment_ids=args.new_segment_ids,
        mode='s2s')

    eval_dataset = seq2seq_loader.Txt2txtDataset(
        N_lines=N_lines,
        split=None,
        batch_size=args.batch_size,
        tokenizers=tokenizers,
        max_len=args.max_len_en + args.max_len_zh + 3,
        preprocessed=False if args.corpus == 'txt' else True,
        bi_uni_pipeline=[pipeline],
        s2s_prob=1,
        bi_prob=0)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        sampler=SequentialSampler(eval_dataset),
        num_workers=4,
        collate_fn=batch_list_to_batch_tensors,
        pin_memory=True)

    amp_handle = None
    if args.amp:
        from apex import amp

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"])
    forbid_ignore_set = None  #default None
    relax_projection, task_idx_proj = 0, 3
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(indexer(w_list))

    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        #logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            max_position_embeddings=args.max_position_embeddings,
            config_path=args.config_path,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            vocab_size=len(indexer),
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,  #img2txt
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input)

        del model_recover

        model.to(device)

        if args.amp:
            model = amp.initialize(model, opt_level='O2')  #'02')
        torch.cuda.empty_cache()
        model.eval()
        val_iter_bar = tqdm(eval_dataloader)
        output_lines = []
        for step_val, val_iter_output in enumerate(val_iter_bar):
            info_, batch = val_iter_output[0], val_iter_output[1]
            with torch.no_grad():
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, position_ids, input_mask, task_idx = batch
                traces = model(None,
                               None,
                               input_ids,
                               segment_ids,
                               position_ids,
                               input_mask,
                               search_beam_size=args.beam_size,
                               task_idx=task_idx,
                               mode='txt2txt',
                               sample_mode='greedy')  #validation greedy
                # if step_val==0:
                #     print(segment_ids[0])
                #     print(input_ids[0])
                #     print(position_ids[0])
                #     input()

                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                else:
                    output_ids = traces[0].tolist()
                #print(output_ids)
                #input()
                for ii, w_ids in enumerate(output_ids):
                    output_buf = indexer.convert_ids_to_tokens(w_ids)
                    output_tokens = []
                    for t in output_buf:
                        if t in ("[SEP]", "[PAD]"):
                            break
                        output_tokens.append(t)
                        #print(t)
                    if args.tgt_lang == 'en':
                        output_sequence = ' '.join(detokenize(output_tokens))
                        output_sequence = output_sequence.replace(
                            ' @ - @ ', '-')
                    #print(id_,output_sequence)
                    #id_ = step_val*args.batch_size+ii
                    #output_sequence = output_sequence.replace('</w>',' ').replace(' ','')
                    if args.tgt_lang == 'zh':
                        output_sequence = ''.join(
                            detokenize(output_tokens)).replace('</w>',
                                                               '').replace(
                                                                   '[SEP]', '')
                    output_lines.append(output_sequence)

        # with open(os.path.join(args.output_dir,'translation_output.json'),'w') as f:
        #     json.dump(output_lines, f)
        with open(os.path.join(args.output_dir, 'translation_output.txt'),
                  'w') as f:
            for line in output_lines:
                f.writelines(line + '\n')
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--enable_vis',
                        action='store_true',
                        help='whether to visualize visual embedding')
    parser.add_argument('--num_concept',
                        default='-1',
                        type=int,
                        help='number of concepts to visualize')
    parser.add_argument('--num_vis',
                        default='1',
                        type=int,
                        help='number of visual embeddings per concept')
    parser.add_argument('--dataset',
                        default='txt',
                        type=str,
                        help='txt -> self-customized')
    parser.add_argument('--src_lang', default='en', type=str, help='')
    parser.add_argument('--tgt_lang', default='zh', type=str, help='')
    parser.add_argument(
        '--max_len_en',
        default=25,
        type=int,
        help='maximum length of English in **bilingual** corpus')
    parser.add_argument(
        '--max_len_zh',
        default=25,
        type=int,
        help='maximum length of Chinese in **bilingual** corpus')
    #parser.add_argument("--vocab_file", default='./src.txt', type=str, help="The input data file name.")
    parser.add_argument('--vocab_file', type=str, required=True, nargs='+')
    parser.add_argument('--en_first',
                        action='store_true',
                        help='always to put english as the first sentence')

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--xml_vocab",
                        type=str,
                        default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",
                        type=str,
                        default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
    #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)

    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    #useless
    parser.add_argument('--split', type=str, default='val')  #wmt?
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")

    args = parser.parse_args()
    print(args.vocab_file)
    if args.enable_vis or '.pkl' in args.vocab_file[0]:
        args.vocab_file = args.vocab_file[0]
        assert '.pkl' in args.vocab_file, args.vocab_file

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
    tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(
        x, lang='zh', bypass_tokenizer=False)
    with open(args.xml_vocab, 'r') as f:
        tokenizer_zh.vocab = json.load(f)
    indexer = Indexer(
        [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab])
    with open('full_vocab.json', 'w') as f:
        json.dump(indexer.ids_to_tokens, f)
    tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh}
    print('tokenizer created')

    concept_list = []
    if args.enable_vis or 'pkl' in args.vocab_file:
        with open(args.vocab_file, 'rb') as f:
            vocab_list = pickle.load(f)
        vocab_list = vocab_list[::-1]
        if args.num_concept == -1:
            args.num_concept = len(vocab_list)
        vocab = []  #[ [En, Zh, Vis1(list), Vis2, Vis3, Vis4,..] *num_concept]
        for key, ls in vocab_list[40:40 + args.num_concept]:
            concept = [ls[0][0], ls[0][1]]  #En, Zh
            concept_list.append(ls[0][0])
            for inst in ls[:args.num_vis]:
                concept.append(inst[2:])  #2,3,4
            vocab.append(concept)
        print('Number of Concept {}'.format(len(vocab)))
        if args.num_concept != -1:
            print(concept_list)
        print('Number of visual instance per concept {}'.format(
            len(vocab[0]) - 2))
        #print('Example {}'.format(vocab[0]))
        #input()
    else:
        vocab = []
        for filename in args.vocab_file:
            with open(filename) as f:
                v = f.readlines()
            if args.num_concept == -1:
                args.num_concept = len(v)
            v = [(a.split('    ')[0].strip(), a.split('    ')[-1].strip())
                 for a in v[:args.num_concept]]  #EN ZH
            vocab.extend(v)
        print('Number of vocabulary {} {}'.format(len(vocab), vocab[0]))

    cls_num_labels = 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"])
    forbid_ignore_set = None  #default None
    relax_projection, task_idx_proj = 0, 3
    print(args.model_recover_path)
    model_recover = torch.load(args.model_recover_path)
    model = BertForSeq2SeqDecoder.from_pretrained(
        args.bert_model,
        max_position_embeddings=args.max_position_embeddings,
        config_path=args.config_path,
        state_dict=model_recover,
        num_labels=2,
        vocab_size=len(indexer),
        type_vocab_size=type_vocab_size,
        task_idx=3,
        mask_word_id=mask_word_id,  #img2txt
        search_beam_size=args.beam_size,
        length_penalty=args.length_penalty,
        eos_id=eos_word_ids,
        forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
        forbid_ignore_set=forbid_ignore_set,
        ngram_size=args.ngram_size,
        min_len=args.min_len,
        enable_butd=True,
        len_vis_input=args.len_vis_input)
    del model_recover
    model.to(device)
    torch.cuda.empty_cache()
    model.eval()

    N_layer = 12
    embeddings = OrderedDict({
        'en': [[] for i in range(N_layer)],
        'zh': [[] for i in range(N_layer)]
    })
    if args.enable_vis:
        embeddings['vis'] = [[] for i in range(N_layer)]
    for _, pair in tqdm(enumerate(vocab)):
        for w, lang in zip(pair, ('en', 'zh')):
            segment_id = 1 if lang == 'en' else 6
            w_t = tokenizers[lang].tokenize(w)

            tokens = ['[CLS]'] + w_t + ['[SEP]']
            input_ids = indexer(tokens)
            token_type_ids = [segment_id] * len(input_ids)
            input_ids = np.expand_dims(np.array(input_ids), axis=0)
            token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0)
            input_ids = torch.tensor(input_ids,
                                     dtype=torch.long,
                                     device=device)
            token_type_ids = torch.tensor(token_type_ids,
                                          dtype=torch.long,
                                          device=device)

            output_embeddings = model.compute_embeddings(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                mode='txt2txt')
            #tuple 12 1,L,768
            for i, e in enumerate(output_embeddings):
                e = e.detach().cpu().numpy()
                ave = np.mean(e[0, 1:-1, :], axis=0)  # 768
                embeddings[lang][i].append(ave)

        if args.enable_vis:
            instance_embeddings = [[] for layer_i in range(N_layer)]
            for vis_embed in pair[2:]:
                vis_feats, vis_pe, cls_label = vis_embed[0], vis_embed[
                    1], vis_embed[2]  #1024
                vis_feats = torch.from_numpy(vis_feats).to(device)
                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = torch.from_numpy(vis_pe).to(device)
                vis_pe = vis_pe.unsqueeze(0)
                cls_label = torch.from_numpy(cls_label).to(device)
                cls_label = cls_label.unsqueeze(0)  #
                # lazy normalization of the coordinates... copy from seq2seq
                w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5
                h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5
                vis_pe[:, [0, 2]] /= w_est
                vis_pe[:, [1, 3]] /= h_est
                assert h_est > 0, 'should greater than 0! {}'.format(h_est)
                assert w_est > 0, 'should greater than 0! {}'.format(w_est)
                rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] -
                                                            vis_pe[:, 0])
                rel_area.clamp_(0)

                vis_pe = torch.cat(
                    (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]),
                    -1)  # confident score
                normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5,
                                               dim=-1)
                vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \
                    F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H

                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = vis_pe.unsqueeze(0)
                #print('input shape', vis_feats.shape, vis_pe.shape)
                segment_id = 0
                tokens = ['[CLS]', '[UNK]', '[SEP]']
                input_ids = indexer(tokens)
                token_type_ids = [segment_id] * len(input_ids)
                input_ids = np.expand_dims(np.array(input_ids), axis=0)
                token_type_ids = np.expand_dims(np.array(token_type_ids),
                                                axis=0)
                input_ids = torch.tensor(input_ids,
                                         dtype=torch.long,
                                         device=device)
                token_type_ids = torch.tensor(token_type_ids,
                                              dtype=torch.long,
                                              device=device)

                vis_embeddings = model.compute_embeddings(
                    vis_feats=vis_feats,
                    vis_pe=vis_pe,
                    input_ids=input_ids,
                    token_type_ids=token_type_ids,
                    mode='img2txt',
                    len_vis_input=1)
                #print(len(vis_embeddings), vis_embeddings[0].shape)
                #input()

                for i, e in enumerate(vis_embeddings):
                    e = e.detach().cpu().numpy()
                    ave = np.mean(e[0, 1:-1, :], axis=0)  # 768
                    instance_embeddings[i].append(ave)
                    #embeddings['vis'][i].append(ave)

            for i, embed_list in enumerate(instance_embeddings):
                if args.num_vis == 1:
                    embeddings['vis'][i].append(embed_list[0])
                else:
                    embeddings['vis'][i].append(embed_list)  #list of array
    # args.output_dir = os.path.join(args.output_dir, 'embedding_vis')
    # if not os.path.exists(args.output_dir):
    #     os.makedirs(args.output_dir)

    for ly in range(N_layer):
        embed = {'en': embeddings['en'][ly], 'zh': embeddings['zh'][ly]}
        if args.enable_vis:
            embed['vis'] = embeddings['vis'][ly]

        #save_numpy(embed,os.path.join(args.output_dir,'hiddenstates_layer_{}.npy'.format(ly)))
        if args.num_vis == 1:
            tSNE_reduce(
                embed, os.path.join(args.output_dir,
                                    'tSNE layer {}'.format(ly)))
            Cosine_Sim(
                embed,
                os.path.join(args.output_dir, 'CosineSim Layer {}'.format(ly)))
            sim_histgram(
                embed,
                os.path.join(args.output_dir, 'Histgram Layer {}'.format(ly)))
        if args.num_vis > 1:
            tSNE_reduce_visual(
                concept_list, embed,
                os.path.join(args.output_dir, 'tSNE layer {}'.format(ly)))
        print('Save layer {}'.format(ly))