def ch_per_example(args, scores_in_one_example, prev_context_tokens, dev_ds,
                   prev_example_idx, ans_dic, ans_idx_dic, offset, out_handle):
    total_score = scores_in_one_example[-1]
    assert len(prev_context_tokens) == len(total_score)
    token_score_dict = []
    for idx in range(len(total_score)):
        token_score_dict.append([idx, offset[idx], total_score[idx]])

    prev_example = dev_ds.data[prev_example_idx]
    char_attribution_dict = match(
        prev_example['context'] + prev_example['title'],
        prev_example['sent_token'], token_score_dict)
    result['id'] = prev_example['id']
    result['question'] = prev_example['question']
    result['title'] = prev_example['title']
    result['context'] = prev_example['context'] + prev_example['title']
    result['pred_label'] = ans_dic[str(result['id'])]
    result['pred_feature'] = ans_idx_dic[str(result['id'])]

    result['char_attri'] = collections.OrderedDict()
    for token_info in sorted(char_attribution_dict,
                             key=lambda x: x[2],
                             reverse=True):
        result['char_attri'][str(
            token_info[0])] = [str(token_info[1]),
                               float(token_info[2])]

    out_handle.write(json.dumps(result, ensure_ascii=False) + '\n')
def extract_attention_scores(args, atts, input_ids, tokens, sub_word_id_dict,
                             result, offset, out_handle):
    if args.base_model.startswith('roberta'):
        inter_score = atts[-1][:, :, 0, :].mean(1)  # (bsz, seq)
        inter_score = inter_score[0][1:-1]  # remove CLS and SEP
        input_ids = input_ids[0][1:-1]

    elif args.base_model == 'lstm':
        inter_score = atts[0]
        input_ids = input_ids[0]

    length = (inter_score > 0).cast('int32').sum(-1).tolist()[0]
    assert len(tokens) == length, f"%s: {len(tokens)} != {length}" % (step + 1)

    char_attribution_dict = {}
    # Collect scores in different situation
    if args.base_model.startswith('roberta'):
        assert len(inter_score) == len(offset), str(
            len(inter_score)) + "not equal to" + str(len(offset))
        sorted_token = []
        for i in range(len(inter_score)):
            sorted_token.append([i, offset[i], inter_score[i]])

        char_attribution_dict = match(result['context'], result['sent_token'],
                                      sorted_token)

        result['char_attri'] = collections.OrderedDict()
        for token_info in sorted(char_attribution_dict,
                                 key=lambda x: x[2],
                                 reverse=True):
            result['char_attri'][str(
                token_info[0])] = [str(token_info[1]),
                                   float(token_info[2])]
        result.pop('sent_token')
    else:
        if args.language == 'ch':
            idx = 0
            for token, score in zip(tokens, inter_score.numpy().tolist()):
                char_attribution_dict[idx] = (token, score)
                idx += 1
        else:
            idx = 0
            for word, sub_word_score in zip(tokens, inter_score.tolist()):
                char_attribution_dict[idx] = (word, sub_word_score)
                idx += 1

        result['char_attri'] = collections.OrderedDict()
        for token_id, token_info in sorted(char_attribution_dict.items(),
                                           key=lambda x: x[1][1],
                                           reverse=True):
            result['char_attri'][token_id] = token_info

    out_handle.write(json.dumps(result, ensure_ascii=False) + '\n')
def en_per_example(inter_score, result, ans_dic, ans_idx_dic, offset,
                   out_handle):
    sorted_token = []
    for i in range(len(inter_score)):
        sorted_token.append([i, offset[i], inter_score[i]])
    char_attribution_dict = match(result['context'], result['sent_token'],
                                  sorted_token)

    result['pred_label'] = ans_dic[str(result['id'])]
    result['pred_feature'] = ans_idx_dic[str(result['id'])]
    result['char_attri'] = collections.OrderedDict()
    for token_info in sorted(char_attribution_dict,
                             key=lambda x: x[2],
                             reverse=True):
        result['char_attri'][str(
            token_info[0])] = [str(token_info[1]),
                               float(token_info[2])]
    result.pop('sent_token')

    out_handle.write(json.dumps(result, ensure_ascii=False) + '\n')
def extract_integrated_gradient_scores(args, atts, input_ids, tokens,
                                       sub_word_id_dict, fwd_args, fwd_kwargs,
                                       model, result, pred_label, err_total,
                                       offset, out_handle):
    embedded_grads_list = []
    for i in range(args.n_samples):
        probs, _, embedded = model.forward_interpet(*fwd_args,
                                                    **fwd_kwargs,
                                                    noise='integrated',
                                                    i=i,
                                                    n_samples=args.n_samples)
        predicted_class_prob = probs[0][pred_label]
        predicted_class_prob.backward(retain_graph=False)
        embedded_grad = embedded.grad
        model.clear_gradients()
        embedded_grads_list.append(embedded_grad)

        if i == 0:
            baseline_pred_confidence = probs.tolist()[0][pred_label]  # scalar
            baseline_embedded = embedded  # Tensor(1, seq_len, embed_size)
        elif i == args.n_samples - 1:
            pred_confidence = probs.tolist()[0][pred_label]  # scalar
            pred_embedded = embedded  # Tensor(1, seq_len, embed_size)

    embedded_grads_tensor = paddle.to_tensor(embedded_grads_list,
                                             dtype='float32',
                                             place=paddle.CUDAPlace(0),
                                             stop_gradient=True)

    trapezoidal_grads = (embedded_grads_tensor[1:] +
                         embedded_grads_tensor[:-1]) / 2
    integral_grads = trapezoidal_grads.sum(0) / trapezoidal_grads.shape[
        0]  # Tensor(1, seq_len, embed_size)

    inter_score = (pred_embedded - baseline_embedded
                   ) * integral_grads  # Tensor(1, seq_len, embed_size)
    inter_score = inter_score.sum(-1)  # Tensor(1, seq_len)

    # eval err
    delta_pred_confidence = pred_confidence - baseline_pred_confidence
    sum_gradient = inter_score.sum().tolist()[0]
    err = (delta_pred_confidence - sum_gradient +
           1e-12) / (delta_pred_confidence + 1e-12)
    err_total.append(np.abs(err))

    print_str = '%s\t%d\t%.3f\t%.3f\t%.3f\t%.3f'
    print_vals = (result['id'], args.n_samples, delta_pred_confidence,
                  sum_gradient, err, np.average(err_total))
    log.debug(print_str % print_vals)

    inter_score.stop_gradient = True

    char_attribution_dict = {}
    if args.base_model.startswith('roberta'):
        inter_score = inter_score[0][1:-1]
        sorted_token = []
        for i in range(len(inter_score)):
            sorted_token.append([i, offset[i], inter_score[i]])
        char_attribution_dict = match(result['context'], result['sent_token'],
                                      sorted_token)

        result['char_attri'] = collections.OrderedDict()
        for token_info in sorted(char_attribution_dict,
                                 key=lambda x: x[2],
                                 reverse=True):
            result['char_attri'][str(
                token_info[0])] = [str(token_info[1]),
                                   float(token_info[2])]
        result.pop('sent_token')

    elif args.base_model == 'lstm':
        inter_score = inter_score[0]
        idx = 0
        for word, sub_word_score in zip(tokens, inter_score.tolist()):
            char_attribution_dict[idx] = (word, sub_word_score)
            idx += 1

        result['char_attri'] = collections.OrderedDict()
        for token_id, token_info in sorted(char_attribution_dict.items(),
                                           key=lambda x: x[1][1],
                                           reverse=True):
            result['char_attri'][token_id] = token_info

    out_handle.write(json.dumps(result, ensure_ascii=False) + '\n')
    return err_total
def extract_integrated_gradient_scores(args, result, fwd_args, fwd_kwargs,
                                       model, q_tokens, t_tokens, out_handle,
                                       SEP_idx, add_idx, q_offset, t_offset,
                                       err_total):
    embedded_grads_list = []
    q_embedded_grads_list, t_embedded_grads_list = [], []
    for i in range(args.n_samples):
        probs, _, embedded = model.forward_interpret(*fwd_args,
                                                     **fwd_kwargs,
                                                     noise='integrated',
                                                     i=i,
                                                     n_samples=args.n_samples)
        predicted_class_prob = probs[0][pred_label]
        predicted_class_prob.backward(retain_graph=False)

        if args.base_model.startswith('roberta'):
            embedded_grad = embedded.grad
            embedded_grads_list.append(embedded_grad)
        elif args.base_model == 'lstm':
            q_embedded, t_embedded = embedded
            q_embedded_grad = q_embedded.grad
            t_embedded_grad = t_embedded.grad
            q_embedded_grads_list.append(q_embedded_grad)
            t_embedded_grads_list.append(t_embedded_grad)
        model.clear_gradients()
        if i == 0:
            baseline_pred_confidence = probs.tolist()[0][pred_label]  # scalar
            baseline_embedded = embedded  # Tensor(1, seq_len, embed_size)
        elif i == args.n_samples - 1:
            pred_confidence = probs.tolist()[0][pred_label]  # scalar
            pred_embedded = embedded  # Tensor(1, seq_len, embed_size)

    if args.base_model.startswith('roberta'):
        q_inter_score, t_inter_score = IG_roberta_inter_score(
            args, embedded_grads_list, pred_embedded, baseline_embedded,
            pred_confidence, baseline_pred_confidence, SEP_idx, add_idx,
            err_total)
    elif args.base_model == 'lstm':
        q_inter_score = IG_lstm_inter_score(q_embedded_grads_list,
                                            pred_embedded, baseline_embedded,
                                            0)
        t_inter_score = IG_lstm_inter_score(t_embedded_grads_list,
                                            pred_embedded, baseline_embedded,
                                            1)

    q_char_attribution_dict, t_char_attribution_dict = {}, {}
    if args.base_model.startswith('roberta'):
        # Query
        sorted_token = []
        for i in range(len(q_inter_score)):
            sorted_token.append([i, q_offset[i], q_inter_score[i]])
        q_char_attribution_dict = match(result['query'], result['text_q_seg'],
                                        sorted_token)
        result['query_char_attri'] = collections.OrderedDict()
        for token_info in sorted(q_char_attribution_dict,
                                 key=lambda x: x[2],
                                 reverse=True):
            result['query_char_attri'][str(
                token_info[0])] = [str(token_info[1]),
                                   float(token_info[2])]
        result.pop('text_q_seg')

        #Title
        sorted_token = []
        for i in range(len(t_inter_score)):
            sorted_token.append([i, t_offset[i], t_inter_score[i]])
        t_char_attribution_dict = match(result['title'], result['text_t_seg'],
                                        sorted_token)
        result['title_char_attri'] = collections.OrderedDict()
        for token_info in sorted(t_char_attribution_dict,
                                 key=lambda x: x[2],
                                 reverse=True):
            result['title_char_attri'][str(
                token_info[0])] = [str(token_info[1]),
                                   float(token_info[2])]
        result.pop('text_t_seg')
    else:
        idx = 0
        for token, score in zip(q_tokens, q_inter_score.tolist()):
            q_char_attribution_dict[idx] = (token, score)
            idx += 1
        for token, score in zip(t_tokens, t_inter_score.tolist()):
            t_char_attribution_dict[idx] = (token, score)
            idx += 1

        result['query_char_attri'], result[
            'title_char_attri'] = collections.OrderedDict(
            ), collections.OrderedDict()
        for token, attri in sorted(q_char_attribution_dict.items(),
                                   key=lambda x: x[1][1],
                                   reverse=True):
            result['query_char_attri'][token] = attri
        for token, attri in sorted(t_char_attribution_dict.items(),
                                   key=lambda x: x[1][1],
                                   reverse=True):
            result['title_char_attri'][token] = attri

    out_handle.write(json.dumps(result, ensure_ascii=False) + '\n')
def extract_attention_scores(args, result, atts, q_tokens, t_tokens,
                             out_handle, SEP_idx, q_offset, t_offset, add_idx):
    if args.base_model.startswith('roberta'):
        inter_score = atts[-1][:, :, 0, :].mean(1)  # (bsz, seq)
        q_inter_score = inter_score[0][1:SEP_idx]  # remove CLS and SEP
        t_inter_score = inter_score[0][SEP_idx +
                                       add_idx:-1]  # remove CLS and SEP
    elif args.base_model == 'lstm':
        q_inter_score = atts[0][0]
        t_inter_score = atts[1][0]

    q_length = (q_inter_score > 0).cast('int32').sum(-1)[0]
    t_length = (t_inter_score > 0).cast('int32').sum(-1)[0]
    assert len(q_tokens) == q_length, f"{len(q_tokens)} != {q_length}"
    assert len(t_tokens) == t_length, f"{len(t_tokens)} != {t_length}"

    q_char_attribution_dict, t_char_attribution_dict = {}, {}
    if args.base_model.startswith('roberta'):
        # Query
        sorted_token = []
        for i in range(len(q_inter_score)):
            sorted_token.append([i, q_offset[i], q_inter_score[i]])
        q_char_attribution_dict = match(result['query'], result['text_q_seg'],
                                        sorted_token)
        result['query_char_attri'] = collections.OrderedDict()
        for token_info in sorted(q_char_attribution_dict,
                                 key=lambda x: x[2],
                                 reverse=True):
            result['query_char_attri'][str(
                token_info[0])] = [str(token_info[1]),
                                   float(token_info[2])]
        result.pop('text_q_seg')

        #Title
        sorted_token = []
        for i in range(len(t_inter_score)):
            sorted_token.append([i, t_offset[i], t_inter_score[i]])
        t_char_attribution_dict = match(result['title'], result['text_t_seg'],
                                        sorted_token)
        result['title_char_attri'] = collections.OrderedDict()
        for token_info in sorted(t_char_attribution_dict,
                                 key=lambda x: x[2],
                                 reverse=True):
            result['title_char_attri'][str(
                token_info[0])] = [str(token_info[1]),
                                   float(token_info[2])]
        result.pop('text_t_seg')

    else:
        idx = 0
        for token, score in zip(q_tokens, q_inter_score.tolist()):
            q_char_attribution_dict[idx] = (token, score)
            idx += 1
        for token, score in zip(t_tokens, t_inter_score.tolist()):
            t_char_attribution_dict[idx] = (token, score)
            idx += 1

        result['query_char_attri'], result[
            'title_char_attri'] = collections.OrderedDict(
            ), collections.OrderedDict()
        for token, attri in sorted(q_char_attribution_dict.items(),
                                   key=lambda x: x[1][1],
                                   reverse=True):
            result['query_char_attri'][token] = attri
        for token, attri in sorted(t_char_attribution_dict.items(),
                                   key=lambda x: x[1][1],
                                   reverse=True):
            result['title_char_attri'][token] = attri

    out_handle.write(json.dumps(result, ensure_ascii=False) + '\n')