def test_load_word_pairs(self): ids1, ids2 = load_word_pairs(self.pairs_fname, WhitespaceTokenizer(), self.pairs_vocab, "tokens") # first two token IDs reserved for [CLS] and [SEP] assert torch.equal(torch.tensor([i.item() for i in ids1]), torch.arange(2, self.num_pairs + 2, step=2)) assert torch.equal(torch.tensor([i.item() for i in ids2]), torch.arange(3, self.num_pairs + 3, step=2))
def __init__( self, embedding_layer: torch.nn.Embedding, seed_word_pairs_file: Union[PathLike, str], tokenizer: Tokenizer, mitigator_vocab: Optional[Vocabulary] = None, namespace: str = "tokens", ): self.ids1, self.ids2 = load_word_pairs(seed_word_pairs_file, tokenizer, mitigator_vocab, namespace) self.mitigator = INLPBiasMitigator()
def __init__( self, seed_word_pairs_file: Union[PathLike, str], tokenizer: Tokenizer, direction_vocab: Optional[Vocabulary] = None, namespace: str = "tokens", noise: float = 1e-10, ): self.ids1, self.ids2 = load_word_pairs(seed_word_pairs_file, tokenizer, direction_vocab, namespace) self.direction = ClassificationNormalBiasDirection() self.noise = noise
def __init__( self, seed_word_pairs_file: Union[PathLike, str], tokenizer: Tokenizer, direction_vocab: Optional[Vocabulary] = None, namespace: str = "tokens", requires_grad: bool = False, noise: float = 1e-10, ): self.ids1, self.ids2 = load_word_pairs(seed_word_pairs_file, tokenizer, direction_vocab, namespace) self.direction = TwoMeansBiasDirection(requires_grad=requires_grad) self.noise = noise
def __init__( self, bias_direction: BiasDirectionWrapper, embedding_layer: torch.nn.Embedding, equalize_word_pairs_file: Union[PathLike, str], tokenizer: Tokenizer, mitigator_vocab: Optional[Vocabulary] = None, namespace: str = "tokens", requires_grad: bool = True, ): # use predetermined bias direction self.bias_direction = bias_direction self.predetermined_bias_direction = self.bias_direction( embedding_layer) self.ids1, self.ids2 = load_word_pairs(equalize_word_pairs_file, tokenizer, mitigator_vocab, namespace) self.mitigator = HardBiasMitigator(requires_grad=requires_grad)