def main(_):
    # -------------------- configuration ------------------------- #
    tf.logging.set_verbosity(tf.logging.INFO)
    task_name = FLAGS.task_name.lower()
    processors = {
        "sst-2": extract.Sst2Processor,
        "cola": extract.ColaProcessor,
    }
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    # ------------------- preprocess dataset -------------------- #
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    max_seq_length = FLAGS.max_seq_length

    # prepare valid dataset
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    if not os.path.exists(eval_file):
        extract.save_tfrecord(eval_examples, label_list, max_seq_length, tokenizer, eval_file)
    else:
        print('eval_tfrecord exists')

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))
    # ----------------------- build model --------------------- #

    # sess1
    bert_model = TextModel(FLAGS.bert_config_file, FLAGS.init_checkpoint, max_seq_length, num_labels)
    bert_model.start_session()

    print('Making explanations...')
    # for (i, example) in enumerate(eval_examples[:1]):
    # ==============================================================================

    sentence = eval_examples[FLAGS.sentence_idx] # the input sentence
    tokens_a = tokenizer.tokenize(sentence.text_a)

    a_len = len(tokens_a)
    # input feature of the sentence to BERT model
    feature = extract.convert_single_example(0, sentence, label_list, max_seq_length, tokenizer)

    seg = (FLAGS.seg_start, FLAGS.seg_end, a_len)  #左闭右开

    print(tokens_a)
    min_shap, max_shap, min_clist, max_clist = get_min_max_shap(seg, feature, bert_model)

    print(sentence.text_a)
    print(tokens_a)
    print("MAX:", max_shap, max_clist)
    print("MIN:", min_shap, min_clist)
    print("seg:","(%d, %d)"%(seg[0],seg[1]-1)) # 左右都包含
def eval():
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    label_map, idx2label = bert_data_utils.read_label_map_file(FLAGS.label_map_file)
    features = bert_data_utils.file_based_convert_examples_to_features(FLAGS.test_data_file, 
                                                                                label_map,
                                                                                FLAGS.max_sequence_length,
                                                                                tokenizer)

    print('\nEvaluating...\n')

    #Evaluation
    graph = tf.Graph()
    with graph.as_default():
        restore_graph_def = tf.GraphDef()
        restore_graph_def.ParseFromString(open(FLAGS.model_dir+'/frozen_model.pb', 'rb').read())
        tf.import_graph_def(restore_graph_def, name='')

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)
        
        with sess.as_default():
            #tensors we feed
            input_ids = graph.get_operation_by_name('input_ids').outputs[0]
            input_mask = graph.get_operation_by_name('input_mask').outputs[0]
            token_type_ids = graph.get_operation_by_name('segment_ids').outputs[0]
            is_training = graph.get_operation_by_name('is_training').outputs[0]
            
            #tensors we want to evaluate
            probs =  graph.get_operation_by_name('loss/probs').outputs[0]
            scores = graph.get_operation_by_name('loss/logits').outputs[0]
            pred_labels = graph.get_operation_by_name('loss/pred_labels').outputs[0]

            batches = dataloader.batch_iter(list(features), FLAGS.batch_size, 1, shuffle=False)

            #collect the predictions here
            all_predictions = []
            all_topk = []
            for batch in batches:
                feed_input_ids, feed_input_mask, feed_segment_ids = get_feed_data(batch)

                feed_dict = {input_ids: feed_input_ids,
                             input_mask: feed_input_mask,
                             token_type_ids: feed_segment_ids,
                             is_training: False,}

                batch_probs, batch_scores, batch_pred_labels = sess.run([probs, scores, pred_labels],
                                                                        feed_dict)
                batch_pred_label = np.argmax(batch_probs, -1)
                all_predictions = np.concatenate([all_predictions, batch_pred_label])
                temp = np.argsort(-batch_scores, 1)
                all_topk.extend(temp[:, :3].tolist()) #top 3
                

    raw_examples = list(bert_data_utils.get_data_from_file(FLAGS.test_data_file))
    truth_label_ids = np.array([item.label_id for item in features])
    #write predictions to file
    write_predictions(raw_examples, features, all_predictions, all_topk, idx2label)
Esempio n. 3
0
    def __init__(self, config):
        #self.top_k_num = config['top_k']
        self.model_dir = config['model_dir']
        self.max_seq_length = config['max_seq_length']
        self.vocab_file = config['vocab_file']
        self.label_map_file = config['label_map_file']
        self.model_checkpoints_dir = config['model_checkpoints_dir']
        self.model_pb_path = config['model_pb_path']

        #init label dict and processors
        label2idx, idx2label = bert_data_utils.read_ner_label_map_file(
            self.label_map_file)
        self.idx2label = idx2label
        self.label2idx = label2idx

        #self.label2code = bert_data_utils.read_code_file(self.code_file)

        self.tokenizer = tokenization.FullTokenizer(vocab_file=self.vocab_file,
                                                    do_lower_case=True)

        #init stop set
        self.stop_set = dataloader.get_stopwords_set(STOPWORD_FILE)

        #use default graph
        self.graph = tf.get_default_graph()
        restore_graph_def = tf.GraphDef()
        restore_graph_def.ParseFromString(
            open(self.model_pb_path, 'rb').read())
        tf.import_graph_def(restore_graph_def, name='')

        session_conf = tf.ConfigProto()
        self.sess = tf.Session(config=session_conf)
        self.sess.as_default()
        self.sess.run(tf.global_variables_initializer())

        #restore model
        #cp_file = tf.train.latest_checkpoint(self.model_checkpoints_dir)
        #saver = tf.train.import_meta_graph('{}.meta'.format(cp_file))
        #saver.restore(self.sess, cp_file)

        #get the placeholders from graph by name
        self.input_ids_tensor = self.graph.get_operation_by_name(
            'input_ids').outputs[0]
        self.input_mask_tensor = self.graph.get_operation_by_name(
            'input_mask').outputs[0]
        self.segment_ids_tensor = self.graph.get_operation_by_name(
            'segment_ids').outputs[0]
        self.is_training_tensor = self.graph.get_operation_by_name(
            'is_training').outputs[0]

        #tensors we want to evaluate
        self.pred_labels_tensor = self.graph.get_operation_by_name(
            'crf_pred_labels').outputs[0]
        self.probabilities_tensor = self.graph.get_operation_by_name(
            'crf_probs').outputs[0]
        self.logits_tensor = self.graph.get_operation_by_name(
            'logits').outputs[0]
    def __init__(self, config):
        #self.top_k_num = config['top_k']
        self.model_dir = config['model_dir']
        self.max_seq_length = config['max_seq_length']
        self.vocab_file = config['vocab_file']
        self.label_map_file = config['label_map_file']

        self.url = config['tf_serving_url']
        self.signature_name = config['signature_name']

        #init label dict and processors
        label2idx, idx2label = bert_data_utils.read_ner_label_map_file(
            self.label_map_file)
        self.idx2label = idx2label
        self.label2idx = label2idx

        self.tokenizer = tokenization.FullTokenizer(vocab_file=self.vocab_file,
                                                    do_lower_case=True)

        #init stop set
        self.stop_set = dataloader.get_stopwords_set(STOPWORD_FILE)
Esempio n. 5
0
def eval():
    f = open(FLAGS.save_file, 'w+')
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    label_map, idx2label = bert_data_utils.read_label_map_file(FLAGS.label_map_file)
    label_name = get_cid_name(FLAGS.cid3_file)
    batch_datas = bert_data_utils.get_data_yield(FLAGS.test_data_file, 
                                                                                label_map,
                                                                                FLAGS.max_sequence_length,
                                                                                tokenizer,
                                                                                FLAGS.batch_size)

    print('\nEvaluating...\n')

    #Evaluation
    # checkpoint_file = tf.train.latest_checkpoint(FLAGS.model_dir)
    graph = tf.Graph()
    with graph.as_default():
        #restore for tensorflow pb style
        # restore_graph_def = tf.GraphDef()
        # restore_graph_def.ParseFromString(open(FLAGS.model_dir+'/frozen_model.pb', 'rb').read())
        # tf.import_graph_def(restore_graph_def, name='')

        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        sess = tf.Session(config=session_conf)

        #restore for tf checkpoint style
        cp_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        # saver = tf.train.Saver()
        saver = tf.train.import_meta_graph('{}.meta'.format(cp_file))
        saver.restore(sess,cp_file)
        
        with sess.as_default():
            #tensors we feed
            input_ids = graph.get_operation_by_name('input_ids').outputs[0]
            input_mask = graph.get_operation_by_name('input_mask').outputs[0]
            token_type_ids = graph.get_operation_by_name('segment_ids').outputs[0]
            is_training = graph.get_operation_by_name('is_training').outputs[0]
            
            #tensors we want to evaluate
            # precision =  graph.get_operation_by_name('accuracy/precision').outputs[0]
            # recall = graph.get_operation_by_name('accuracy/recall').outputs[0]
            # f1 = graph.get_operation_by_name('accuracy/f1').outputs[0]
            predictions = graph.get_operation_by_name('loss/predictions').outputs[0]


            #collect the predictions here
            for batch in batch_datas:
                feed_input_ids, feed_input_mask, feed_segment_ids, querys = batch

                feed_dict = {input_ids: feed_input_ids,
                             input_mask: feed_input_mask,
                             token_type_ids: feed_segment_ids,
                             is_training: False,}

                batch_predictions = sess.run(predictions,feed_dict)
                for  prediction, query in zip(batch_predictions, querys):
                    predictions_sorted = sorted(prediction, reverse=True)
                    index_sorted = np.argsort(-prediction)
                    t =0
                    label_list = []
                    #label_scores = []
                    label_names = []
                    for index, predict in zip(index_sorted, predictions_sorted):
                        if predict >=FLAGS.threshold:
                            label = idx2label[index]
                            label_list.append(label+':'+str(predict))
                            #label_scores.append(str(predict))
                            label_names.append(label_name[label])
                    if len(label_list) == 0:
                        label_list.append('0:0')
                        #label_scores.append('0')
                        label_names.append(u'填充类')

                    f.write(query+'\t'+','.join(label_list)+'\t'+','.join(label_names)+'\n')
def main(_):
    # -------------------- configuration ------------------------- #
    tf.logging.set_verbosity(tf.logging.INFO)
    task_name = FLAGS.task_name.lower()
    processors = {
        "sst-2": extract.Sst2Processor,
        "cola": extract.ColaProcessor,
    }
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    if FLAGS.init_checkpoint == "":
        if FLAGS.task_name == "sst-2":
            data_dir = "./GLUE_data/SST-2"
            FLAGS.init_checkpoint = data_dir + "/model/model.ckpt-6313"
        else:
            data_dir = "./GLUE_data/CoLA"
            FLAGS.init_checkpoint = data_dir + "/model/model.ckpt-801"

    processor = processors[task_name]()

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    # ------------------- preprocess dataset -------------------- #
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    max_seq_length = FLAGS.max_seq_length

    # prepare valid dataset
    eval_examples = processor.get_dev_examples(data_dir)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))
    # ----------------------- build model --------------------- #

    # sess1
    bert_model = TextModel(FLAGS.bert_config_file, FLAGS.init_checkpoint,
                           max_seq_length, num_labels)
    bert_model.start_session()

    print('Making explanations...')
    # for (i, example) in enumerate(eval_examples[:1]):
    # ==============================================================================

    sentence = eval_examples[FLAGS.sentence_idx]  # the input sentence
    tokens_a = tokenizer.tokenize(sentence.text_a)

    a_len = len(tokens_a)
    # input feature of the sentence to BERT model
    feature = extract.convert_single_example(0, sentence, label_list,
                                             max_seq_length, tokenizer)

    seg = (FLAGS.seg_start, FLAGS.seg_end, a_len)
    seg_len = seg[1] - seg[0]
    p_mask = np.zeros(a_len - 1)
    p_mask[seg[0]:seg[1] - 1] = 1
    print("\nCurrent words:", tokens_a[seg[0]:seg[1]])

    m_cnt = FLAGS.m_cnt
    g_sample_num = FLAGS.g_sample_num

    #=================================================================================================
    g = tf.Graph()
    with g.as_default():
        sess = tf.Session()

        summary_writer.add_graph(sess.graph)

        tmp = [0.0] * (a_len - 1)
        pij_weights = tf.Variable(tmp)  # initial value of ps, before sigmoid

        # pij_weights = tf.Variable(tf.random.normal([a_len-1]))
        pij_weights_ = tf.sigmoid(pij_weights)  # add sigmoid

        pij_masked = tf.where(
            p_mask > 0, pij_weights_,
            tf.zeros_like(pij_weights_))  # freeze pi out of selected seg

        tf.summary.histogram("pij", pij_masked[seg[0]:seg[1]])
        for i in range(seg_len - 1):
            tf.summary.scalar("p_%d" % i, pij_masked[seg[0] + i])

        p_c = pij_masked[seg[0]:seg[1] - 1]
        p_seg = tf.concat([
            [[0.0]], [p_c]
        ], axis=1)[0, :]  # ensure the number of ps same as the number of words

        overall_expected = tf.placeholder(shape=[seg_len, 4], dtype=tf.float32)

        phi_c = overall_expected[:, 0] * p_seg\
                +overall_expected[:, 1] * (1 - p_seg)\
                -overall_expected[:, 2] * p_seg\
                -overall_expected[:, 3] * (1 - p_seg)
        g_score = tf.reduce_sum(phi_c)

        if FLAGS.maximize_shap:
            totloss = tf.negative(g_score)
        else:
            totloss = g_score

        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(FLAGS.lr, global_step, 10,
                                                   1)

        my_opt = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                            momentum=0.9)
        train_step = my_opt.minimize(totloss, global_step=global_step)

        tf.summary.scalar("total_loss", totloss)

        merged_summary_op = tf.summary.merge_all()

        #=======================================================================================================================

        init = tf.global_variables_initializer()
        sess.run(init)
        # loss_list = []

        item_list = [i for i in range(a_len)]

        for epoch in range(FLAGS.epoch_num):
            pij = sess.run(pij_masked)  # numpy ndarray

            clist = pij_coals(pij, seg=seg)

            words = []
            for coal in clist:
                if len(coal) == 1:
                    item = coal[0]
                else:
                    item = coal
                if isinstance(item, int):
                    words.append(tokens_a[item])
                else:
                    tmp = []
                    for id in item:
                        tmp.append(tokens_a[id])
                    words.append(tmp)

            print('pij', pij, clist)
            print("coalition:", words)

            score_exp_list = []
            for g_ in range(g_sample_num):
                g_sample = g_sample_bern(pij)  # sample g

                g_clist = pij_coals(
                    g_sample, seg=seg)  # partition the coalition based on g
                score_exp_items = []
                score_item = [0.0, 0.0]

                for cIdx, coal in enumerate(g_clist):
                    # new_list, cIdx = get_new_list(item, item_list)

                    if coal[0] < seg[0] or coal[0] >= seg[1]:  # out of the seg
                        continue

                    positions_dict = get_masks_sampleshapley(
                        g_clist, cIdx, a_len, m_cnt)  # sample S
                    positions_dict = exclude_mask(positions_dict, coal, seg)

                    scores_c_s, scores_c_si = compute_scores_seperate(
                        positions_dict, feature, a_len, bert_model.predict)

                    score_item[0] += np.mean(scores_c_si)
                    score_item[1] += np.mean(scores_c_s)

                score_item[0] /= seg_len
                score_item[1] /= seg_len

                for idx, item in enumerate(item_list[seg[0]:seg[1]]):
                    score_exp = compute_sum(score_item[1], score_item[0],
                                            g_sample, item)
                    score_exp_items.append(score_exp)

                score_exp_list.append(score_exp_items)

            overall_exp_score = cal_overall_exp(score_exp_list)

            in_dict = {overall_expected: overall_exp_score}

            _, _loss, summary_str, lr, g_score_ = sess.run([
                train_step, totloss, merged_summary_op, learning_rate, g_score
            ],
                                                           feed_dict=in_dict)

            summary_writer.add_summary(summary_str, epoch)

            print('epoch:', epoch, '-->loss:', _loss, '-->learning_rate:', lr,
                  "\n")
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    task_name = FLAGS.task_name.lower()
    processors = {
        "sst-2": extract.Sst2Processor,
        "cola": extract.ColaProcessor,
    }
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))
    if FLAGS.task_name == "sst-2":
        FLAGS.data_dir = "data/sst-2"
        FLAGS.init_checkpoint = "models/sst-2/model.ckpt-6313"
    elif FLAGS.task_name == "cola":
        FLAGS.data_dir = "data/cola"
        FLAGS.init_checkpoint = "models/cola/model.ckpt-801"

    processor = processors[task_name]()

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    # ------------------- preprocess dataset -------------------
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    max_seq_length = FLAGS.max_seq_length

    eval_examples = processor.get_dev_examples(FLAGS.data_dir)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))

    # ----------------------- build models ---------------------
    tf.reset_default_graph()
    model = TextModel(FLAGS.bert_config_file, FLAGS.init_checkpoint,
                      max_seq_length, num_labels)
    model.start_session()

    method = FLAGS.method
    if method not in ['SampleShapley', 'Singleton']:
        print("Not Supported Shapley")
    else:
        print(method)
    print('Making explanations...')

    for (i, example) in enumerate(eval_examples):
        st = time.time()
        print('explaining the {}th sample...'.format(i))
        tokens_a = tokenizer.tokenize(example.text_a)
        a_len = len(tokens_a)
        print(tokens_a)
        print('tokens length is', a_len)
        feature = extract.convert_single_example(i, example, label_list,
                                                 max_seq_length, tokenizer)

        pre_slist = list(range(a_len))
        output_tree = list(range(a_len))
        tree_values = []
        # construct a binary tree
        for h in range(a_len - 1):
            pre_slen = len(pre_slist)
            totcombs = []
            ratios = []
            stn = {}

            # compute B, phi{a,b,...} for each point
            tot_values = {}
            for k in range(pre_slen):
                scores = compute_scores(pre_slist, k, feature, a_len,
                                        model.predict, method)
                if len(scores) == 2:
                    b = 0
                    subtree = [b, scores[0:1]]
                else:
                    b = scores[0] - np.sum(scores[1:])
                    subtree = [b, scores[1:]]
                tot_values[k] = subtree

            locs = []
            for j in range(pre_slen - 1):
                coal = turn_list(pre_slist[j]) + turn_list(pre_slist[j + 1])
                now_slist = pre_slist[:j]  # elems before j
                now_slist.append(coal)
                if j + 2 < pre_slen:
                    now_slist = now_slist + pre_slist[j +
                                                      2:]  # elems after j+1

                totcombs.append(now_slist)
                # compute shapley values of now pair combination
                score = compute_scores(now_slist, j, feature, a_len,
                                       model.predict, method)
                nowb = score[0] - np.sum(score[1:])
                nowphis = score[1:]

                lt = tot_values[j][1]
                rt = tot_values[j + 1][1]
                avgphis = (nowphis + np.concatenate((lt, rt))) / 2
                len_lt = lt.shape[0]

                b_lt = tot_values[j][0]
                b_rt = tot_values[j + 1][0]
                b_local = nowb - b_lt - b_rt
                contri_lt = b_lt + np.sum(avgphis[:len_lt])
                contri_rt = b_rt + np.sum(avgphis[len_lt:])

                # additional two metrics
                extra_pre_slist_l = list(pre_slist)
                extra_pre_slist_l.pop(j + 1)
                extra_score_l = compute_scores(extra_pre_slist_l, j, feature,
                                               a_len, model.predict, method)
                psi_intra_l = extra_score_l[0] - np.sum(extra_score_l[1:])
                psi_intra_l = psi_intra_l - b_lt

                extra_pre_slist_r = list(pre_slist)
                extra_pre_slist_r.pop(j)
                extra_score_r = compute_scores(extra_pre_slist_r, j, feature,
                                               a_len, model.predict, method)
                psi_intra_r = extra_score_r[0] - np.sum(extra_score_r[1:])
                psi_intra_r = psi_intra_r - b_rt
                psi_intra = (psi_intra_l + psi_intra_r)
                psi_inter = b_local - psi_intra
                t = abs(psi_inter) / (abs(psi_intra) + abs(psi_inter))
                # end additional metrics

                locs.append(
                    [b_local, contri_lt, contri_rt, b_lt, b_rt, t, nowb])

            for j in range(pre_slen - 1):
                loss = 0.0
                if j - 1 >= 0:
                    loss = loss + abs(locs[j - 1][0])

                if j + 2 < pre_slen:
                    loss = loss + abs(locs[j + 1][0])

                all_info = loss + abs(locs[j][0]) + abs(locs[j][1]) + abs(
                    locs[j][2])
                metric = abs(locs[j][0]) / all_info
                sub_metric = loss / all_info

                ratios.append(metric)

                stn[j] = {
                    'r': metric,
                    's': sub_metric,
                    'Bbetween': locs[j][0],
                    'Bl': locs[j][3],
                    'Br': locs[j][4],
                    't': locs[j][5],
                    'B([S])': locs[j][6],
                }
            stn['base_B'] = tot_values
            coalition = np.argmax(np.array(ratios))
            pre_slist = totcombs[coalition]
            stn['maxIdx'] = coalition
            stn['after_slist'] = pre_slist
            print('coalition:', coalition)
            print('after_slist:', pre_slist)

            tree_values.append(stn)

            # generate a new nested list by adding elements into a empty list
            tmp_list = []
            for z in range(len(output_tree)):
                if z == coalition:
                    tmp_list.append(
                        list((output_tree[z], output_tree[z + 1], stn[z])))
                elif z == coalition + 1:
                    continue
                else:
                    tmp_list.append(output_tree[z])
            output_tree = tmp_list.copy()

        save_path = 'binary_trees/' + FLAGS.task_name.upper()
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_pkl = save_path + '/stn_' + str(i) + '.pkl'
        with open(save_pkl, "wb") as f:
            contents = {
                "sentence": tokens_a,
                "tree": output_tree,
                "tree_values": tree_values,
            }
            pickle.dump(contents, f)

        print('Time spent is {}'.format(time.time() - st))
def main(_):
    # -------------------- configuration ------------------------- #
    tf.logging.set_verbosity(tf.logging.INFO)
    task_name = FLAGS.task_name.lower()
    processors = {
        "sst-2": extract.Sst2Processor,
        "cola": extract.ColaProcessor,
    }
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    # ------------------- preprocess dataset -------------------- #
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    max_seq_length = FLAGS.max_seq_length

    # prepare valid dataset
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))
    # ----------------------- build model --------------------- #

    # sess1
    bert_model = TextModel(FLAGS.bert_config_file, FLAGS.init_checkpoint, max_seq_length, num_labels)
    bert_model.start_session()

    print('Making explanations...')
    # for (i, example) in enumerate(eval_examples[:1]):
    # ==============================================================================

    res = []
    res.append({"lr":FLAGS.lr, "g_sample_num":g_sample_num, "m_cnt":m_cnt, "epoch_num": FLAGS.epoch_num, "maximize": FLAGS.maximize_shap})

    count = 0
    for id, sentence in enumerate(eval_examples):
        dic = {}
        tokens_a = tokenizer.tokenize(sentence.text_a)

        dic["id"] = id
        dic["tokens"] = tokens_a

        a_len = len(tokens_a)
        if a_len < min_len or a_len > max_len:
            continue
        count += 1
        print(count)
        # print(count)

        print(id, tokens_a)

        seg_len = random.choice(seg_len_range)
        seg = [0, 0, a_len]
        seg[0] = random.choice(range(a_len-seg_len))
        seg[1] = seg[0] + seg_len

        # print(seg,"\n\n\n\n")

        # input feature of the sentence to BERT model
        feature = extract.convert_single_example(0, sentence, label_list, max_seq_length, tokenizer)

        # print("\nCurrent words:", tokens_a[seg[0]:seg[1]])

        dic["seg"] = seg

        FLAGS.maximize_shap = True
        opt_res_1 = manage_a_sentence(tokens_a, seg, feature, bert_model)
        FLAGS.maximize_shap = False
        opt_res_2 = manage_a_sentence(tokens_a, seg, feature, bert_model)

        opt_res = []
        for i in range(len(opt_res_1)):
            item = {"p_max": opt_res_1[i]["p"],
                    "p_min": opt_res_2[i]["p"],
                    "loss": -1 * opt_res_1[i]["loss"] - opt_res_2[i]["loss"]
                    }
            opt_res.append(item)

        dic["opt_res"] = opt_res

        min_gt_score, max_gt_score, min_gt_part, max_gt_part = get_min_max_shap(seg, feature, bert_model)
        gt_score = max_gt_score - min_gt_score
        dic["gt_score"] = gt_score

        difference = []
        for i in range(FLAGS.epoch_num//l_step):

            opt_score = 0
            for j in range(i*l_step,(i+1)*l_step):
                opt_score += abs(opt_res[j]["loss"])
            opt_score /= l_step

            difference.append(abs(gt_score-opt_score))

        dic["difference"] = difference
        res.append(dic)

        print("gt_score:", gt_score)
        with open('difference_%s_bert.json'%FLAGS.task_name, 'w') as f:
            json.dump(res, f)
def main(_):
    # -------------------- configuration ------------------------- #
    tf.logging.set_verbosity(tf.logging.INFO)
    task_name = FLAGS.task_name.lower()
    processors = {
        "sst-2": extract.Sst2Processor,
        "cola": extract.ColaProcessor,
    }
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    # ------------------- preprocess dataset -------------------- #
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    max_seq_length = FLAGS.max_seq_length

    # prepare valid dataset
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))
    # ----------------------- build model --------------------- #
    # sess1
    bert_models = []
    for layer_id in layers:
        bert_model = TextModel(FLAGS.bert_config_file, FLAGS.init_checkpoint,
                               max_seq_length, num_labels, layer_id)
        bert_models.append(bert_model)

    print('Making explanations...')
    # for (i, example) in enumerate(eval_examples[:1]):
    # ==============================================================================

    res = []
    res.append({
        "lr": FLAGS.lr,
        "g_sample_num": g_sample_num,
        "m_cnt": m_cnt,
        "epoch_num": FLAGS.epoch_num,
        "maximize": FLAGS.maximize_shap
    })

    # with open("difference_sst_elmo.json","r") as f:
    #     res = json.load(f)
    # count = len(res) - 1
    # start = res[-1]["id"] + 1
    start = 0
    count = 0
    for i, sentence in enumerate(eval_examples[start:]):
        id = i + start
        dic = {}

        tokens_a = tokenizer.tokenize(sentence.text_a)
        feature = extract.convert_single_example(0, sentence, label_list,
                                                 max_seq_length, tokenizer)

        dic["id"] = id
        dic["tokens"] = tokens_a

        a_len = len(tokens_a)
        if a_len < min_len or a_len > max_len:
            continue
        count += 1
        print(count)

        print(id, tokens_a)

        seg_len = random.choice(seg_len_range)
        seg = [0, 0, a_len]
        seg[0] = random.choice(range(a_len - seg_len))
        seg[1] = seg[0] + seg_len

        dic["seg"] = seg

        for id, layer in enumerate(layers):
            bert_models[id].start_session()

            print(id, layer, "\n\n\n\n")
            layer_res = {}
            FLAGS.maximize_shap = True
            opt_res_max = manage_a_sentence(tokens_a, seg, feature,
                                            bert_models[id])

            FLAGS.maximize_shap = False
            opt_res_min = manage_a_sentence(tokens_a, seg, feature,
                                            bert_models[id])

            layer_res["opt_res_max"] = opt_res_max
            layer_res["opt_res_min"] = opt_res_min

            dic[layer] = layer_res

            bert_models[id].close_session()

        res.append(dic)

        with open('interaction_%s_bert_layer.json' % FLAGS.task_name,
                  'w') as f:
            json.dump(res, f)