def construct_interventions(base_sent,
                            professions,
                            tokenizer,
                            DEVICE,
                            structure=None,
                            number=None,
                            subs=None):
    interventions = {}
    all_word_count = 0
    used_word_count = 0
    if structure is None:
        for p in professions:
            all_word_count += 1
            try:
                interventions[p] = Intervention(tokenizer,
                                                base_sent, [p, "man", "woman"],
                                                ["he", "she"],
                                                device=DEVICE)
                used_word_count += 1
            except:
                pass
        return
    # else we're doing structural interventions
    if structure.startswith("across") or structure == "simple_agreement":
        candidate_sing = "is"
        candidate_pl = "are"
    elif structure.startswith("within"):
        candidate_sing = "likes"
        candidate_pl = "like"
    if structure == "across_subj_rel":
        sub = base_sent.split()[-1]
    elif structure.startswith("across_obj_rel"):
        sub = base_sent.split()[-2]
    elif structure.startswith("within_obj_rel"):
        sub = base_sent.split()[1]
    for idx, p in enumerate(professions):
        all_word_count += 1
        if structure == "simple_agreement":
            sub = subs[idx]
        # print(base_sent, p, sub, candidate_pl, candidate_sing, DEVICE)
        try:
            if number == "sing":
                interventions[p] = Intervention(tokenizer,
                                                base_sent, [p, sub],
                                                [candidate_pl, candidate_sing],
                                                device=DEVICE,
                                                structure=structure)
            elif number == "pl":
                interventions[p] = Intervention(tokenizer,
                                                base_sent, [p, sub],
                                                [candidate_sing, candidate_pl],
                                                device=DEVICE,
                                                structure=structure)
            used_word_count += 1
        except:
            pass

    print("\t Only used {}/{} professions due to tokenizer".format(
        used_word_count, all_word_count))
    return interventions
def construct_interventions(base_sent, tokenizer, DEVICE, gender='female'):
    interventions = {}
    if gender == 'female':
        filename = 'experiment_data/professions_female_stereo.json'
    else:
        filename = 'experiment_data/professions_male_stereo.json'
    with open(filename, 'r') as f:
        all_word_count = 0
        used_word_count = 0
        for l in f:
            # there is only one line that eval's to an array
            for j in eval(l):
                all_word_count += 1
                biased_word = j[0]
                try:
                    interventions[biased_word] = Intervention(
                        tokenizer,
                        base_sent, [biased_word, "man", "woman"],
                        ["he", "she"],
                        device=DEVICE)
                    used_word_count += 1
                except:
                    pass
                    # print("excepted {} due to tokenizer splitting.".format(
                    #     biased_word))

        print("Only used {}/{} neutral words due to tokenizer".format(
            used_word_count, all_word_count))
    return interventions
def construct_interventions(base_sent, professions, tokenizer, DEVICE):
    interventions = {}
    all_word_count = 0
    used_word_count = 0
    for p in professions:
        all_word_count += 1
        try:
            interventions[p] = Intervention(tokenizer,
                                            base_sent, [p, "man", "woman"],
                                            ["he", "she"],
                                            device=DEVICE)
            used_word_count += 1
        except:
            pass
    print("\t Only used {}/{} professions due to tokenizer".format(
        used_word_count, all_word_count))
    return interventions
コード例 #4
0
 def to_intervention(self, tokenizer, stat):
     if stat == 'bergsma':
         pct_female = self.bergsma_pct_female
     elif stat == 'bls':
         pct_female = self.bls_pct_female
     else:
         raise ValueError('Invalid: ' + stat)
     if pct_female > 50:
         female_continuation = self.continuation_occupation
         male_continuation = self.continuation_participant
     else:
         male_continuation = self.continuation_occupation
         female_continuation = self.continuation_participant
     return Intervention(
         tokenizer=tokenizer,
         base_string=self.base_string,
         substitutes=[self.female_pronoun, self.male_pronoun],
         candidates=[female_continuation, male_continuation])
コード例 #5
0
        candidate1_alt_prob, candidate2_alt_prob = model.get_probabilities_for_examples_multitoken(
            x_alt, intervention.candidates_tok)

    odds_base = candidate2_base_prob / candidate1_base_prob
    odds_alt = candidate2_alt_prob / candidate1_alt_prob
    return odds_alt / odds_base


def topk_indices(arr, k):
    """Return indices of top-k values"""
    return (-arr).argsort(axis=None)[:k]


if __name__ == "__main__":
    from transformers import GPT2Tokenizer
    from experiment import Intervention, Model
    from pandas import DataFrame
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = Model(output_attentions=True)

    # Test experiment
    interventions = [
        Intervention(tokenizer, "The doctor asked the nurse a question. {}",
                     ["He", "She"], ["asked", "answered"]),
        Intervention(tokenizer, "The doctor asked the nurse a question. {}",
                     ["He", "She"], ["requested", "responded"])
    ]

    results = perform_interventions(interventions, model)
    report_interventions_summary_by_layer(results)
def load_structural_interventions(tokenizer, device, structure=None):
    grammar = read_grammar('structural/grammar.avg')
    if structure.startswith("across") or structure == "simple_agreement":
        professions = {'sing': grammar[("N1", frozenset("s"))],
                       'pl':   grammar[("N1", frozenset("p"))]}
    elif structure.startswith("within"):
        professions = {'sing': grammar[("N2", frozenset("s"))],
                       'pl':   grammar[("N2", frozenset("p"))]}
    if structure == "simple_agreement":
        templates = StrTemplates("The {}", structure, grammar)
    elif structure == "across_obj_rel":
        templates = StrTemplates("The {} that the {} {}", structure, grammar)
    elif structure == "across_obj_rel_no_that":
        templates = StrTemplates("The {} the {} {}", structure, grammar)
    elif structure == "across_subj_rel":
        templates = StrTemplates("The {} that {} the {}", structure, grammar)
    elif structure == "within_obj_rel":
        templates = StrTemplates("The {} that the {}", structure, grammar)
    elif structure == "within_obj_rel_no_that":
        templates = StrTemplates("The {} the {}", structure, grammar)
    templates = templates.base_strings
    intervention_types = ["diffnum_direct", "diffnum_indirect"]

    # build list of interventions
    interventions = []
    if structure.startswith("across") or structure == "simple_agreement":
        candidate_sing = "is"; candidate_pl = "are"
    elif structure.startswith("within"):
        candidate_sing = "likes"; candidate_pl = "like"

    for number in ('sing', 'pl'):
        if structure == "simple_agreement":
            other_number = "sing" if number == "pl" else "pl"
        for template in templates[number]:
            if structure.startswith("within"):
                sub = template.split()[-1]
            elif structure.startswith("across_obj_rel"):
                sub = template.split()[-2]
            elif structure == "across_subj_rel":
                sub = template.split()[-1]
            elif structure.startswith("within_obj_rel"):
                sub = template.split()[1]
            for idx, p in enumerate(professions[number]):
                if structure == "simple_agreement":
                    sub = professions[other_number][idx]
                try:
                    if number == "sing":
                        interventions.append(Intervention(
                            tokenizer, template, [p, sub], 
                            [candidate_sing, candidate_pl],
                            device=device, structure=structure
                        ))
                    elif number == "pl":
                        interventions.append(Intervention(
                            tokenizer, template, [p, sub],
                            [candidate_pl, candidate_sing],
                            device=device, structure=structure
                        ))
                except:
                    pass

    return interventions