def __init__(self, config): super(SpanAttentionLayer, self).__init__() self.attention = SpanAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) init_bert_weights(self.intermediate, config.initializer_range) init_bert_weights(self.output, config.initializer_range)
def __init__(self, output_feed_forward_hidden_dim: int = 100, weighted_entity_threshold: float = None, null_embedding: torch.Tensor = None, initializer_range: float = 0.02): super().__init__() # layers for the dot product attention self.out_layer_1 = torch.nn.Linear(2, output_feed_forward_hidden_dim) self.out_layer_2 = torch.nn.Linear(output_feed_forward_hidden_dim, 1) init_bert_weights(self.out_layer_1, initializer_range) init_bert_weights(self.out_layer_2, initializer_range) self.weighted_entity_threshold = weighted_entity_threshold if null_embedding is not None: self.register_buffer('null_embedding', null_embedding)
def __init__(self, vocab: Vocabulary, entity_linker: Model, span_attention_config: Dict[str, int], should_init_kg_to_bert_inverse: bool = True, freeze: bool = False, regularizer: RegularizerApplicator = None): super().__init__(vocab, regularizer) self.entity_linker = entity_linker self.entity_embedding_dim = self.entity_linker.disambiguator.entity_embedding_dim self.contextual_embedding_dim = self.entity_linker.disambiguator.contextual_embedding_dim self.weighted_entity_layer_norm = BertLayerNorm(self.entity_embedding_dim, eps=1e-5) init_bert_weights(self.weighted_entity_layer_norm, 0.02) self.dropout = torch.nn.Dropout(0.1) # the span attention layers assert len(span_attention_config) == 4 config = BertConfig( 0, # vocab size, not used hidden_size=span_attention_config['hidden_size'], num_hidden_layers=span_attention_config['num_hidden_layers'], num_attention_heads=span_attention_config['num_attention_heads'], intermediate_size=span_attention_config['intermediate_size'] ) self.span_attention_layer = SpanAttentionLayer(config) # already init inside span attention layer # for the output! self.output_layer_norm = BertLayerNorm(self.contextual_embedding_dim, eps=1e-5) self.kg_to_bert_projection = torch.nn.Linear( self.entity_embedding_dim, self.contextual_embedding_dim ) self.should_init_kg_to_bert_inverse = should_init_kg_to_bert_inverse self._init_kg_to_bert_projection() self._freeze_all = freeze
def __init__(self, embedding_file: str, entity_dim: int, entity_file: str = None, vocab_file: str = None, entity_h5_key: str = 'conve_tucker_infersent_bert', dropout: float = 0.1, pos_embedding_dim: int = 25, include_null_embedding: bool = False): """ pass pos_emedding_dim = None to skip POS embeddings and all the entity stuff, using this as a pretrained embedding file with feedforward """ super().__init__() if pos_embedding_dim is not None: # entity_id -> pos abbreviation, e.g. # 'cat.n.01' -> 'n' # includes special, e.g. '@@PADDING@@' -> '@@PADDING@@' entity_to_pos = {} with JsonFile(cached_path(entity_file), 'r') as fin: for node in fin: if node['type'] == 'synset': entity_to_pos[node['id']] = node['pos'] for special in [ '@@PADDING@@', '@@MASK@@', '@@NULL@@', '@@UNKNOWN@@' ]: entity_to_pos[special] = special # list of entity ids entities = ['@@PADDING@@'] with open(cached_path(vocab_file), 'r') as fin: for line in fin: entities.append(line.strip()) # the map from entity index id -> pos embedding id, # will use for POS embedding lookup entity_id_to_pos_index = [ self.POS_MAP[entity_to_pos[ent]] for ent in entities ] self.register_buffer('entity_id_to_pos_index', torch.tensor(entity_id_to_pos_index)) self.pos_embeddings = torch.nn.Embedding(len(entities), pos_embedding_dim) init_bert_weights(self.pos_embeddings, 0.02) self.use_pos = True else: self.use_pos = False # load the embeddings with h5py.File(cached_path(embedding_file), 'r') as fin: entity_embeddings = fin[entity_h5_key][...] self.entity_embeddings = torch.nn.Embedding(entity_embeddings.shape[0], entity_embeddings.shape[1], padding_idx=0) self.entity_embeddings.weight.data.copy_( torch.tensor(entity_embeddings).contiguous()) if pos_embedding_dim is not None: assert entity_embeddings.shape[0] == len(entities) concat_dim = entity_embeddings.shape[1] + pos_embedding_dim else: concat_dim = entity_embeddings.shape[1] self.proj_feed_forward = torch.nn.Linear(concat_dim, entity_dim) init_bert_weights(self.proj_feed_forward, 0.02) self.dropout = torch.nn.Dropout(dropout) self.entity_dim = entity_dim self.include_null_embedding = include_null_embedding if include_null_embedding: # a special embedding for null entities = ['@@PADDING@@'] with open(cached_path(vocab_file), 'r') as fin: for line in fin: entities.append(line.strip()) self.null_id = entities.index("@@NULL@@") self.null_embedding = torch.nn.Parameter(torch.zeros(entity_dim)) self.null_embedding.data.normal_(mean=0.0, std=0.02)
def __init__(self, config): super(SpanAttention, self).__init__() self.attention = SpanWordAttention(config) init_bert_weights(self.attention, config.initializer_range, (SpanWordAttention, )) self.output = BertSelfOutput(config) init_bert_weights(self.output, config.initializer_range)
def __init__(self, contextual_embedding_dim, entity_embedding_dim: int, entity_embeddings: torch.nn.Embedding, max_sequence_length: int = 512, span_encoder_config: Dict[str, int] = None, dropout: float = 0.1, output_feed_forward_hidden_dim: int = 100, initializer_range: float = 0.02, weighted_entity_threshold: float = None, null_entity_id: int = None, include_null_embedding_in_dot_attention: bool = False): """ Idea: Align the bert and KG vector space by learning a mapping between them. """ super().__init__() self.span_extractor = SelfAttentiveSpanExtractor(entity_embedding_dim) init_bert_weights(self.span_extractor._global_attention._module, initializer_range) self.dropout = torch.nn.Dropout(dropout) self.bert_to_kg_projector = torch.nn.Linear( contextual_embedding_dim, entity_embedding_dim) init_bert_weights(self.bert_to_kg_projector, initializer_range) self.projected_span_layer_norm = BertLayerNorm(entity_embedding_dim, eps=1e-5) init_bert_weights(self.projected_span_layer_norm, initializer_range) self.kg_layer_norm = BertLayerNorm(entity_embedding_dim, eps=1e-5) init_bert_weights(self.kg_layer_norm, initializer_range) # already pretrained, don't init self.entity_embeddings = entity_embeddings self.entity_embedding_dim = entity_embedding_dim # layers for the dot product attention if weighted_entity_threshold is not None or include_null_embedding_in_dot_attention: if hasattr(self.entity_embeddings, 'get_null_embedding'): null_embedding = self.entity_embeddings.get_null_embedding() else: null_embedding = self.entity_embeddings.weight[null_entity_id, :] else: null_embedding = None self.dot_attention_with_prior = DotAttentionWithPrior( output_feed_forward_hidden_dim, weighted_entity_threshold, null_embedding, initializer_range ) self.null_entity_id = null_entity_id self.contextual_embedding_dim = contextual_embedding_dim if span_encoder_config is None: self.span_encoder = None else: # create BertConfig assert len(span_encoder_config) == 4 config = BertConfig( 0, # vocab size, not used hidden_size=span_encoder_config['hidden_size'], num_hidden_layers=span_encoder_config['num_hidden_layers'], num_attention_heads=span_encoder_config['num_attention_heads'], intermediate_size=span_encoder_config['intermediate_size'] ) self.span_encoder = BertEncoder(config) init_bert_weights(self.span_encoder, initializer_range)