Exemple #1
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        default=0,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        with open(args.input_file, encoding="utf-8") as fin:
            input_lines = [x.strip() for x in fin.readlines()]
            if args.subset > 0:
                logger.info("Decoding subset: %d", args.subset)
                input_lines = input_lines[:args.subset]
        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        input_lines = [
            data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines
        ]
        input_lines = sorted(list(enumerate(input_lines)),
                             key=lambda x: -len(x[1]))
        output_lines = [""] * len(input_lines)
        score_trace_list = [None] * len(input_lines)
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[1] for x in _chunk]
                next_i += args.batch_size
                max_a_len = max([len(x) for x in buf])
                instances = []
                for instance in [(x, max_a_len) for x in buf]:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance))
                with torch.no_grad():
                    batch = seq2seq_loader.batch_list_to_batch_tensors(
                        instances)
                    batch = [
                        t.to(device) if t is not None else None for t in batch
                    ]
                    input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                    traces = model(input_ids,
                                   token_type_ids,
                                   position_ids,
                                   input_mask,
                                   task_idx=task_idx,
                                   mask_qkv=mask_qkv)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()
                    for i in range(len(buf)):
                        w_ids = output_ids[i]
                        output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        output_sequence = ' '.join(detokenize(output_tokens))
                        output_lines[buf_id[i]] = output_sequence
                        if args.need_score_traces:
                            score_trace_list[buf_id[i]] = {
                                'scores': traces['scores'][i],
                                'wids': traces['wids'][i],
                                'ptrs': traces['ptrs'][i]
                            }
                pbar.update(1)
        if args.output_file:
            fn_out = args.output_file
        else:
            fn_out = model_recover_path + '.' + args.split
        with open(fn_out, "w", encoding="utf-8") as fout:
            for l in output_lines:
                fout.write(l)
                fout.write("\n")

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump({
                    "version": 0.0,
                    "num_samples": len(input_lines)
                }, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)
def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--dataset', default='txt', type=str, help='txt -> self-customized')
    # parser.add_argument('--src_lang', default='en', type=str, help='')
    # parser.add_argument('--tgt_lang', default='zh', type=str, help='')
    parser.add_argument(
        '--max_len_en',
        default=25,
        type=int,
        help='maximum length of English in **bilingual** corpus')
    parser.add_argument(
        '--max_len_zh',
        default=25,
        type=int,
        help='maximum length of Chinese in **bilingual** corpus')
    parser.add_argument("--src_file",
                        default='./.pkl',
                        type=str,
                        help="The input data file name.")

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--xml_vocab",
                        type=str,
                        default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",
                        type=str,
                        default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
    #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)

    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    #useless
    parser.add_argument('--split', type=str, default='val')  #wmt?
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=1,
                        help="The length of visual token input region 1")

    with open(
            '/data/private/chenyutong/dataset/concept_count/word_concept_count.pkl',
            'rb') as f:
        word_fre = pickle.load(f)
    word_fre = defaultdict(int, word_fre)

    args = parser.parse_args()

    assert args.batch_size == 1, 'only support batch_size=1'
    args.max_tgt_length = max(args.max_len_en, args.max_len_zh)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en
    tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
    tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(
        x, lang='zh', bypass_tokenizer=False)
    with open(args.xml_vocab, 'r') as f:
        tokenizer_zh.vocab = json.load(f)
    indexer = Indexer(
        [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab])
    with open('full_vocab.json', 'w') as f:
        json.dump(indexer.ids_to_tokens, f)
    tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh}
    print('tokenizer created')

    assert '.pkl' in args.src_file
    with open(args.src_file, 'rb') as f:
        src_data = pickle.load(f)
    # list [pred_id, vocab, vis, pos, distribution]
    # dict {'vgid':{'en':,'zh':,'region_features':[img, conf, fea[i], pos[i],dist]}}
    amp_handle = None
    if args.amp:
        from apex import amp

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"])
    forbid_ignore_set = None  #default None
    relax_projection, task_idx_proj = 0, 3
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(indexer(w_list))

    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        #logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            max_position_embeddings=args.max_position_embeddings,
            config_path=args.config_path,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            vocab_size=len(indexer),
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,  #img2txt
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input)

        del model_recover

        model.to(device)

        if args.amp:
            model = amp.initialize(model, opt_level='O2')  #'02')
        torch.cuda.empty_cache()
        model.eval()

        fout = open(os.path.join(args.output_dir, 'region2txt_output.txt'),
                    'w')
        output_lines = []
        select_ids = [87, 120, 179, 297, 721, 852, 1025]
        for step_val, sd in enumerate(src_data.items()):
            # if step_val>=1:
            #     break
            vgid, input_item = sd
            en, zh = input_item['en'], input_item['zh']
            fout.writelines('\n' + '#' * 10 + '\n')
            fout.writelines('{}\n'.format(vgid))
            fout.writelines('{} coco: word_fre {}  vis_fre {} \n'.format(
                en, input_item['coco_fre']['word'],
                input_item['coco_fre']['vis']))
            fout.writelines('{} aic: word_fre {}  vis_fre {} \n'.format(
                zh, input_item['aic_fre']['word'],
                input_item['aic_fre']['vis']))
            print('step_val {} Process {}'.format(step_val, en))
            for rf in tqdm(input_item['region_features']):
                filename, conf, vis_feats, vis_pe, cls_label = rf
                vis_feats = torch.from_numpy(vis_feats).to(device)
                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = torch.from_numpy(vis_pe).to(device)
                vis_pe = vis_pe.unsqueeze(0)
                cls_label = torch.from_numpy(cls_label).to(device)
                cls_label = cls_label.unsqueeze(0)  #
                # lazy normalization of the coordinates... copy from seq2seq
                w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5
                h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5
                vis_pe[:, [0, 2]] /= w_est
                vis_pe[:, [1, 3]] /= h_est
                assert h_est > 0, 'should greater than 0! {}'.format(h_est)
                assert w_est > 0, 'should greater than 0! {}'.format(w_est)
                rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] -
                                                            vis_pe[:, 0])
                rel_area.clamp_(0)

                vis_pe = torch.cat(
                    (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]),
                    -1)  # confident score
                normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5,
                                               dim=-1)
                vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \
                    F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H

                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = vis_pe.unsqueeze(0)
                #print('input shape', vis_feats.shape, vis_pe.shape)
                assert args.new_segment_ids == False, 'only support 0 1 6 now'
                tokens = ['[CLS]', '[UNK]', '[SEP]']
                input_ids = indexer(tokens)
                input_ids = np.expand_dims(np.array(input_ids), axis=0)
                input_ids = torch.tensor(input_ids,
                                         dtype=torch.long,
                                         device=device)

                max_len_in_batch = len(tokens) + args.max_tgt_length
                _tril_matrix = torch.tril(
                    torch.ones((max_len_in_batch, max_len_in_batch),
                               dtype=torch.long))
                input_mask = torch.zeros(max_len_in_batch,
                                         max_len_in_batch,
                                         dtype=torch.long,
                                         device=device)
                input_mask[:, :len(tokens)].fill_(1)
                second_st, second_end = len(tokens), max_len_in_batch
                input_mask[second_st:second_end, second_st:second_end].copy_(
                    _tril_matrix[:second_end - second_st, :second_end -
                                 second_st])  #L,L
                input_mask = input_mask.unsqueeze(0)

                position_ids = torch.arange(max_len_in_batch,
                                            dtype=torch.long,
                                            device=device)  #L
                position_ids = position_ids.unsqueeze(0)  # B,L

                predictions = {
                    'en': None,
                    'zh': None,
                    'en2zh': None,
                    'zh2en': None
                }
                for tgt_lang, lang_id in zip(['en', 'zh'], [1, 6]):
                    token_type_ids = [0] * len(
                        tokens) + [lang_id] * args.max_tgt_length
                    token_type_ids = np.expand_dims(np.array(token_type_ids),
                                                    axis=0)
                    token_type_ids = torch.tensor(token_type_ids,
                                                  dtype=torch.long,
                                                  device=device)
                    with torch.no_grad():
                        # print(token_type_ids[0])
                        # print(position_ids[0])
                        # print(input_ids[0])
                        # print(input_mask[0])
                        # input()
                        traces = model(
                            vis_feats=vis_feats,
                            vis_pe=vis_pe,
                            input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            attention_mask=input_mask,
                            search_beam_size=args.beam_size,
                            task_idx=3,
                            mode='img2txt',
                            sample_mode='greedy')  #validation greedy

                    output_sequence = postprocess(traces, args.beam_size,
                                                  tgt_lang, indexer)
                    predictions[tgt_lang] = output_sequence  #truncate

                for langs, lang_ids in zip(['en2zh', 'zh2en'],
                                           [[1, 6], [6, 1]]):
                    src_lang = langs[:2]  #en,zh
                    tgt_lang = langs[-2:]
                    w = predictions[
                        src_lang]  # predictions['en']/ predictions['zh']
                    w_t = tokenizers[src_lang].tokenize(w)
                    tokens = ['[CLS]'] + w_t + ['[SEP]']
                    input_ids = indexer(tokens)
                    token_type_ids = [lang_ids[0]] * len(
                        input_ids) + [lang_ids[1]] * args.max_tgt_length
                    input_ids = np.expand_dims(np.array(input_ids), axis=0)
                    token_type_ids = np.expand_dims(np.array(token_type_ids),
                                                    axis=0)
                    input_ids = torch.tensor(input_ids,
                                             dtype=torch.long,
                                             device=device)
                    token_type_ids = torch.tensor(token_type_ids,
                                                  dtype=torch.long,
                                                  device=device)

                    max_len_in_batch = len(
                        tokens) + args.max_tgt_length  #2+64 = 66
                    position_ids = torch.arange(max_len_in_batch,
                                                dtype=torch.long,
                                                device=device)  #L
                    position_ids = position_ids.unsqueeze(0)  # B,L
                    _tril_matrix = torch.tril(
                        torch.ones((max_len_in_batch, max_len_in_batch),
                                   dtype=torch.long))
                    input_mask = torch.zeros(max_len_in_batch,
                                             max_len_in_batch,
                                             dtype=torch.long,
                                             device=device)
                    input_mask[:, :len(tokens)].fill_(1)
                    second_st, second_end = len(tokens), max_len_in_batch
                    input_mask[second_st:second_end,
                               second_st:second_end].copy_(
                                   _tril_matrix[:second_end -
                                                second_st, :second_end -
                                                second_st])  #L,L
                    input_mask = input_mask.unsqueeze(0)
                    with torch.no_grad():
                        traces = model(
                            vis_feats=None,
                            vis_pe=None,
                            input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            attention_mask=input_mask,
                            search_beam_size=args.beam_size,
                            task_idx=3,
                            mode='txt2txt',
                            sample_mode='greedy')  #validation greedy
                    output_sequence = postprocess(traces, args.beam_size,
                                                  tgt_lang, indexer)
                    predictions[langs] = output_sequence

                #print(predictions)
                fout.writelines(
                    'conf:{:.2f} en:{: <10} fre:{:<5d} en2zh:{: <10} zh:{: <10} fre:{:<5d} zh2en:{: <10} \n'
                    .format(conf, predictions['en'],
                            word_fre['coco'][predictions['en']],
                            predictions['en2zh'], predictions['zh'],
                            word_fre['aic'][predictions['zh']],
                            predictions['zh2en']))

        fout.close()
Exemple #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        default='txt',
                        type=str,
                        help='txt -> self-customized')
    parser.add_argument('--src_lang', default='en', type=str, help='')
    parser.add_argument('--tgt_lang', default='zh', type=str, help='')
    parser.add_argument(
        '--max_len_en',
        default=25,
        type=int,
        help='maximum length of English in **bilingual** corpus')
    parser.add_argument(
        '--max_len_zh',
        default=25,
        type=int,
        help='maximum length of Chinese in **bilingual** corpus')
    parser.add_argument("--src_file",
                        default='./src.txt',
                        type=str,
                        help="The input data file name.")
    parser.add_argument('--corpus', default='txt', type=str)
    parser.add_argument('--en_first',
                        action='store_true',
                        help='always to put english as the first sentence')

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--xml_vocab",
                        type=str,
                        default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",
                        type=str,
                        default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
    #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)

    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    #useless
    parser.add_argument('--split', type=str, default='val')  #wmt?
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")

    args = parser.parse_args()
    args.tgt_lang = 'en' if args.src_lang == 'zh' else 'zh'
    # print(sys.getfilesystemencoding())
    # print('这是中文')
    assert args.batch_size == 1, 'only support batch_size=1'
    args.max_tgt_length = args.max_len_en if args.src_lang == 'zh' else args.max_len_zh

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en
    tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
    tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(
        x, lang='zh', bypass_tokenizer=False)
    with open(args.xml_vocab, 'r') as f:
        tokenizer_zh.vocab = json.load(f)
    indexer = Indexer(
        [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab])
    with open('full_vocab.json', 'w') as f:
        json.dump(indexer.ids_to_tokens, f)
    tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh}

    print('tokenizer created')

    if '.txt' in args.src_file:
        with codecs.open(args.src_file, 'r') as f:
            src_lines = f.readlines()
        src_lines = [line.strip() for line in src_lines]
        N_lines = len(src_lines)
    elif 'hdf5' in args.src_file:
        assert 'wmt' == args.corpus
        src_lines = args.src_file
        N_lines = 1999
    else:
        raise

    pipeline = seq2seq_loader.Preprocess4Seq2SeqBilingualDecoder(
        corpus=args.corpus,
        file_src=src_lines,
        src_lang=args.src_lang,
        indexer=indexer,
        tokenizers=tokenizers,
        max_len=args.max_len_en + args.max_len_zh + 3,
        max_tgt_length=args.max_tgt_length,
        preprocessed=False if args.corpus == 'txt' else True,
        new_segment_ids=args.new_segment_ids,
        mode='s2s')

    eval_dataset = seq2seq_loader.Txt2txtDataset(
        N_lines=N_lines,
        split=None,
        batch_size=args.batch_size,
        tokenizers=tokenizers,
        max_len=args.max_len_en + args.max_len_zh + 3,
        preprocessed=False if args.corpus == 'txt' else True,
        bi_uni_pipeline=[pipeline],
        s2s_prob=1,
        bi_prob=0)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        sampler=SequentialSampler(eval_dataset),
        num_workers=4,
        collate_fn=batch_list_to_batch_tensors,
        pin_memory=True)

    amp_handle = None
    if args.amp:
        from apex import amp

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"])
    forbid_ignore_set = None  #default None
    relax_projection, task_idx_proj = 0, 3
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(indexer(w_list))

    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        #logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            max_position_embeddings=args.max_position_embeddings,
            config_path=args.config_path,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            vocab_size=len(indexer),
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,  #img2txt
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input)

        del model_recover

        model.to(device)

        if args.amp:
            model = amp.initialize(model, opt_level='O2')  #'02')
        torch.cuda.empty_cache()
        model.eval()
        val_iter_bar = tqdm(eval_dataloader)
        output_lines = []
        for step_val, val_iter_output in enumerate(val_iter_bar):
            info_, batch = val_iter_output[0], val_iter_output[1]
            with torch.no_grad():
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, position_ids, input_mask, task_idx = batch
                traces = model(None,
                               None,
                               input_ids,
                               segment_ids,
                               position_ids,
                               input_mask,
                               search_beam_size=args.beam_size,
                               task_idx=task_idx,
                               mode='txt2txt',
                               sample_mode='greedy')  #validation greedy
                # if step_val==0:
                #     print(segment_ids[0])
                #     print(input_ids[0])
                #     print(position_ids[0])
                #     input()

                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                else:
                    output_ids = traces[0].tolist()
                #print(output_ids)
                #input()
                for ii, w_ids in enumerate(output_ids):
                    output_buf = indexer.convert_ids_to_tokens(w_ids)
                    output_tokens = []
                    for t in output_buf:
                        if t in ("[SEP]", "[PAD]"):
                            break
                        output_tokens.append(t)
                        #print(t)
                    if args.tgt_lang == 'en':
                        output_sequence = ' '.join(detokenize(output_tokens))
                        output_sequence = output_sequence.replace(
                            ' @ - @ ', '-')
                    #print(id_,output_sequence)
                    #id_ = step_val*args.batch_size+ii
                    #output_sequence = output_sequence.replace('</w>',' ').replace(' ','')
                    if args.tgt_lang == 'zh':
                        output_sequence = ''.join(
                            detokenize(output_tokens)).replace('</w>',
                                                               '').replace(
                                                                   '[SEP]', '')
                    output_lines.append(output_sequence)

        # with open(os.path.join(args.output_dir,'translation_output.json'),'w') as f:
        #     json.dump(output_lines, f)
        with open(os.path.join(args.output_dir, 'translation_output.txt'),
                  'w') as f:
            for line in output_lines:
                f.writelines(line + '\n')
Exemple #4
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default='tmp',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_file",
        default="training.log",
        type=str,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--do_train",
        action='store_true',
        help="Whether to run training. This should ALWAYS be set to True.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=30,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--global_rank",
                        type=int,
                        default=-1,
                        help="global_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 32-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=20,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='b',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=3,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=4,
                        type=int,
                        help="Number of workers for the data loader.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=['/mnt/dat/COCO/annotations/dataset_coco.json'],
        type=str,
        nargs='+',
        help="The input data file name.")
    parser.add_argument('--enable_visdom', action='store_true')
    parser.add_argument('--visdom_port', type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument('--image_root',
                        type=str,
                        default='/mnt/dat/COCO/images')
    parser.add_argument('--dataset',
                        default='coco',
                        type=str,
                        help='coco | flickr30k | cc')
    parser.add_argument('--split',
                        type=str,
                        nargs='+',
                        default=['train', 'restval'])

    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url',
                        default='file://[PT_OUTPUT_DIR]/nonexistent_file',
                        type=str,
                        help='url used to set up distributed training')
    parser.add_argument(
        '--file_valid_jpgs',
        default='/mnt/dat/COCO/annotations/coco_valid_jpgs.json',
        type=str)
    parser.add_argument('--sche_mode',
                        default='warmup_linear',
                        type=str,
                        help="warmup_linear | warmup_constant | warmup_cosine")
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--use_num_imgs', default=-1, type=int)
    parser.add_argument('--vis_mask_prob', default=0, type=float)
    parser.add_argument('--max_drop_worst_ratio', default=0, type=float)
    parser.add_argument('--drop_after', default=6, type=int)

    parser.add_argument(
        '--s2s_prob',
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq)."
    )
    parser.add_argument(
        '--bi_prob',
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.")
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument(
        '--region_bbox_file',
        default=
        'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
        type=str)
    parser.add_argument(
        '--region_det_file_prefix',
        default=
        'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval',
        type=str)
    parser.add_argument('--tasks', default='img2txt', help='img2txt | vqa2')
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--scst',
                        action='store_true',
                        help='Self-critical sequence training')

    args = parser.parse_args()

    print('global_rank: {}, local rank: {}'.format(args.global_rank,
                                                   args.local_rank))

    args.max_seq_length = args.max_len_b + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (args.vis_mask_prob > 0
                               )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace('[PT_OUTPUT_DIR]', args.output_dir)

    # arguments inspection
    assert (args.tasks in ('img2txt', 'vqa2'))
    assert args.enable_butd == True, 'only support region attn! featmap attn deprecated'
    assert (
        not args.scst) or args.dataset == 'coco', 'scst support on coco only!'
    if args.scst:
        assert args.dataset == 'coco', 'scst support on coco only!'
        assert args.max_pred == 0 and args.mask_prob == 0, 'no mask for scst!'
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert (args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root,
                                             args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(
            args.image_root, args.region_det_file_prefix) if args.dataset in (
                'cc', 'coco') and args.region_det_file_prefix != '' else ''

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://localhost:10001',  #args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank)
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom
        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {'iter': None, 'score': None}

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir +
        '/.pretrained_model_{}'.format(args.global_rank))
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2'))
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == 'vqa2')))

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks)

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs * 1. /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == 'img2txt' else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"])  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, 'must init from maximum likelihood training'
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir +
            '/.pretrained_model_{}'.format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir,
                             "model.{0}.bin".format(recover_step)))
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total * 1. /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir +
                '/.pretrained_model_{}'.format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks)
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input)

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             schedule=args.sche_mode,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir,
                         "optim.{0}.bin".format(recover_step)))
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)
        logger.info("  Loader length = %d", len(train_dataloader))

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              args.num_train_epochs + 1,
                              desc="Epoch"):
            if args.local_rank >= 0:
                train_sampler.set_epoch(i_epoch - 1)
            iter_bar = tqdm(train_dataloader, desc='Iter (loss=X.XXX)')
            nbatches = len(train_dataloader)
            train_loss = []
            pretext_loss = []
            vqa2_loss = []
            scst_reward = []
            for step, batch in enumerate(iter_bar):
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, img, vis_masked_pos, vis_pe, ans_labels = batch

                if args.fp16:
                    img = img.half()
                    vis_pe = vis_pe.half()

                if args.enable_butd:
                    conv_feats = img.data  # Bx100x2048
                    vis_pe = vis_pe.data
                else:
                    conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                    conv_feats = conv_feats.view(conv_feats.size(0),
                                                 conv_feats.size(1),
                                                 -1).permute(0, 2,
                                                             1).contiguous()

                if not args.scst:
                    loss_tuple = model(
                        conv_feats,
                        vis_pe,
                        input_ids,
                        segment_ids,
                        input_mask,
                        lm_label_ids,
                        ans_labels,
                        is_next,
                        masked_pos=masked_pos,
                        masked_weights=masked_weights,
                        task_idx=task_idx,
                        vis_masked_pos=vis_masked_pos,
                        mask_image_regions=args.mask_image_regions,
                        drop_worst_ratio=args.max_drop_worst_ratio
                        if i_epoch > args.drop_after else 0)
                    mean_reward = loss_tuple[0].new(1).fill_(0)
                else:
                    # scst training
                    model.eval()
                    position_ids = torch.arange(
                        input_ids.size(1),
                        dtype=input_ids.dtype,
                        device=input_ids.device).unsqueeze(0).expand_as(
                            input_ids)
                    input_dummy = input_ids[:, :args.len_vis_input +
                                            2]  # +2 for [CLS] and [SEP]
                    greedy_res = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)
                    gen_result = input_ids.new(
                        input_ids.size(0),
                        input_ids.size(1) - args.len_vis_input - 2).fill_(0)

                    with torch.no_grad():
                        greedy_res_raw, _ = model(conv_feats,
                                                  vis_pe,
                                                  input_dummy,
                                                  segment_ids,
                                                  position_ids,
                                                  input_mask,
                                                  task_idx=task_idx,
                                                  sample_mode='greedy')
                        for b in range(greedy_res_raw.size(0)):
                            for idx in range(greedy_res_raw.size(1)):
                                if greedy_res_raw[b][idx] not in [
                                        eos_word_ids, pad_word_ids
                                ]:
                                    greedy_res[b][idx] = greedy_res_raw[b][idx]
                                else:
                                    if greedy_res_raw[b][idx] == eos_word_ids:
                                        greedy_res[b][idx] = eos_word_ids
                                    break
                    model.train()
                    gen_result_raw, sample_logprobs = model(
                        conv_feats,
                        vis_pe,
                        input_dummy,
                        segment_ids,
                        position_ids,
                        input_mask,
                        task_idx=task_idx,
                        sample_mode='sample')
                    for b in range(gen_result_raw.size(0)):
                        for idx in range(gen_result_raw.size(1)):
                            if gen_result_raw[b][idx] not in [
                                    eos_word_ids, pad_word_ids
                            ]:
                                gen_result[b][idx] = gen_result_raw[b][idx]
                            else:
                                if gen_result_raw[b][idx] == eos_word_ids:
                                    gen_result[b][idx] = eos_word_ids
                                break

                    gt_ids = input_ids[:, args.len_vis_input + 2:]
                    reward = get_self_critical_reward(greedy_res,
                                                      gt_ids, gen_result,
                                                      gt_ids.size(0))
                    reward = torch.from_numpy(reward).float().to(
                        gen_result.device)
                    mean_reward = reward.mean()
                    loss = rl_crit(sample_logprobs, gen_result.data, reward)

                    loss_tuple = [
                        loss,
                        loss.new(1).fill_(0.),
                        loss.new(1).fill_(0.)
                    ]

                # disable pretext_loss_deprecated for now
                masked_lm_loss, pretext_loss_deprecated, ans_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu. For dist, this is done through gradient addition.
                    masked_lm_loss = masked_lm_loss.mean()
                    pretext_loss_deprecated = pretext_loss_deprecated.mean()
                    ans_loss = ans_loss.mean()
                loss = masked_lm_loss + pretext_loss_deprecated + ans_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())
                train_loss.append(loss.item())
                pretext_loss.append(pretext_loss_deprecated.item())
                vqa2_loss.append(ans_loss.item())
                scst_reward.append(mean_reward.item())
                if step % 100 == 0:
                    logger.info(
                        "Epoch {}, Iter {}, Loss {:.2f}, Pretext {:.2f}, VQA2 {:.2f}, Mean R {:.3f}\n"
                        .format(i_epoch, step, np.mean(train_loss),
                                np.mean(pretext_loss), np.mean(vqa2_loss),
                                np.mean(scst_reward)))

                if args.enable_visdom:
                    if vis_window['iter'] is None:
                        vis_window['iter'] = vis.line(
                            X=np.tile(
                                np.arange((i_epoch - 1) * nbatches + step,
                                          (i_epoch - 1) * nbatches + step + 1),
                                (1, 1)).T,
                            Y=np.column_stack(
                                (np.asarray([np.mean(train_loss)]), )),
                            opts=dict(title='Training Loss',
                                      xlabel='Training Iteration',
                                      ylabel='Loss',
                                      legend=['total']))
                    else:
                        vis.line(X=np.tile(
                            np.arange((i_epoch - 1) * nbatches + step,
                                      (i_epoch - 1) * nbatches + step + 1),
                            (1, 1)).T,
                                 Y=np.column_stack(
                                     (np.asarray([np.mean(train_loss)]), )),
                                 opts=dict(title='Training Loss',
                                           xlabel='Training Iteration',
                                           ylabel='Loss',
                                           legend=['total']),
                                 win=vis_window['iter'],
                                 update='append')

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            logger.info(
                "** ** * Saving fine-tuned model and optimizer ** ** * ")
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "model.{0}.bin".format(i_epoch))
            output_optim_file = os.path.join(args.output_dir,
                                             "optim.{0}.bin".format(i_epoch))
            if args.global_rank in (
                    -1, 0):  # save model if the first device or no dist
                torch.save(
                    copy.deepcopy(model_to_save).cpu().state_dict(),
                    output_model_file)
                # torch.save(optimizer.state_dict(), output_optim_file) # disable for now, need to sanitize state and ship everthing back to cpu

            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()

            if args.world_size > 1:
                torch.distributed.barrier()
def main():

    args = load_args()
    ss = load_server_socket()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()

        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        DONE_SIGNAL = 42  # bytescode of asterisk '*'
        while True:
            ss.listen(100)
            print("Waiting Connection:")

            cs, addr = ss.accept()
            data = bytes()
            while True:
                recv = cs.recv(1024)
                if 0 < len(recv) and DONE_SIGNAL == recv[-1]:
                    data += recv[:len(recv) - 1]
                    break
                data += recv
            print("Connection with:", addr)
            print("Received:", len(data))

            input_lines = [
                x.strip() for x in data.decode('utf-8').splitlines()
            ]
            if args.subset > 0:
                logger.info("Decoding subset: %d", args.subset)
                input_lines = input_lines[:args.subset]

            data_tokenizer = WhitespaceTokenizer(
            ) if args.tokenized_input else tokenizer
            input_lines = [
                data_tokenizer.tokenize(x)[:max_src_length]
                for x in input_lines
            ]
            input_lines = list(enumerate(input_lines))

            output_lines = [""] * len(input_lines)
            score_trace_list = [None] * len(input_lines)
            next_i = 0
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[1] for x in _chunk]
                next_i += args.batch_size
                max_a_len = max([len(x) for x in buf])
                instances = []
                for instance in [(x, max_a_len) for x in buf]:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance))
                with torch.no_grad():
                    batch = seq2seq_loader.batch_list_to_batch_tensors(
                        instances)
                    batch = [
                        t.to(device) if t is not None else None for t in batch
                    ]
                    input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                    traces = model(input_ids,
                                   token_type_ids,
                                   position_ids,
                                   input_mask,
                                   task_idx=task_idx,
                                   mask_qkv=mask_qkv)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()

                    qg_result = []
                    for i in range(len(buf)):
                        w_ids = output_ids[i]
                        output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        output_sequence = ' '.join(detokenize(output_tokens))
                        qg_result.append(output_sequence)
                        if args.need_score_traces:
                            score_trace_list[buf_id[i]] = {
                                'scores': traces['scores'][i],
                                'wids': traces['wids'][i],
                                'ptrs': traces['ptrs'][i]
                            }
                    cs.sendall(ascii_print('\n'.join(qg_result)))
            cs.close()

            if args.output_file:
                fn_out = args.output_file
            else:
                fn_out = model_recover_path + '.' + args.split
            with open(fn_out, "w", encoding="utf-8") as fout:
                for l in output_lines:
                    fout.write(l)
                    fout.write("\n")

            if args.need_score_traces:
                with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                    pickle.dump(
                        {
                            "version": 0.0,
                            "num_samples": len(input_lines)
                        }, fout_trace)
                    for x in score_trace_list:
                        pickle.dump(x, fout_trace)
Exemple #6
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument('--topk', type=int, default=10, help="Value of K.")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # tokenizer = BertTokenizer.from_pretrained(
    #     args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer(
        vocab_file=
        '/ps2/intern/clsi/BERT/bert_weights/cased_L-24_H-1024_A-16/vocab.txt',
        do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            sos_id=sos_word_id,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            not_predict_set=not_predict_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            ffn_type=args.ffn_type,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            pos_shift=args.pos_shift,
            topk=args.topk,
            config_path=args.config_path)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        ## for YFG style json
        # testset = loads_json(args.input_file, 'Load Test Set: '+args.input_file)
        # if args.subset > 0:
        #     logger.info("Decoding subset: %d", args.subset)
        #     testset = testset[:args.subset]

        with open(args.input_file, encoding="utf-8") as fin:
            data = json.load(fin)
        #     input_lines = [x.strip() for x in fin.readlines()]
        #     if args.subset > 0:
        #         logger.info("Decoding subset: %d", args.subset)
        #         input_lines = input_lines[:args.subset]
        # data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
        # input_lines = [data_tokenizer.tokenize(
        #     x)[:max_src_length] for x in input_lines]
        # input_lines = sorted(list(enumerate(input_lines)),
        #                      key=lambda x: -len(x[1]))
        # output_lines = [""] * len(input_lines)
        # score_trace_list = [None] * len(input_lines)
        # total_batch = math.ceil(len(input_lines) / args.batch_size)

        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        PQA_dict = {}  #will store the generated distractors
        dis_tot = 0
        dis_n = 0
        len_tot = 0
        hypothesis = {}
        ##change to process one by one and store the distractors in PQA json form
        ##with tqdm(total=total_batch) as pbar:
        # for example in tqdm(testset):
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id in hypothesis:
        #         continue
        # dis_n += 1
        # if dis_n % 2000 == 0:
        #     logger.info("Already processed: "+str(dis_n))
        counter = 0
        for race_id, example in tqdm(data.items()):
            counter += 1
            if args.subset > 0 and counter >= args.subset:
                break
            eg_dict = {}
            # eg_dict["question_id"] = question_id
            # eg_dict["question"] = ' '.join(example['question'])
            # eg_dict["context"] = ' '.join(example['article'])

            eg_dict["question"] = example['question']
            eg_dict["context"] = example['context']
            label = int(example["label"])
            options = example["options"]
            answer = options[label]
            #new_distractors = []
            pred1 = None
            pred2 = None
            pred3 = None
            #while next_i < len(input_lines):
            #_chunk = input_lines[next_i:next_i + args.batch_size]
            #line = example["context"].strip() + ' ' + example["question"].strip()
            question = example['question']
            question = question.replace('_', ' ')
            line = ' '.join(
                nltk.word_tokenize(example['context']) +
                nltk.word_tokenize(question))
            line = [data_tokenizer.tokenize(line)[:max_src_length]]
            # buf_id = [x[0] for x in _chunk]
            # buf = [x[1] for x in _chunk]
            buf = line
            #next_i += args.batch_size
            max_a_len = max([len(x) for x in buf])
            instances = []
            for instance in [(x, max_a_len) for x in buf]:
                for proc in bi_uni_pipeline:
                    instances.append(proc(instance))
            with torch.no_grad():
                batch = seq2seq_loader.batch_list_to_batch_tensors(instances)
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                # for i in range(1):
                #try max 10 times
                # if len(new_distractors) >= 3:
                #     break
                traces = model(input_ids,
                               token_type_ids,
                               position_ids,
                               input_mask,
                               task_idx=task_idx,
                               mask_qkv=mask_qkv)
                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                    # print (np.array(output_ids).shape)
                    # print (output_ids)
                else:
                    output_ids = traces.tolist()
                # now only supports single batch decoding!!!
                # will keep the second and third sequence as backup
                for i in range(len(buf)):
                    # print (len(buf), buf)
                    for s in range(len(output_ids)):
                        output_seq = output_ids[s]
                        #w_ids = output_ids[i]
                        #output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                        output_buf = tokenizer.convert_ids_to_tokens(
                            output_seq)
                        output_tokens = []
                        for t in output_buf:
                            if t in ("[SEP]", "[PAD]"):
                                break
                            output_tokens.append(t)
                        if s == 1:
                            backup_1 = output_tokens
                        if s == 2:
                            backup_2 = output_tokens
                        if pred1 is None:
                            pred1 = output_tokens
                        elif jaccard_similarity(pred1, output_tokens) < 0.5:
                            if pred2 is None:
                                pred2 = output_tokens
                            elif pred3 is None:
                                if jaccard_similarity(pred2,
                                                      output_tokens) < 0.5:
                                    pred3 = output_tokens
                        if pred1 is not None and pred2 is not None and pred3 is not None:
                            break
                    if pred2 is None:
                        pred2 = backup_1
                        if pred3 is None:
                            pred3 = backup_2
                    elif pred3 is None:
                        pred3 = backup_1
                        # output_sequence = ' '.join(detokenize(output_tokens))
                        # print (output_sequence)
                        # print (output_sequence)
                        # if output_sequence.lower().strip() == answer.lower().strip():
                        #     continue
                        # repeated = False
                        # for cand in new_distractors:
                        #     if output_sequence.lower().strip() == cand.lower().strip():
                        #         repeated = True
                        #         break
                        # if not repeated:
                        #     new_distractors.append(output_sequence.strip())

            #hypothesis[question_id] = [pred1, pred2, pred3]
            new_distractors = [pred1, pred2, pred3]
            # print (new_distractors)
            # dis_tot += len(new_distractors)
            # # fill the missing ones with original distractors
            # for i in range(4):
            #     if len(new_distractors) >= 3:
            #         break
            #     elif i == label:
            #         continue
            #     else:
            #         new_distractors.append(options[i])
            for dis in new_distractors:
                len_tot += len(dis)
                dis_n += 1
            new_distractors = [
                ' '.join(detokenize(dis)) for dis in new_distractors
                if dis is not None
            ]
            assert len(new_distractors) == 3, "Number of distractors WRONG"
            new_distractors.insert(label, answer)
            #eg_dict["generated_distractors"] = new_distractors
            eg_dict["options"] = new_distractors
            eg_dict["label"] = label
            #PQA_dict[question_id] = eg_dict
            PQA_dict[race_id] = eg_dict

        # reference = {}
        # for example in testset:
        #     question_id = str(example['id']['file_id']) + '_' + str(example['id']['question_id'])
        #     if question_id not in reference.keys():
        #         reference[question_id] = [example['distractor']]
        #     else:
        #         reference[question_id].append(example['distractor'])

        # _ = eval(hypothesis, reference)
        # assert len(PQA_dict) == len(data), "Number of examples WRONG"
        # logger.info("Average number of GENERATED distractor per question: "+str(dis_tot/dis_n))
        logger.info("Average length of distractors: " + str(len_tot / dis_n))
        with open(args.output_file, mode='w', encoding='utf-8') as f:
            json.dump(PQA_dict, f, indent=4)
Exemple #7
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--unilm_model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining unilm model.")
    parser.add_argument("--topic_model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    parser.add_argument("--topic_data_path",
                        default=None,
                        type=str,
                        help="The file of  topic model data.")
    parser.add_argument("--topic_num", default=50, type=int, help="topic_num.")
    parser.add_argument("--data_path",
                        default=None,
                        type=str,
                        help="The file of  topic model data.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument('--topic_mode',
                        default=1,
                        type=float,
                        help="1:idea1 1.1:idea1_wo_theta 2:idea2 ")
    parser.add_argument("--topic_model_dict_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining topic model.")
    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument('--subset',
                        type=int,
                        default=0,
                        help="Decode a subset of the input dataset.")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Ignore the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--not_predict_token',
                        type=str,
                        default=None,
                        help="Do not predict the tokens during decoding.")

    args = parser.parse_args()
    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode="s2s",
            num_qkv=args.num_qkv,
            s2s_special_token=args.s2s_special_token,
            s2s_add_segment=args.s2s_add_segment,
            s2s_share_segment=args.s2s_share_segment,
            pos_shift=args.pos_shift))
    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        # logger.info("enable fp16 with amp")
    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[S2S_SOS]"])

    def _get_token_id_set(s):
        r = None
        if s:
            w_list = []
            for w in s.split('|'):
                if w.startswith('[') and w.endswith(']'):
                    w_list.append(w.upper())
                else:
                    w_list.append(w)
            r = set(tokenizer.convert_tokens_to_ids(w_list))
        return r

    forbid_ignore_set = _get_token_id_set(args.forbid_ignore_word)
    not_predict_set = _get_token_id_set(args.not_predict_token)
    unilm_model_recover = torch.load(args.unilm_model_recover_path)
    unilm = BertForSeq2SeqDecoder.from_pretrained(
        args.bert_model,
        state_dict=unilm_model_recover,
        num_labels=cls_num_labels,
        num_rel=pair_num_relation,
        type_vocab_size=type_vocab_size,
        task_idx=3,
        mask_word_id=mask_word_id,
        search_beam_size=args.beam_size,
        length_penalty=args.length_penalty,
        eos_id=eos_word_ids,
        sos_id=sos_word_id,
        forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
        forbid_ignore_set=forbid_ignore_set,
        not_predict_set=not_predict_set,
        ngram_size=args.ngram_size,
        min_len=args.min_len,
        mode=args.mode,
        max_position_embeddings=args.max_seq_length,
        ffn_type=args.ffn_type,
        num_qkv=args.num_qkv,
        seg_emb=args.seg_emb,
        pos_shift=args.pos_shift)
    topic_model_recover = torch.load(args.topic_model_recover_path)
    dictionary = Dictionary.load_from_text(args.topic_model_dict_path)
    gsm = GSM(len(dictionary))
    gsm.load_state_dict(topic_model_recover)
    del unilm_model_recover
    del topic_model_recover

    if args.fp16:
        unilm.half()
        gsm.half()
    unilm.to(device)
    gsm.to(device)

    if n_gpu > 1:
        unilm = torch.nn.DataParallel(unilm)
        gsm = torch.nn.DataParallel(gsm)
    torch.cuda.empty_cache()
    unilm.eval()
    gsm.eval()
    next_i = 0
    max_src_length = args.max_seq_length - 2 - args.max_tgt_length

    with open(args.input_file, encoding="utf-8") as fin:
        input_lines = [x.strip() for x in fin.readlines()]
        if args.subset > 0:  #==0 可忽略
            # logger.info("Decoding subset: %d", args.subset)
            input_lines = input_lines[:args.subset]
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    input_lines = [
        data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines
    ]

    input_lines = sorted(
        list(enumerate(input_lines)), key=lambda x: -len(x[1])
    )  #input_lines = [(ori_index,[tokens]), (ori_index,[tokens])] 按照文本长度倒着排

    output_lines = [""] * len(input_lines)  #一维[]
    score_trace_list = [None] * len(input_lines)
    total_batch = math.ceil(len(input_lines) / args.batch_size)

    # get topic_model bows
    def detokenize(tk_list):
        r_list = []
        src = " ".join(tk_list)
        src = src.replace("[UNK]", "")
        tk_list = src.split()
        for tk in tk_list:
            if tk.startswith('##') and len(r_list) > 0:
                r_list[-1] = r_list[-1] + tk[2:]
            else:
                r_list.append(tk)
        src = " ".join(r_list)
        src = src.replace("UNK", "")
        r_list = src.split()
        return r_list

    txtLines = []
    for input_line in input_lines:
        textline = " ".join(detokenize(input_line[1]))
        txtLines.append(textline)
    cwd = os.getcwd()
    dictionary = Dictionary.load_from_text(args.topic_model_dict_path)
    dictionary.id2token = {
        v: k
        for k, v in dictionary.token2id.items()
    }  # because id2token is empty be default, it is a bug.
    stopwords = set([
        l.strip('\n').strip()
        for l in open(os.path.join(cwd, 'data/topic_model', 'stopwords.txt'),
                      'r',
                      encoding='utf-8')
    ])
    topic_tokenizer = seq2seq_loader.SpacyTokenizer(stopwords=stopwords)
    docs = topic_tokenizer.tokenize(txtLines)

    # convert to BOW representation
    bows, _docs = [], []
    vocabsize = len(dictionary)
    print("vocabsize", vocabsize)
    for doc in docs:
        _bow = dictionary.doc2bow(doc)
        if _bow != []:
            _docs.append(list(doc))
            bows.append(_bow)
        else:
            bows.append([(vocabsize - 1, 1)])
    docs = _docs
    with tqdm(total=total_batch) as pbar:
        while next_i < len(input_lines):
            _chunk = input_lines[next_i:next_i +
                                 args.batch_size]  #如果超过就到最后一个,这是list[a:b]的特性
            buf_id = [x[0] for x in _chunk]
            buf = [x[1] for x in _chunk]
            max_a_len = max([len(x) for x in buf])
            instances = []
            batch_bow = []
            for i in range(next_i, next_i + args.batch_size):
                if i < len(input_lines):
                    bow = torch.zeros(vocabsize)
                    item = list(
                        zip(*bows[i])
                    )  # bow = [[token_id1,token_id2,...],[freq1,freq2,...]]
                    bow[list(item[0])] = torch.tensor(list(item[1])).float()
                    batch_bow.append(bow)
            next_i += args.batch_size
            for instance in [(x, max_a_len) for x in buf]:
                for proc in bi_uni_pipeline:  #proc 是 Preprocess4Seq2seqDecoder  相当于可以把数据给padding
                    instances.append(proc(instance))
            with torch.no_grad():
                batch = seq2seq_loader.batch_list_to_batch_tensors(instances)
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                batch_bow = torch.stack(batch_bow)
                batch_bow = batch_bow.to(device)
                input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
                p_x, mus, log_vars, theta, beta, topic_embedding = gsm(
                    batch_bow)
                traces = unilm(input_ids,
                               theta,
                               beta,
                               topic_embedding,
                               args.topic_mode,
                               token_type_ids,
                               position_ids,
                               input_mask,
                               task_idx=task_idx,
                               mask_qkv=mask_qkv)
                cal_ppl(batch_bow, p_x, log_vars, mus)

                if args.beam_size > 1:
                    traces = {k: v.tolist() for k, v in traces.items()}
                    output_ids = traces['pred_seq']
                else:
                    output_ids = traces.tolist()
                for i in range(len(buf)):
                    w_ids = output_ids[i]
                    output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                    output_tokens = []
                    for t in output_buf:
                        if t in ("[SEP]", "[PAD]"):
                            break
                        output_tokens.append(t)
                    output_sequence = ' '.join(detokenize(output_tokens))
                    output_lines[buf_id[i]] = output_sequence
                    if args.need_score_traces:
                        score_trace_list[buf_id[i]] = {
                            'scores': traces['scores'][i],
                            'wids': traces['wids'][i],
                            'ptrs': traces['ptrs'][i]
                        }
            pbar.update(1)
    print("word_count", word_count)
    ppx = np.exp(loss_sum / word_count)
    ppx_document = np.exp(ppx_sum / doc_count)
    print("ppx", ppx)
    print("ppx_document", ppx_document)
    topic_words = show_topic_words(gsm.module,
                                   args.topic_num,
                                   device,
                                   dictionary.id2token,
                                   topic_id=None,
                                   topK=10)
    # evaluate_topic_quality(topic_words, docs, dictionary, taskname="unilm", calc4each=False)
    topic_diversity = calc_topic_diversity(topic_words)
    print("topic_diversity", topic_diversity)
    # print('\n'.join([str(lst) for lst in topic_words]))
    # print('='*30)

    if args.output_file:
        fn_out = args.output_file
    else:
        fn_out = args.unilm_model_recover_path + '.' + args.split
    with open(fn_out, "w", encoding="utf-8") as fout:
        for l in output_lines:
            fout.write(l)
            fout.write("\n")

    if args.need_score_traces:
        with open(fn_out + ".trace.pickle", "wb") as fout_trace:
            pickle.dump({
                "version": 0.0,
                "num_samples": len(input_lines)
            }, fout_trace)
            for x in score_trace_list:
                pickle.dump(x, fout_trace)
Exemple #8
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.",
    )
    parser.add_argument(
        "--config_path", default=None, type=str, help="Bert config file path."
    )
    parser.add_argument(
        "--output_dir",
        default="tmp",
        type=str,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--log_file",
        default="eval.log",
        type=str,
        help="The output directory where the log will be written.",
    )
    parser.add_argument(
        "--model_recover_path",
        default=None,
        type=str,
        help="The file of fine-tuned pretraining model.",
    )
    parser.add_argument(
        "--do_train",
        action="store_true",
        help="Whether to run training. This should ALWAYS be set to True.",
    )
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.",
    )
    parser.add_argument(
        "--train_batch_size",
        default=64,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=3e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--label_smoothing",
        default=0,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--weight_decay",
        default=0.01,
        type=float,
        help="The weight decay rate for Adam.",
    )
    parser.add_argument(
        "--finetune_decay",
        action="store_true",
        help="Weight decay to the original weights.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=30,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--global_rank",
        type=int,
        default=-1,
        help="global_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--fp32_embedding",
        action="store_true",
        help="Whether to use 32-bit float precision instead of 32-bit for embeddings",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--amp", action="store_true", help="Whether to use amp for fp16"
    )
    parser.add_argument(
        "--from_scratch",
        action="store_true",
        help="Initialize parameters with random values (i.e., training from scratch).",
    )
    parser.add_argument(
        "--new_segment_ids",
        action="store_true",
        help="Use new segment ids for bi-uni-directional LM.",
    )
    parser.add_argument(
        "--tokenized_input", action="store_true", help="Whether the input is tokenized."
    )
    parser.add_argument(
        "--len_vis_input",
        type=int,
        default=100,
        help="The length of visual token input",
    )
    parser.add_argument(
        "--max_len_b",
        type=int,
        default=20,
        help="Truncate_config: maximum length of segment B.",
    )
    parser.add_argument(
        "--trunc_seg",
        default="b",
        help="Truncate_config: first truncate segment A/B (option: a, b).",
    )
    parser.add_argument(
        "--always_truncate_tail",
        action="store_true",
        help="Truncate_config: Whether we should always truncate tail.",
    )
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help="Number of prediction is sometimes less than max_pred when sequence is short.",
    )
    parser.add_argument(
        "--max_pred", type=int, default=3, help="Max tokens of prediction."
    )
    parser.add_argument(
        "--num_workers",
        default=4,
        type=int,
        help="Number of workers for the data loader.",
    )
    parser.add_argument(
        "--max_position_embeddings",
        type=int,
        default=None,
        help="max position embeddings",
    )

    # Others for VLP
    parser.add_argument(
        "--src_file",
        default=["/mnt/dat/COCO/annotations/dataset_coco.json"],
        type=str,
        nargs="+",
        help="The input data file name.",
    )
    parser.add_argument("--enable_visdom", action="store_true")
    parser.add_argument("--visdom_port", type=int, default=8888)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument("--image_root", type=str, default="/mnt/dat/COCO/images")
    parser.add_argument(
        "--dataset", default="coco", type=str, help="coco | flickr30k | cc"
    )
    parser.add_argument("--split", type=str, nargs="+", default=["train", "restval"])

    parser.add_argument(
        "--world_size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument(
        "--dist_url",
        default="file://[PT_OUTPUT_DIR]/nonexistent_file",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--file_valid_jpgs",
        default="/mnt/dat/COCO/annotations/coco_valid_jpgs.json",
        type=str,
    )
    parser.add_argument(
        "--sche_mode",
        default="warmup_linear",
        type=str,
        help="warmup_linear | warmup_constant | warmup_cosine",
    )
    parser.add_argument("--drop_prob", default=0.1, type=float)
    parser.add_argument("--use_num_imgs", default=-1, type=int)
    parser.add_argument("--vis_mask_prob", default=0, type=float)
    parser.add_argument("--max_drop_worst_ratio", default=0, type=float)
    parser.add_argument("--drop_after", default=6, type=int)

    parser.add_argument(
        "--s2s_prob",
        default=1,
        type=float,
        help="Percentage of examples that are bi-uni-directional LM (seq2seq).",
    )
    parser.add_argument(
        "--bi_prob",
        default=0,
        type=float,
        help="Percentage of examples that are bidirectional LM.",
    )
    parser.add_argument(
        "--enable_butd", action="store_true", help="set to take in region features"
    )
    parser.add_argument(
        "--region_bbox_file",
        default="coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5",
        type=str,
    )
    parser.add_argument(
        "--region_det_file_prefix",
        default="feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval",
        type=str,
    )
    parser.add_argument("--tasks", default="img2txt", help="img2txt | vqa2")
    parser.add_argument(
        "--relax_projection",
        action="store_true",
        help="Use different projection layers for tasks.",
    )
    parser.add_argument(
        "--scst", action="store_true", help="Self-critical sequence training"
    )

    args = parser.parse_args()

    print("global_rank: {}, local rank: {}".format(args.global_rank, args.local_rank))

    args.max_seq_length = (
        args.max_len_b + args.len_vis_input + 3
    )  # +3 for 2x[SEP] and [CLS]
    args.mask_image_regions = (
        args.vis_mask_prob > 0
    )  # whether to mask out image regions
    args.dist_url = args.dist_url.replace("[PT_OUTPUT_DIR]", args.output_dir)

    # arguments inspection
    assert args.tasks in ("img2txt", "vqa2")
    assert args.enable_butd == True, "only support region attn! featmap attn deprecated"
    assert (not args.scst) or args.dataset == "coco", "scst support on coco only!"
    if args.scst:
        assert args.dataset == "coco", "scst support on coco only!"
        assert args.max_pred == 0 and args.mask_prob == 0, "no mask for scst!"
        rl_crit = RewardCriterion()

    if args.enable_butd:
        assert args.len_vis_input == 100
        args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
        args.region_det_file_prefix = (
            os.path.join(args.image_root, args.region_det_file_prefix)
            if args.dataset in ("cc", "coco") and args.region_det_file_prefix != ""
            else ""
        )

    # output config
    os.makedirs(args.output_dir, exist_ok=True)
    json.dump(
        args.__dict__,
        open(os.path.join(args.output_dir, "eval_opt.json"), "w"),
        sort_keys=True,
        indent=2,
    )

    logging.basicConfig(
        filename=os.path.join(args.output_dir, args.log_file),
        filemode="w",
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger = logging.getLogger(__name__)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.global_rank,
        )
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                args.gradient_accumulation_steps
            )
        )

    args.train_batch_size = int(
        args.train_batch_size / args.gradient_accumulation_steps
    )

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # plotting loss, optional
    if args.enable_visdom:
        import visdom

        vis = visdom.Visdom(port=args.visdom_port, env=args.output_dir)
        vis_window = {"iter": None, "score": None}

    # preprocessing/data loader
    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + "/.pretrained_model_{}".format(args.global_rank),
    )
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer

    if args.do_train:
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    "max_len_b": args.max_len_b,
                    "trunc_seg": args.trunc_seg,
                    "always_truncate_tail": args.always_truncate_tail,
                },
                mask_image_regions=args.mask_image_regions,
                mode="s2s",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == "vqa2"),
            )
        ]
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    "max_len_b": args.max_len_b,
                    "trunc_seg": args.trunc_seg,
                    "always_truncate_tail": args.always_truncate_tail,
                },
                mask_image_regions=args.mask_image_regions,
                mode="bi",
                len_vis_input=args.len_vis_input,
                vis_mask_prob=args.vis_mask_prob,
                enable_butd=args.enable_butd,
                region_bbox_file=args.region_bbox_file,
                region_det_file_prefix=args.region_det_file_prefix,
                local_rank=args.local_rank,
                load_vqa_ann=(args.tasks == "vqa2"),
            )
        )

        train_dataset = seq2seq_loader.Img2txtDataset(
            args.src_file,
            args.image_root,
            args.split,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_valid_jpgs=args.file_valid_jpgs,
            bi_uni_pipeline=bi_uni_pipeline,
            use_num_imgs=args.use_num_imgs,
            s2s_prob=args.s2s_prob,
            bi_prob=args.bi_prob,
            enable_butd=args.enable_butd,
            tasks=args.tasks,
        )

        if args.world_size == 1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=batch_list_to_batch_tensors,
            pin_memory=True,
        )

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    t_total = int(
        len(train_dataloader)
        * args.num_train_epochs
        * 1.0
        / args.gradient_accumulation_steps
    )

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp

        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    relax_projection = 4 if args.relax_projection else 0
    task_idx_proj = 3 if args.tasks == "img2txt" else 0
    mask_word_id, eos_word_ids, pad_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]", "[PAD]"]
    )  # index in BERT vocab: 103, 102, 0

    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        assert args.scst == False, "must init from maximum likelihood training"
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            relax_projection=relax_projection,
            config_path=args.config_path,
            task_idx=task_idx_proj,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            cache_dir=args.output_dir
            + "/.pretrained_model_{}".format(args.global_rank),
            drop_prob=args.drop_prob,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input,
            tasks=args.tasks,
        )
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(
                os.path.join(args.output_dir, "model.{0}.bin".format(recover_step))
            )
            # recover_step == number of epochs
            global_step = math.floor(
                recover_step * t_total * 1.0 / args.num_train_epochs
            )
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****", args.model_recover_path)
            model_recover = torch.load(args.model_recover_path)
            global_step = 0
        if not args.scst:
            model = BertForPreTrainingLossMask.from_pretrained(
                args.bert_model,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                relax_projection=relax_projection,
                config_path=args.config_path,
                task_idx=task_idx_proj,
                max_position_embeddings=args.max_position_embeddings,
                label_smoothing=args.label_smoothing,
                fp32_embedding=args.fp32_embedding,
                cache_dir=args.output_dir
                + "/.pretrained_model_{}".format(args.global_rank),
                drop_prob=args.drop_prob,
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
                tasks=args.tasks,
            )
        else:
            model = BertForSeq2SeqDecoder.from_pretrained(
                args.bert_model,
                max_position_embeddings=args.max_position_embeddings,
                config_path=args.config_path,
                state_dict=model_recover,
                num_labels=cls_num_labels,
                type_vocab_size=type_vocab_size,
                task_idx=task_idx_proj,
                mask_word_id=mask_word_id,
                search_beam_size=1,
                eos_id=eos_word_ids,
                mode="s2s",
                enable_butd=args.enable_butd,
                len_vis_input=args.len_vis_input,
            )

        del model_recover
        torch.cuda.empty_cache()

    # deprecated
    # from vlp.resnet import resnet
    # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

    if args.fp16:
        model.half()
        # cnn.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    # cnn.to(device)
    if args.local_rank != -1:
        try:
            # from apex.parallel import DistributedDataParallel as DDP
            from torch.nn.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
        # cnn = DDP(cnn)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)
        # cnn = DataParallelImbalance(cnn)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    if args.fp16:
        try:
            # from apex.optimizers import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            bias_correction=False,
            max_grad_norm=1.0,
        )
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(
                optimizer, static_loss_scale=args.loss_scale
            )
    else:
        optimizer = BertAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            warmup=args.warmup_proportion,
            schedule=args.sche_mode,
            t_total=t_total,
        )

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(
            os.path.join(args.output_dir, "optim.{0}.bin".format(recover_step))
        )
        if hasattr(optim_recover, "state_dict"):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        model.eval()

        losses = []
        for batch in tqdm(train_dataloader):
            # wrangle batch
            batch = [t.to(device) for t in batch]
            (
                input_ids,
                segment_ids,
                input_mask,
                lm_label_ids,
                masked_pos,
                masked_weights,
                is_next,
                task_idx,
                img,
                vis_masked_pos,
                vis_pe,
                ans_labels,
            ) = batch

            if args.fp16:
                img = img.half()
                vis_pe = vis_pe.half()

            if args.enable_butd:
                conv_feats = img.data  # Bx100x2048
                vis_pe = vis_pe.data
            else:
                conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                conv_feats = (
                    conv_feats.view(conv_feats.size(0), conv_feats.size(1), -1)
                    .permute(0, 2, 1)
                    .contiguous()
                )

            # compute loss
            masked_lm_loss, _, _ = model(
                conv_feats,
                vis_pe,
                input_ids,
                segment_ids,
                input_mask,
                lm_label_ids,
                ans_labels,
                is_next,
                masked_pos=masked_pos,
                masked_weights=masked_weights,
                task_idx=task_idx,
                vis_masked_pos=vis_masked_pos,
                mask_image_regions=args.mask_image_regions,
                drop_worst_ratio=args.max_drop_worst_ratio
            )

            # average across multiple GPUs
            if n_gpu > 1:
                masked_lm_loss = masked_lm_loss.mean()

            losses.append(masked_lm_loss.item())
        
        print(args.split, 'perplexity:', np.exp(np.mean(losses)))
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased"
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=20,
                        help="maximum length of target sequence")

    # Others for VLP
    parser.add_argument("--src_file",
                        default='/mnt/dat/COCO/annotations/dataset_coco.json',
                        type=str,
                        help="The input data file name.")
    parser.add_argument('--dataset',
                        default='coco',
                        type=str,
                        help='coco | flickr30k | cc')
    parser.add_argument('--len_vis_input', type=int, default=100)
    # parser.add_argument('--resnet_model', type=str, default='imagenet_weights/resnet101.pth')
    parser.add_argument('--image_root',
                        type=str,
                        default='/mnt/dat/COCO/images')
    parser.add_argument('--split', type=str, default='val')
    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument(
        '--region_bbox_file',
        default=
        'coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
        type=str)
    parser.add_argument(
        '--region_det_file_prefix',
        default=
        'feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval',
        type=str)
    parser.add_argument('--file_valid_jpgs', default='', type=str)

    args = parser.parse_args()

    if args.enable_butd:
        assert (args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root,
                                             args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(
            args.image_root, args.region_det_file_prefix) if args.dataset in (
                'cc', 'coco',
                'fake_media') and args.region_det_file_prefix != '' else ''

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    args.max_seq_length = args.max_tgt_length + args.len_vis_input + 3  # +3 for 2x[SEP] and [CLS]
    tokenizer.max_len = args.max_seq_length

    bi_uni_pipeline = []
    bi_uni_pipeline.append(
        seq2seq_loader.Preprocess4Seq2seqDecoder(
            list(tokenizer.vocab.keys()),
            tokenizer.convert_tokens_to_ids,
            args.max_seq_length,
            max_tgt_length=args.max_tgt_length,
            new_segment_ids=args.new_segment_ids,
            mode='s2s',
            len_vis_input=args.len_vis_input,
            enable_butd=args.enable_butd,
            region_bbox_file=args.region_bbox_file,
            region_det_file_prefix=args.region_det_file_prefix))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]"])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            max_position_embeddings=args.max_position_embeddings,
            config_path=args.config_path,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            enable_butd=args.enable_butd,
            len_vis_input=args.len_vis_input)
        del model_recover

        # from vlp.resnet import resnet
        # cnn = resnet(args.resnet_model, _num_layers=101, _fixed_block=4, pretrained=True) # no finetuning

        if args.fp16:
            model.half()
            # cnn.half()
        model.to(device)
        # cnn.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
            # cnn = torch.nn.DataParallel(cnn)

        torch.cuda.empty_cache()
        model.eval()
        # cnn.eval()

        eval_lst = []
        with open(args.src_file, "r", encoding='utf-8') as f_src:
            img_dat = json.load(f_src)['images']
            img_idx = 0
            valid_jpgs = None if (args.file_valid_jpgs == '' or args.dataset in \
                ('coco', 'flickr30k')) else json.load(open(args.file_valid_jpgs))
            for src in img_dat:
                if src['split'] == args.split and (
                        valid_jpgs is None or src['filename'] in valid_jpgs):
                    if args.enable_butd:
                        src_tk = os.path.join(args.image_root,
                                              src.get('filepath', 'trainval'),
                                              src['filename'][:-4] + '.npy')
                    else:
                        src_tk = os.path.join(args.image_root,
                                              src.get('filepath', 'trainval'),
                                              src['filename'])
                    if args.dataset == 'coco':
                        imgid = int(src['filename'].split('_')[2][:-4])
                    elif args.dataset == 'cc' or args.dataset == 'fake_media':
                        imgid = int(src['imgid'])
                    elif args.dataset == 'flickr30k':
                        imgid = int(src['filename'].split('.')[0])
                    eval_lst.append(
                        (img_idx, imgid, src_tk))  # id and path for COCO
                    img_idx += 1
        input_lines = eval_lst
        predictions = {}

        print('start the caption evaluation...')
        total_batch = math.ceil(
            len(input_lines) / args.batch_size) * len(SEED_PHRASES)

        # output_lines = [""] * len(input_lines)
        with tqdm(total=total_batch) as pbar:
            for seed in SEED_PHRASES:
                seed_ids = tokenizer.convert_tokens_to_ids(
                    tokenizer.tokenize(seed))
                next_i = 0
                while next_i < len(input_lines):
                    _chunk = input_lines[next_i:next_i + args.batch_size]
                    buf_id = [x[0] for x in _chunk]
                    buf = [x[2] for x in _chunk]
                    next_i += args.batch_size
                    instances = []
                    for instance in [(x, args.len_vis_input) for x in buf]:
                        for proc in bi_uni_pipeline:
                            instances.append(proc(instance))
                    with torch.no_grad():
                        batch = batch_list_to_batch_tensors(instances)
                        batch = [t.to(device) for t in batch]
                        input_ids, token_type_ids, position_ids, input_mask, task_idx, img, vis_pe = batch

                        if args.fp16:
                            img = img.half()
                            vis_pe = vis_pe.half()

                        if args.enable_butd:
                            conv_feats = img.data  # Bx100x2048
                            vis_pe = vis_pe.data
                        else:
                            conv_feats, _ = cnn(img.data)  # Bx2048x7x7
                            conv_feats = conv_feats.view(
                                conv_feats.size(0), conv_feats.size(1),
                                -1).permute(0, 2, 1).contiguous()

                        traces = model(conv_feats,
                                       vis_pe,
                                       input_ids,
                                       token_type_ids,
                                       position_ids,
                                       input_mask,
                                       task_idx=task_idx,
                                       seed_ids=seed_ids)
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces[0].tolist()
                        for i in range(len(buf)):
                            w_ids = output_ids[i]
                            output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                            output_tokens = []
                            for t in output_buf:
                                if t in ("[SEP]", "[PAD]"):
                                    break
                                output_tokens.append(t)
                            output_sequence = ' '.join(
                                detokenize(output_tokens))
                            img_id = input_lines[buf_id[i]][1]
                            if img_id not in predictions:
                                predictions[img_id] = [output_sequence]
                            else:
                                predictions[img_id].append(output_sequence)
                            # output_lines[buf_id[i]] = output_sequence

                    pbar.update(1)

        # predictions = [{'image_id': tup[1], 'caption': output_lines[img_idx]} for img_idx, tup in enumerate(input_lines)]

        # output captions to file
        output_dir = os.path.dirname(args.model_recover_path)
        caption_output_path = os.path.join(
            output_dir,
            args.split + '_captions_beam' + str(args.beam_size) + '.json')

        with open(caption_output_path, 'w') as caption_out_fp:
            json.dump(predictions, caption_out_fp, indent=4)
Exemple #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='coco', type=str, nargs='*', help='')
    parser.add_argument('--lang', default='en zh', type=str, nargs='*', help='')
    parser.add_argument('--max_len_en', default=25, type=int, help='maximum length of English in **bilingual** corpus')
    parser.add_argument('--max_len_zh', default=25, type=int, help='maximum length of Chinese in **bilingual** corpus')
    parser.add_argument('--max_len_en_cap', default=25, type=int, help='maximum length of English in **img2txt** corpus')
    parser.add_argument('--max_len_zh_cap', default=25, type=int, help='maximum length of Chinese in **img2txt** corpus')
    parser.add_argument('--len_vis_input', type=int, default=100, help="The length of visual token input")
    parser.add_argument("--src_file", default='$DATA_ROOT/{}/annotations/{}_dataset.json',
                        type=str, help="The input data file name.")
    parser.add_argument('--split', type=str, default='val')
    parser.add_argument('--file_valid_jpgs', default='$DATA_ROOT/{}/annotations/{}_valid_jpgs.json', type=str)
    parser.add_argument('--image_root', type=str, default='$DATA_ROOT/{}/region_feat_gvd_wo_bgd')
    parser.add_argument('--region_bbox_file', default='raw_bbox/{}_detection_vg_100dets_vlp_checkpoint_trainval_bbox', type=str)
    parser.add_argument('--region_det_file_prefix', default='feat_cls_1000/{}_detection_vg_100dets_vlp_checkpoint_trainval', type=str)

    # General
    parser.add_argument("--config_path", default=None, type=str,
                        help="Bert config file path.")
    parser.add_argument("--bert_model", default="bert-base-cased", type=str,
                        help="Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased.")
    parser.add_argument("--xml_vocab",type=str, default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",type=str, default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings', type=int, default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
     #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed', type=int, default=123,
                        help="random seed for initialization")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids', action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty', type=float, default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word', type=str, default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--max_tgt_length', type=int, default=20,
                        help="maximum length of target sequence")


    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd', action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    args = parser.parse_args()
    dataset = {}
    for d in args.dataset:
        assert d in ['coco','aic','wmt']
        if d == 'coco':
            dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_en_cap}
        elif d == 'aic':
            dataset[d] = {'max_len_a': args.len_vis_input, 'max_len_b': args.max_len_zh_cap}
        else:# d == 'wmt':
            dataset[d] = {'max_len_a': args.max_len_en, 'max_len_b': args.max_len_zh}
        dataset[d]['max_seq_length'] = dataset[d]['max_len_a'] + dataset[d]['max_len_b'] + 3
    args.dataset = dataset
    lang2cap_max_seq_length = {'zh':args.len_vis_input+args.max_len_zh_cap+3, 'en':args.len_vis_input+args.max_len_en_cap+3}

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    if args.enable_butd:
        assert(args.len_vis_input == 100)
        args.region_bbox_file = os.path.join(args.image_root, args.region_bbox_file)
        args.region_det_file_prefix = os.path.join(args.image_root, args.region_det_file_prefix) 

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir+'/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    #tokenizer_en= WhitespaceTokenizer() if args.tokenized_input else tokenizer_en
    tokenizers = {'en':tokenizer_en}
    if 'aic' in args.dataset or 'wmt' in args.dataset:
        tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
        tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(x, lang='zh', bypass_tokenizer=True)
        with open(args.xml_vocab,'r') as f:
            tokenizer_zh.vocab = json.load(f)
        tokenizers['zh'] = tokenizer_zh
    if 'coco_g8_lr3e-5_batch512_ft_from_s0.75_b0.25' in args.model_recover_path:
        indexer = Indexer([os.path.join(args.bert_model,'vocab.txt')])
    else:
        indexer = Indexer([os.path.join(args.bert_model,'vocab.txt'), args.xml_vocab])

    for corpus in args.dataset:
        if corpus in ['coco','aic']:
            #bilingual
            for lang in args.lang:
                tokenizer = tokenizers[lang]#tokenizers['en'] if corpus=='coco' else tokenizers['zh']
                max_seq_length = lang2cap_max_seq_length[lang]
                decode_pipeline= [seq2seq_loader.Preprocess4Seq2seqDecoder(
                    corpus,  
                    lang,
                    list(tokenizer.vocab.keys()), indexer, 
                    max_len=max_seq_length,
                    max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids,
                    mode='s2s', len_vis_input=args.len_vis_input, enable_butd=args.enable_butd,
                    region_bbox_file=args.region_bbox_file.format(corpus.upper(), corpus.lower()), 
                    region_det_file_prefix=args.region_det_file_prefix.format(corpus.upper(), corpus.lower()))]
                eval_dataset = seq2seq_loader.Img2txtDataset(
                                        args.src_file.format(corpus.upper(), corpus.lower()),
                                        args.image_root.format(corpus.upper()), 
                                        args.split, args.batch_size,
                                        tokenizer,
                                        max_seq_length, 
                                        preprocessed=True,
                                        file_valid_jpgs=args.file_valid_jpgs.format(corpus.upper(), corpus.lower()),
                                        bi_uni_pipeline=decode_pipeline, use_num_imgs=-1,
                                        s2s_prob=1, bi_prob=0,
                                        enable_butd=args.enable_butd, tasks='img2txt')
                args.dataset[corpus][lang+'_eval_dataloader'] = torch.utils.data.DataLoader(
                            eval_dataset, batch_size=args.batch_size,
                            sampler=SequentialSampler(eval_dataset), num_workers=4, 
                            collate_fn=batch_list_to_batch_tensors, pin_memory=True)
        else:
            raise NotImplementedError # only support aic and coco now

    amp_handle = None
    if args.amp:
        from apex import amp
    #    amp_handle = amp.init(enable_caching=True)
    #    logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    #type_vocab_size = 6 if args.new_segment_ids else 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(
        ["[MASK]", "[SEP]"])
    forbid_ignore_set = None #default None
    relax_projection, task_idx_proj = 0, 3
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(indexer(w_list))
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model,
            max_position_embeddings=args.max_position_embeddings, config_path=args.config_path,
            state_dict=model_recover, num_labels=cls_num_labels,
            vocab_size=len(indexer),
            type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, #img2txt
            search_beam_size=args.beam_size, length_penalty=args.length_penalty,
            eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len,
            enable_butd=args.enable_butd, len_vis_input=args.len_vis_input)

        del model_recover

        model.to(device)

        if args.amp:
            model = amp.initialize(model, opt_level='O2')#'02')
        torch.cuda.empty_cache()
        model.eval()
        for corpus in args.dataset:
            for lang in ['en','zh']:
                if not lang+'_eval_dataloader' in args.dataset[corpus]:
                    continue
                print('corpus {} lang {}'.format(corpus, lang))
                output_lines = {}
                val_iter_bar = tqdm(args.dataset[corpus][lang+'_eval_dataloader'])
                for step_val, val_iter_output in enumerate(val_iter_bar):
                    info_, batch = val_iter_output[0],val_iter_output[1]
                    with torch.no_grad():
                        batch = [t.to(device) for t in batch]
                        input_ids, segment_ids, position_ids, input_mask, task_idx, img, vis_pe = batch
                        # if step_val==0:
                        #     print(segment_ids[0][100:])
                        #     input()
                        #input()
                        if args.enable_butd:
                            conv_feats = img.data # Bx100x2048
                            vis_pe = vis_pe.data
                        else:
                            conv_feats, _ = cnn(img.data) # Bx2048x7x7
                            conv_feats = conv_feats.view(conv_feats.size(0), conv_feats.size(1),
                                -1).permute(0,2,1).contiguous()
                        if args.amp:
                            conv_feats = conv_feats.half()
                            vis_pe = vis_pe.half()

                        traces = model(conv_feats, vis_pe, input_ids, segment_ids, position_ids, input_mask, 
                            search_beam_size=args.beam_size, task_idx=task_idx, sample_mode='greedy') #validation greedy
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces[0].tolist()
                        for ii,w_ids in enumerate(output_ids):
                            output_buf = indexer.convert_ids_to_tokens(w_ids)
                            output_tokens = []
                            for t in output_buf:
                                if t in ("[SEP]", "[PAD]"):
                                    break
                                output_tokens.append(t)
                            output_sequence = ' '.join(detokenize(output_tokens))
                            if corpus=='coco':
                                id_ = int(info_[ii][2].split('_')[2])
                            else:
                                id_ = info_[ii][2]
                            #print(id_,output_sequence)
                            output_lines[id_] = output_sequence
                predictions = [{'image_id': ids_, 'caption': output_lines[ids_]} for ids_ in output_lines]
                with open(os.path.join(args.output_dir,'{}_{}_{}_predictions.json').format(args.split, corpus, lang),'w') as f:
                    json.dump(predictions, f)
                if (corpus=='coco' and lang=='en') or (corpus=='aic' and lang=='zh'):
                    print('Begin evaluating '+corpus)
                    lang_stats = language_eval(corpus, predictions, 
                        args.model_recover_path.split('/')[-2]+'-'+args.split+'-'+args.model_recover_path.split('/')[-1].split('.')[-2], 
                        args.split,
                        ['Bleu','METEOR','Rouge','CIDEr'])
                    with open(os.path.join(args.output_dir,'{}_{}_{}_scores.json').format(args.split, corpus,lang),'w') as f:
                        json.dump(lang_stats, f)
Exemple #11
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")

    # decoding parameters
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split",
                        type=str,
                        default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--top_k',
                        type=int,
                        default=1,
                        help="Top k for output")
    parser.add_argument('--top_kk',
                        type=int,
                        default=0,
                        help="Top k sample method for output")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode',
                        default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length',
                        type=int,
                        default=128,
                        help="maximum length of target sequence")

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1."
        )
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    if args.mode == "s2s" or args.mode == "both":
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seqDecoder(
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                max_tgt_length=args.max_tgt_length,
                new_segment_ids=args.new_segment_ids,
                mode="s2s"))
    if args.mode == "l2r" or args.mode == "both":
        bi_uni_pipeline.append(
            seq2seq_loader.Preprocess4Seq2seqDecoder(
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                max_tgt_length=args.max_tgt_length,
                new_segment_ids=args.new_segment_ids,
                mode="l2r"))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]"])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_recover_path)
    for model_recover_path in glob.glob(args.model_recover_path.strip()):
        logger.info("***** Recover model: %s *****", model_recover_path)
        model_recover = torch.load(model_recover_path)
        model = BertForSeq2SeqDecoder.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=pair_num_relation,
            type_vocab_size=type_vocab_size,
            task_idx=3,
            mask_word_id=mask_word_id,
            search_beam_size=args.beam_size,
            length_penalty=args.length_penalty,
            eos_id=eos_word_ids,
            forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
            forbid_ignore_set=forbid_ignore_set,
            ngram_size=args.ngram_size,
            min_len=args.min_len,
            mode=args.mode,
            max_position_embeddings=args.max_seq_length,
            top_kk=args.top_kk)
        del model_recover

        if args.fp16:
            model.half()
        model.to(device)
        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        torch.cuda.empty_cache()
        model.eval()
        next_i = 0
        max_src_length = args.max_seq_length - 2 - args.max_tgt_length

        with open(args.input_file, encoding="utf-8") as fin:
            input_lines = [x.strip() for x in fin.readlines()]
        data_tokenizer = WhitespaceTokenizer(
        ) if args.tokenized_input else tokenizer
        input_lines = [
            data_tokenizer.tokenize(x)[:max_src_length] for x in input_lines
        ]
        input_lines = sorted(list(enumerate(input_lines)),
                             key=lambda x: -len(x[1]))
        output_lines = [""] * len(input_lines)
        score_trace_list = [None] * len(input_lines)
        total_batch = math.ceil(len(input_lines) / args.batch_size)

        with tqdm(total=total_batch) as pbar:
            while next_i < len(input_lines):
                _chunk = input_lines[next_i:next_i + args.batch_size]
                buf_id = [x[0] for x in _chunk]
                buf = [x[1] for x in _chunk]
                next_i += args.batch_size
                max_a_len = max([len(x) for x in buf])
                instances = []
                for instance in [(x, max_a_len) for x in buf]:
                    for proc in bi_uni_pipeline:
                        instances.append(proc(instance))
                with torch.no_grad():
                    batch = long_loader.batch_list_to_batch_tensors(instances)
                    batch = [t.to(device) for t in batch]
                    input_ids, token_type_ids, position_ids, input_mask, task_idx = batch
                    traces = model(input_ids,
                                   token_type_ids,
                                   position_ids,
                                   input_mask,
                                   task_idx=task_idx)
                    if args.beam_size > 1:
                        traces = {k: v.tolist() for k, v in traces.items()}
                        output_ids = traces['pred_seq']
                    else:
                        output_ids = traces.tolist()
                    for i in range(len(buf)):
                        scores = traces['scores'][i]
                        wids_list = traces['wids'][i]
                        ptrs = traces['ptrs'][i]
                        eos_id = 102
                        top_k = args.top_k
                        # first we need to find the eos frame where all symbols are eos
                        # any frames after the eos frame are invalid
                        last_frame_id = len(scores) - 1
                        for _i, wids in enumerate(wids_list):
                            if all(wid == eos_id for wid in wids):
                                last_frame_id = _i
                                break
                        frame_id = -1
                        pos_in_frame = -1
                        seqs = []
                        for fid in range(last_frame_id + 1):
                            for _i, wid in enumerate(wids_list[fid]):
                                if wid == eos_id or fid == last_frame_id:
                                    s = scores[fid][_i]

                                    frame_id = fid
                                    pos_in_frame = _i

                                    if frame_id != -1 and s < 0:
                                        seq = [
                                            wids_list[frame_id][pos_in_frame]
                                        ]
                                        for _fid in range(frame_id, 0, -1):
                                            pos_in_frame = ptrs[_fid][
                                                pos_in_frame]
                                            seq.append(
                                                wids_list[_fid -
                                                          1][pos_in_frame])
                                        seq.reverse()
                                        seqs.append([seq, s])
                        seqs = sorted(seqs, key=lambda x: x[1], reverse=True)
                        w_idss = [seq[0] for seq in seqs[:top_k]]
                        output_sequences = []
                        for w_ids in w_idss:
                            output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                            output_tokens = []
                            for t in output_buf:
                                if t in ("[SEP]", "[PAD]"):
                                    break
                                output_tokens.append(t)
                            output_sequence = ' '.join(
                                detokenize(output_tokens))
                            output_sequences.append(output_sequence)
                        output_lines[buf_id[i]] = output_sequences
                        if args.need_score_traces:
                            score_trace_list[buf_id[i]] = {
                                'scores': traces['scores'][i],
                                'wids': traces['wids'][i],
                                'ptrs': traces['ptrs'][i]
                            }
                pbar.update(1)
        if args.output_file:
            fn_out = args.output_file
        else:
            fn_out = model_recover_path + '.' + args.split
        with open(fn_out, "w", encoding="utf-8") as fout:
            for l in output_lines:
                fout.write('\t'.join(l))
                fout.write("\n")

        if args.need_score_traces:
            with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                pickle.dump({
                    "version": 0.0,
                    "num_samples": len(input_lines)
                }, fout_trace)
                for x in score_trace_list:
                    pickle.dump(x, fout_trace)
Exemple #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--enable_vis',
                        action='store_true',
                        help='whether to visualize visual embedding')
    parser.add_argument('--num_concept',
                        default='-1',
                        type=int,
                        help='number of concepts to visualize')
    parser.add_argument('--num_vis',
                        default='1',
                        type=int,
                        help='number of visual embeddings per concept')
    parser.add_argument('--dataset',
                        default='txt',
                        type=str,
                        help='txt -> self-customized')
    parser.add_argument('--src_lang', default='en', type=str, help='')
    parser.add_argument('--tgt_lang', default='zh', type=str, help='')
    parser.add_argument(
        '--max_len_en',
        default=25,
        type=int,
        help='maximum length of English in **bilingual** corpus')
    parser.add_argument(
        '--max_len_zh',
        default=25,
        type=int,
        help='maximum length of Chinese in **bilingual** corpus')
    #parser.add_argument("--vocab_file", default='./src.txt', type=str, help="The input data file name.")
    parser.add_argument('--vocab_file', type=str, required=True, nargs='+')
    parser.add_argument('--en_first',
                        action='store_true',
                        help='always to put english as the first sentence')

    # General
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--bert_model",
        default="bert-base-cased",
        type=str,
        help=
        "Bert pre-trained model selected in the list: bert-base-cased, bert-large-cased."
    )
    parser.add_argument("--xml_vocab",
                        type=str,
                        default='./download_models/xml_vocab.json')
    parser.add_argument("--xml_merge",
                        type=str,
                        default='./download_models/xml_merges.txt')
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=512,
                        help="max position embeddings")

    # For decoding
    #parser.add_argument('--fp16', action='store_true',
    #                   help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help="random seed for initialization")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size',
                        type=int,
                        default=1,
                        help="Beam size for searching")
    parser.add_argument('--length_penalty',
                        type=float,
                        default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word',
                        type=str,
                        default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--ngram_size', type=int, default=3)

    parser.add_argument('--drop_prob', default=0.1, type=float)
    parser.add_argument('--enable_butd',
                        action='store_true',
                        help='set to take in region features')
    parser.add_argument('--output_dir', default='./result', type=str)

    #useless
    parser.add_argument('--split', type=str, default='val')  #wmt?
    parser.add_argument('--len_vis_input',
                        type=int,
                        default=100,
                        help="The length of visual token input")

    args = parser.parse_args()
    print(args.vocab_file)
    if args.enable_vis or '.pkl' in args.vocab_file[0]:
        args.vocab_file = args.vocab_file[0]
        assert '.pkl' in args.vocab_file, args.vocab_file

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    # fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer_en = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case=args.do_lower_case,
        cache_dir=args.output_dir + '/.pretrained_model')
    if args.max_position_embeddings:
        tokenizer_en.max_len = args.max_position_embeddings
    tokenizer_zh = XLMTokenizer(args.xml_vocab, args.xml_merge)
    tokenizer_zh.tokenize = lambda x: tokenizer_zh._tokenize(
        x, lang='zh', bypass_tokenizer=False)
    with open(args.xml_vocab, 'r') as f:
        tokenizer_zh.vocab = json.load(f)
    indexer = Indexer(
        [os.path.join(args.bert_model, 'vocab.txt'), args.xml_vocab])
    with open('full_vocab.json', 'w') as f:
        json.dump(indexer.ids_to_tokens, f)
    tokenizers = {'en': tokenizer_en, 'zh': tokenizer_zh}
    print('tokenizer created')

    concept_list = []
    if args.enable_vis or 'pkl' in args.vocab_file:
        with open(args.vocab_file, 'rb') as f:
            vocab_list = pickle.load(f)
        vocab_list = vocab_list[::-1]
        if args.num_concept == -1:
            args.num_concept = len(vocab_list)
        vocab = []  #[ [En, Zh, Vis1(list), Vis2, Vis3, Vis4,..] *num_concept]
        for key, ls in vocab_list[40:40 + args.num_concept]:
            concept = [ls[0][0], ls[0][1]]  #En, Zh
            concept_list.append(ls[0][0])
            for inst in ls[:args.num_vis]:
                concept.append(inst[2:])  #2,3,4
            vocab.append(concept)
        print('Number of Concept {}'.format(len(vocab)))
        if args.num_concept != -1:
            print(concept_list)
        print('Number of visual instance per concept {}'.format(
            len(vocab[0]) - 2))
        #print('Example {}'.format(vocab[0]))
        #input()
    else:
        vocab = []
        for filename in args.vocab_file:
            with open(filename) as f:
                v = f.readlines()
            if args.num_concept == -1:
                args.num_concept = len(v)
            v = [(a.split('    ')[0].strip(), a.split('    ')[-1].strip())
                 for a in v[:args.num_concept]]  #EN ZH
            vocab.extend(v)
        print('Number of vocabulary {} {}'.format(len(vocab), vocab[0]))

    cls_num_labels = 2
    type_vocab_size = 12 if args.new_segment_ids else 12
    mask_word_id, eos_word_ids = indexer(["[MASK]", "[SEP]"])
    forbid_ignore_set = None  #default None
    relax_projection, task_idx_proj = 0, 3
    print(args.model_recover_path)
    model_recover = torch.load(args.model_recover_path)
    model = BertForSeq2SeqDecoder.from_pretrained(
        args.bert_model,
        max_position_embeddings=args.max_position_embeddings,
        config_path=args.config_path,
        state_dict=model_recover,
        num_labels=2,
        vocab_size=len(indexer),
        type_vocab_size=type_vocab_size,
        task_idx=3,
        mask_word_id=mask_word_id,  #img2txt
        search_beam_size=args.beam_size,
        length_penalty=args.length_penalty,
        eos_id=eos_word_ids,
        forbid_duplicate_ngrams=args.forbid_duplicate_ngrams,
        forbid_ignore_set=forbid_ignore_set,
        ngram_size=args.ngram_size,
        min_len=args.min_len,
        enable_butd=True,
        len_vis_input=args.len_vis_input)
    del model_recover
    model.to(device)
    torch.cuda.empty_cache()
    model.eval()

    N_layer = 12
    embeddings = OrderedDict({
        'en': [[] for i in range(N_layer)],
        'zh': [[] for i in range(N_layer)]
    })
    if args.enable_vis:
        embeddings['vis'] = [[] for i in range(N_layer)]
    for _, pair in tqdm(enumerate(vocab)):
        for w, lang in zip(pair, ('en', 'zh')):
            segment_id = 1 if lang == 'en' else 6
            w_t = tokenizers[lang].tokenize(w)

            tokens = ['[CLS]'] + w_t + ['[SEP]']
            input_ids = indexer(tokens)
            token_type_ids = [segment_id] * len(input_ids)
            input_ids = np.expand_dims(np.array(input_ids), axis=0)
            token_type_ids = np.expand_dims(np.array(token_type_ids), axis=0)
            input_ids = torch.tensor(input_ids,
                                     dtype=torch.long,
                                     device=device)
            token_type_ids = torch.tensor(token_type_ids,
                                          dtype=torch.long,
                                          device=device)

            output_embeddings = model.compute_embeddings(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                mode='txt2txt')
            #tuple 12 1,L,768
            for i, e in enumerate(output_embeddings):
                e = e.detach().cpu().numpy()
                ave = np.mean(e[0, 1:-1, :], axis=0)  # 768
                embeddings[lang][i].append(ave)

        if args.enable_vis:
            instance_embeddings = [[] for layer_i in range(N_layer)]
            for vis_embed in pair[2:]:
                vis_feats, vis_pe, cls_label = vis_embed[0], vis_embed[
                    1], vis_embed[2]  #1024
                vis_feats = torch.from_numpy(vis_feats).to(device)
                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = torch.from_numpy(vis_pe).to(device)
                vis_pe = vis_pe.unsqueeze(0)
                cls_label = torch.from_numpy(cls_label).to(device)
                cls_label = cls_label.unsqueeze(0)  #
                # lazy normalization of the coordinates... copy from seq2seq
                w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5
                h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5
                vis_pe[:, [0, 2]] /= w_est
                vis_pe[:, [1, 3]] /= h_est
                assert h_est > 0, 'should greater than 0! {}'.format(h_est)
                assert w_est > 0, 'should greater than 0! {}'.format(w_est)
                rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] -
                                                            vis_pe[:, 0])
                rel_area.clamp_(0)

                vis_pe = torch.cat(
                    (vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]),
                    -1)  # confident score
                normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5,
                                               dim=-1)
                vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \
                    F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... #BL,H

                vis_feats = vis_feats.unsqueeze(0)
                vis_pe = vis_pe.unsqueeze(0)
                #print('input shape', vis_feats.shape, vis_pe.shape)
                segment_id = 0
                tokens = ['[CLS]', '[UNK]', '[SEP]']
                input_ids = indexer(tokens)
                token_type_ids = [segment_id] * len(input_ids)
                input_ids = np.expand_dims(np.array(input_ids), axis=0)
                token_type_ids = np.expand_dims(np.array(token_type_ids),
                                                axis=0)
                input_ids = torch.tensor(input_ids,
                                         dtype=torch.long,
                                         device=device)
                token_type_ids = torch.tensor(token_type_ids,
                                              dtype=torch.long,
                                              device=device)

                vis_embeddings = model.compute_embeddings(
                    vis_feats=vis_feats,
                    vis_pe=vis_pe,
                    input_ids=input_ids,
                    token_type_ids=token_type_ids,
                    mode='img2txt',
                    len_vis_input=1)
                #print(len(vis_embeddings), vis_embeddings[0].shape)
                #input()

                for i, e in enumerate(vis_embeddings):
                    e = e.detach().cpu().numpy()
                    ave = np.mean(e[0, 1:-1, :], axis=0)  # 768
                    instance_embeddings[i].append(ave)
                    #embeddings['vis'][i].append(ave)

            for i, embed_list in enumerate(instance_embeddings):
                if args.num_vis == 1:
                    embeddings['vis'][i].append(embed_list[0])
                else:
                    embeddings['vis'][i].append(embed_list)  #list of array
    # args.output_dir = os.path.join(args.output_dir, 'embedding_vis')
    # if not os.path.exists(args.output_dir):
    #     os.makedirs(args.output_dir)

    for ly in range(N_layer):
        embed = {'en': embeddings['en'][ly], 'zh': embeddings['zh'][ly]}
        if args.enable_vis:
            embed['vis'] = embeddings['vis'][ly]

        #save_numpy(embed,os.path.join(args.output_dir,'hiddenstates_layer_{}.npy'.format(ly)))
        if args.num_vis == 1:
            tSNE_reduce(
                embed, os.path.join(args.output_dir,
                                    'tSNE layer {}'.format(ly)))
            Cosine_Sim(
                embed,
                os.path.join(args.output_dir, 'CosineSim Layer {}'.format(ly)))
            sim_histgram(
                embed,
                os.path.join(args.output_dir, 'Histgram Layer {}'.format(ly)))
        if args.num_vis > 1:
            tSNE_reduce_visual(
                concept_list, embed,
                os.path.join(args.output_dir, 'tSNE layer {}'.format(ly)))
        print('Save layer {}'.format(ly))
Exemple #13
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--model_recover_path", default=None, type=str,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")

    # decoding parameters
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--amp', action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument("--input_file", type=str, help="Input file")
    parser.add_argument("--output_file", type=str, help="output file")
    parser.add_argument("--split", type=str, default="",
                        help="Data split (train/val/test).")
    parser.add_argument('--tokenized_input', action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--seed', type=int, default=123,
                        help="random seed for initialization")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids', action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--batch_size', type=int, default=4,
                        help="Batch size for decoding.")
    parser.add_argument('--beam_size', type=int, default=1,
                        help="Beam size for searching")
    parser.add_argument('--top_k', type=int, default=1,
                        help="Top k for output")
    parser.add_argument('--top_kk', type=int, default=0,
                        help="Top k sample method for output")
    parser.add_argument('--length_penalty', type=float, default=0,
                        help="Length penalty for beam search")

    parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
    parser.add_argument('--forbid_ignore_word', type=str, default=None,
                        help="Forbid the word during forbid_duplicate_ngrams")
    parser.add_argument("--min_len", default=None, type=int)
    parser.add_argument('--need_score_traces', action='store_true')
    parser.add_argument('--ngram_size', type=int, default=3)
    parser.add_argument('--mode', default="s2s",
                        choices=["s2s", "l2r", "both"])
    parser.add_argument('--max_tgt_length', type=int, default=128,
                        help="maximum length of target sequence")
    
    # evaluate parameters
    parser.add_argument('--do_predict', action='store_true', help="do_predict")
    parser.add_argument("--do_evaluate", action="store_true", help="caculate the scores if have label file")
    parser.add_argument("--label_file", type=str, default="")
    parser.add_argument("--experiment", type=str, default="full", help="full/title/title-l1/hierachical/title-first/title-first-rouge")

    # ranker parameters
    parser.add_argument("--ranker_recover_path", type=str, help="ranker model for extract sentence")
    parser.add_argument("--ranker_max_len", type=int, default=192, help ="max length of the ranker input")
    parser.add_argument("--ranker_batch_size", type=int, default=128)

    args = parser.parse_args()

    if args.need_score_traces and args.beam_size <= 1:
        raise ValueError(
            "Score trace is only available for beam search with beam size > 1.")
    if args.max_tgt_length >= args.max_seq_length - 2:
        raise ValueError("Maximum tgt length exceeds max seq length - 2.")

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case)

    tokenizer.max_len = args.max_seq_length

    pair_num_relation = 0
    bi_uni_pipeline = []
    if args.mode == "s2s" or args.mode == "both":
        bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(
            tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s"))
    if args.mode == "l2r" or args.mode == "both":
        bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder(list(
            tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="l2r"))
    
    if args.experiment == "segsep":
        bi_uni_pipeline = []
        bi_uni_pipeline.append(Preprocess4SegSepDecoder(list(
            tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, max_tgt_length=args.max_tgt_length, new_segment_ids=args.new_segment_ids, mode="s2s"))

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    cls_num_labels = 2
    type_vocab_size = 6 if args.new_segment_ids else 2
    if args.experiment == "segsep":
        type_vocab_size = 11
    mask_word_id, eos_word_ids = tokenizer.convert_tokens_to_ids(
        ["[MASK]", "[SEP]"])
    forbid_ignore_set = None
    if args.forbid_ignore_word:
        w_list = []
        for w in args.forbid_ignore_word.split('|'):
            if w.startswith('[') and w.endswith(']'):
                w_list.append(w.upper())
            else:
                w_list.append(w)
        forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
    print(args.model_recover_path)
    if args.do_predict:
        for model_recover_path in glob.glob(args.model_recover_path.strip()):
            logger.info("***** Recover model: %s *****", model_recover_path)
            model_recover = torch.load(model_recover_path)
            model = BertForSeq2SeqDecoder.from_pretrained(args.bert_model, state_dict=model_recover, num_labels=cls_num_labels, num_rel=pair_num_relation, type_vocab_size=type_vocab_size, task_idx=3, mask_word_id=mask_word_id, search_beam_size=args.beam_size,
                                                        length_penalty=args.length_penalty, eos_id=eos_word_ids, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, max_position_embeddings=args.max_seq_length)
            del model_recover

            if args.fp16:
                model.half()
            model.to(device)
            if n_gpu > 1:
                model = torch.nn.DataParallel(model)

            torch.cuda.empty_cache()
            model.eval()
            next_i = 0
            max_src_length = args.max_seq_length - 2 - args.max_tgt_length

            if args.experiment in ["full", "title", "title-l1"]:
                input_lines = EvalDataset(args.input_file, args.experiment).proc()
            elif args.experiment == "single":
                input_lines, map_dict = EvalDataset(args.input_file, args.experiment).proc()
            elif args.experiment == "title-first":
                input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc()
            elif args.experiment == "segsep":
                input_lines = EvalDataset(args.input_file, args.experiment, tokenizer, args.max_seq_length, args.max_seq_length).proc()
            elif args.experiment == "heirachical":
                logger.info("***** Recover rank model: %s *****", args.ranker_recover_path)
                # extract sentences before load data
                # load rank model
                rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu")
                global_step = 0
                rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2)
                
                # set model for multi GPUs or multi nodes
                if args.fp16:
                    rank_model.half()
                
                rank_model.to(device)

                if n_gpu > 1:
                    rank_model = DataParallelImbalance(rank_model)
                
                DatasetFunc = ScoreEvalDataset
                
                # Load title + sentence pair
                print ("Loading Rank Dataset from ", args.input_file)
                data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
                max_pred = 16
                mask_prob = 0.7
                rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.ranker_max_len, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 64, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)]
                fn_src = args.input_file
                fn_tgt = None
                eval_dataset = DatasetFunc(
                     fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.ranker_max_len, bi_uni_pipeline=rank_bi_uni_pipeline
                )

                eval_sampler = SequentialSampler(eval_dataset)
                _batch_size = args.ranker_batch_size

                eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False)


                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
                logger.info("***** Runinning ranker *****")
                logger.info("   Batch size = %d", _batch_size)
                logger.info("   Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size))

                rank_model.to(device)
                rank_model.eval()

                iter_bar = tqdm(eval_dataloader, desc = "Iter: ")
                num_rank_labels = 2
                all_labels = []
                for step, batch in enumerate(iter_bar):
                    batch = [t.to(device) if t is not None else None for t in batch]
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch
                    logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv)
                    labels = torch.argmax(logits.view(-1, num_rank_labels), dim=-1)
                    all_labels.append(labels)
                
                all_labels_results = []
                for label in all_labels:
                    all_labels_results.extend(label.detach().cpu().numpy())
                
                # collect results
                logger.info("**** Collect results ******")
                clu2doc_dict, doc2sent_dict, all_titles, all_sents = eval_dataset.get_maps()
                all_docs = []
                for i, doc in enumerate(doc2sent_dict):
                    text = all_titles[i]
                    sent_idx = doc2sent_dict[doc]
                    for idx in sent_idx:
                        if all_labels_results[idx] == 1:
                            text += ". " + all_sents[idx]
                    all_docs.append(text)
                
                input_lines = []
                for clu in tqdm(clu2doc_dict):
                    doc_idx = clu2doc_dict[clu]
                    input_line  = ""
                    for idx in doc_idx:
                        input_line += all_docs[idx]
                    input_lines.append(input_line)

            elif args.experiment == "title-first-rank":
                logger.info("***** Recover rank model: %s *****", args.ranker_recover_path)
                # extract sentences before load data
                # load rank model
                rank_model_recover = torch.load(args.ranker_recover_path, map_location="cpu")
                global_step = 0
                rank_model = BertForSentenceRanker.from_pretrained(args.bert_model, state_dict=rank_model_recover, num_labels=2)
                
                # set model for multi GPUs or multi nodes
                if args.fp16:
                    rank_model.half()
                
                rank_model.to(device)

                if n_gpu > 1:
                    rank_model = DataParallelImbalance(rank_model)
                
                DatasetFunc = EvalRankDataset
                
                # Load title + sentence pair
                print ("Loading Rank Dataset from ", args.input_file)
                data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
                max_pred = 16
                mask_prob = 0.7
                rank_bi_uni_pipeline = [Preprocess4Seq2cls(max_pred, mask_prob, list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, new_segment_ids=args.new_segment_ids, truncate_config={'max_len_a': 512, 'max_len_b': 16, 'trunc_seg': 'a', 'always_truncate_tail': True}, mask_source_words=False, skipgram_prb=0.0, skipgram_size=1, mask_whole_word=False, mode="s2s", has_oracle=False, num_qkv=0, s2s_special_token=False, s2s_add_segment=False, s2s_share_segment=False, pos_shift=False, eval=True)]
                fn_src = args.input_file
                fn_tgt = None
                eval_dataset = DatasetFunc(
                     fn_src, fn_tgt, args.ranker_batch_size, data_tokenizer, args.max_seq_length, bi_uni_pipeline=rank_bi_uni_pipeline
                )

                eval_sampler = SequentialSampler(eval_dataset)
                _batch_size = args.ranker_batch_size

                eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=_batch_size, sampler=eval_sampler, num_workers=24, collate_fn=seq2seq_loader.batch_list_to_batch_tensors, pin_memory=False)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
                logger.info("***** Runinning ranker *****")
                logger.info("   Batch size = %d", _batch_size)
                logger.info("   Num steps = %d", int(len(eval_dataset)/ args.ranker_batch_size))

                rank_model.to(device)
                rank_model.eval()

                iter_bar = tqdm(eval_dataloader, desc = "Iter: ")
                num_rank_labels = 2
                all_labels = []
                for step, batch in enumerate(iter_bar):
                    batch = [t.to(device) if t is not None else None for t in batch]
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch
                    # print("input_ids", len(input_ids[0]), "segment_ids", len(segment_ids[0]))
                    with torch.no_grad():
                        logits = rank_model(input_ids, task_idx=task_idx, mask_qkv=mask_qkv)
                    labels = logits.view(-1)
                    all_labels.append(labels)

                
                all_labels_results = []
                for label in all_labels:
                    all_labels_results.extend(label.detach().cpu().numpy())
                
                print("test label results")
                print(all_labels_results[0])
                # collect results
                logger.info("**** Collect results ******")
                clu2sent_dict, all_sents, all_titles= eval_dataset.get_maps()
                all_clusters = []
                input_lines = []
                for i, clu_id in enumerate(clu2sent_dict):
                    text = all_titles[clu_id]
                    sent_idx = clu2sent_dict[clu_id]
                    sents_collect = []
                    for idx in sent_idx:
                        sents_collect.append([all_sents[idx], all_labels_results[idx]])
                    sents_collect_sort = sorted(sents_collect, key=lambda x:x[1])

                    sents_collect = [x[0] for x in sents_collect_sort]

                    text_tk = tokenizer.tokenize(text)
                    j = 0
                    while j < len(sents_collect) and len(text_tk) + len(tokenizer.tokenize(sents_collect[j])) <= args.max_seq_length:
                        text += " " + sents_collect[j]
                        j += 1
                    
                    input_lines.append(text)                
            else:
                input_lines = []
            
            data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
            input_lines = [data_tokenizer.tokenize(
                x)[:max_src_length] for x in input_lines]
            input_lines = sorted(list(enumerate(input_lines)),
                                key=lambda x: -len(x[1]))
            output_lines = [""] * len(input_lines)
            score_trace_list = [None] * len(input_lines)
            total_batch = math.ceil(len(input_lines) / args.batch_size)

            with tqdm(total=total_batch) as pbar:
                while next_i < len(input_lines):
                    _chunk = input_lines[next_i:next_i + args.batch_size]
                    buf_id = [x[0] for x in _chunk]
                    buf = [x[1] for x in _chunk]
                    next_i += args.batch_size
                    max_a_len = max([len(x) for x in buf])
                    instances = []
                    for instance in [(x, max_a_len) for x in buf]:
                        for proc in bi_uni_pipeline:
                            instances.append(proc(instance))
                    with torch.no_grad():
                        batch = seq2seq_loader.batch_list_to_batch_tensors(
                            instances)
                        # print("batch")
                        # print(batch)
                        # print(len(batch))
                        batch = [t.to(device) for t in batch if t is not None]
                        input_ids, token_type_ids, position_ids, input_mask, task_idx = batch
                        traces = model(input_ids, token_type_ids,
                                    position_ids, input_mask, task_idx=task_idx)
                        if args.beam_size > 1:
                            traces = {k: v.tolist() for k, v in traces.items()}
                            output_ids = traces['pred_seq']
                        else:
                            output_ids = traces.tolist()
                        for i in range(len(buf)):
                            scores = traces['scores'][i]
                            wids_list = traces['wids'][i]
                            ptrs = traces['ptrs'][i]
                            eos_id = 102
                            top_k = args.top_k
                            # first we need to find the eos frame where all symbols are eos
                            # any frames after the eos frame are invalid
                            last_frame_id = len(scores) - 1
                            for _i, wids in enumerate(wids_list):
                                if all(wid == eos_id for wid in wids):
                                    last_frame_id = _i
                                    break
                            frame_id = -1
                            pos_in_frame = -1
                            seqs = []
                            for fid in range(last_frame_id + 1):
                                for _i, wid in enumerate(wids_list[fid]):
                                    if wid == eos_id or fid == last_frame_id:
                                        s = scores[fid][_i]

                                        frame_id = fid
                                        pos_in_frame = _i

                                        if frame_id != -1 and s < 0:
                                            seq = [wids_list[frame_id][pos_in_frame]]
                                            for _fid in range(frame_id, 0, -1):
                                                pos_in_frame = ptrs[_fid][pos_in_frame]
                                                seq.append(wids_list[_fid - 1][pos_in_frame])
                                            seq.reverse()
                                            seqs.append([seq, s])
                            seqs = sorted(seqs, key= lambda x:x[1], reverse=True)
                            w_idss = [seq[0] for seq in seqs[:top_k]]
                            output_sequences = []
                            for w_ids in w_idss:
                                output_buf = tokenizer.convert_ids_to_tokens(w_ids)
                                output_tokens = []
                                for t in output_buf:
                                    if t in ("[SEP]", "[PAD]"):
                                        break
                                    output_tokens.append(t)
                                output_sequence = ' '.join(detokenize(output_tokens))
                                output_sequences.append(output_sequence)
                            output_lines[buf_id[i]] = output_sequences
                            if args.need_score_traces:
                                score_trace_list[buf_id[i]] = {
                                    'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]}
                    pbar.update(1)
            # collect instances after split
            results = []
            if args.experiment == "single":
                for clu in map_dict:
                    record = []
                    clu_ixs = map_dict[clu]
                    for i in clu_ixs:
                        record.extend(output_lines[i])
                    record_top10 = Counter(record).most_common(10)
                    record_top10 = [x[0] for x in record_top10]
                    results.append(record_top10)

                output_lines = results

            if args.output_file:
                fn_out = args.output_file
            else:
                fn_out = model_recover_path+'.'+args.split
            with open(fn_out, "w", encoding="utf-8") as fout:
                for l in output_lines:
                    fout.write('\t'.join(l))
                    fout.write("\n")

            if args.need_score_traces:
                with open(fn_out + ".trace.pickle", "wb") as fout_trace:
                    pickle.dump(
                        {"version": 0.0, "num_samples": len(input_lines)}, fout_trace)
                    for x in score_trace_list:
                        pickle.dump(x, fout_trace)
        
        # Evaluate !
        if args.do_evaluate:
            labels = []
            if not os.path.exists(args.label_file):
                raise ValueError("Label file not exists")
            print("Loading label file from {}".format(args.label_file))
            with open(args.label_file) as f:
                for line in tqdm(f.readlines()):
                    line = line.strip().split("\t")
                    labels.append(line)
            results = output_lines

            ks = [1, 5, 10]
            results_dict = {}
            for k in ks:
                acc_cul = 0
                r_cul = 0
                f1_cul = 0
                cnt = 0
                for predict, true_label in zip(tqdm(results), tqdm(labels)):
                    predict = predict[:k]
                    true_label = true_label[:k]
                    if len(predict) > 0 and len(true_label) > 0:
                        acc_cul += acc_score(predict, true_label)
                        r_cul += recall_score(predict, true_label)
                        f1_cul += f1_score(acc_score(predict, true_label), recall_score(predict, true_label))
                        cnt += 1
                    
                results_dict["P@{}".format(k)] = acc_cul*1.000 / cnt
                results_dict["R@{}".format(k)] = r_cul*1.000 / cnt
                results_dict["F1@{}".format(k)] = f1_cul*1.000 / cnt
            
            print(results_dict)