Пример #1
0
 def __init__(self, config):
     super(SpanAttentionLayer, self).__init__()
     self.attention = SpanAttention(config)
     self.intermediate = BertIntermediate(config)
     self.output = BertOutput(config)
     init_bert_weights(self.intermediate, config.initializer_range)
     init_bert_weights(self.output, config.initializer_range)
Пример #2
0
    def __init__(self,
                 output_feed_forward_hidden_dim: int = 100,
                 weighted_entity_threshold: float = None,
                 null_embedding: torch.Tensor = None,
                 initializer_range: float = 0.02):

        super().__init__()

        # layers for the dot product attention
        self.out_layer_1 = torch.nn.Linear(2, output_feed_forward_hidden_dim)
        self.out_layer_2 = torch.nn.Linear(output_feed_forward_hidden_dim, 1)
        init_bert_weights(self.out_layer_1, initializer_range)
        init_bert_weights(self.out_layer_2, initializer_range)

        self.weighted_entity_threshold = weighted_entity_threshold
        if null_embedding is not None:
            self.register_buffer('null_embedding', null_embedding)
Пример #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 entity_linker: Model,
                 span_attention_config: Dict[str, int],
                 should_init_kg_to_bert_inverse: bool = True,
                 freeze: bool = False,
                 regularizer: RegularizerApplicator = None):
        super().__init__(vocab, regularizer)

        self.entity_linker = entity_linker
        self.entity_embedding_dim = self.entity_linker.disambiguator.entity_embedding_dim
        self.contextual_embedding_dim = self.entity_linker.disambiguator.contextual_embedding_dim

        self.weighted_entity_layer_norm = BertLayerNorm(self.entity_embedding_dim, eps=1e-5)
        init_bert_weights(self.weighted_entity_layer_norm, 0.02)

        self.dropout = torch.nn.Dropout(0.1)

        # the span attention layers
        assert len(span_attention_config) == 4
        config = BertConfig(
            0,  # vocab size, not used
            hidden_size=span_attention_config['hidden_size'],
            num_hidden_layers=span_attention_config['num_hidden_layers'],
            num_attention_heads=span_attention_config['num_attention_heads'],
            intermediate_size=span_attention_config['intermediate_size']
        )
        self.span_attention_layer = SpanAttentionLayer(config)
        # already init inside span attention layer

        # for the output!
        self.output_layer_norm = BertLayerNorm(self.contextual_embedding_dim, eps=1e-5)

        self.kg_to_bert_projection = torch.nn.Linear(
            self.entity_embedding_dim, self.contextual_embedding_dim
        )

        self.should_init_kg_to_bert_inverse = should_init_kg_to_bert_inverse
        self._init_kg_to_bert_projection()

        self._freeze_all = freeze
Пример #4
0
    def __init__(self,
                 embedding_file: str,
                 entity_dim: int,
                 entity_file: str = None,
                 vocab_file: str = None,
                 entity_h5_key: str = 'conve_tucker_infersent_bert',
                 dropout: float = 0.1,
                 pos_embedding_dim: int = 25,
                 include_null_embedding: bool = False):
        """
        pass pos_emedding_dim = None to skip POS embeddings and all the
            entity stuff, using this as a pretrained embedding file
            with feedforward
        """

        super().__init__()

        if pos_embedding_dim is not None:
            # entity_id -> pos abbreviation, e.g.
            # 'cat.n.01' -> 'n'
            # includes special, e.g. '@@PADDING@@' -> '@@PADDING@@'
            entity_to_pos = {}
            with JsonFile(cached_path(entity_file), 'r') as fin:
                for node in fin:
                    if node['type'] == 'synset':
                        entity_to_pos[node['id']] = node['pos']
            for special in [
                    '@@PADDING@@', '@@MASK@@', '@@NULL@@', '@@UNKNOWN@@'
            ]:
                entity_to_pos[special] = special

            # list of entity ids
            entities = ['@@PADDING@@']
            with open(cached_path(vocab_file), 'r') as fin:
                for line in fin:
                    entities.append(line.strip())

            # the map from entity index id -> pos embedding id,
            # will use for POS embedding lookup
            entity_id_to_pos_index = [
                self.POS_MAP[entity_to_pos[ent]] for ent in entities
            ]
            self.register_buffer('entity_id_to_pos_index',
                                 torch.tensor(entity_id_to_pos_index))

            self.pos_embeddings = torch.nn.Embedding(len(entities),
                                                     pos_embedding_dim)
            init_bert_weights(self.pos_embeddings, 0.02)

            self.use_pos = True
        else:
            self.use_pos = False

        # load the embeddings
        with h5py.File(cached_path(embedding_file), 'r') as fin:
            entity_embeddings = fin[entity_h5_key][...]
        self.entity_embeddings = torch.nn.Embedding(entity_embeddings.shape[0],
                                                    entity_embeddings.shape[1],
                                                    padding_idx=0)
        self.entity_embeddings.weight.data.copy_(
            torch.tensor(entity_embeddings).contiguous())

        if pos_embedding_dim is not None:
            assert entity_embeddings.shape[0] == len(entities)
            concat_dim = entity_embeddings.shape[1] + pos_embedding_dim
        else:
            concat_dim = entity_embeddings.shape[1]

        self.proj_feed_forward = torch.nn.Linear(concat_dim, entity_dim)
        init_bert_weights(self.proj_feed_forward, 0.02)

        self.dropout = torch.nn.Dropout(dropout)

        self.entity_dim = entity_dim

        self.include_null_embedding = include_null_embedding
        if include_null_embedding:
            # a special embedding for null
            entities = ['@@PADDING@@']
            with open(cached_path(vocab_file), 'r') as fin:
                for line in fin:
                    entities.append(line.strip())
            self.null_id = entities.index("@@NULL@@")
            self.null_embedding = torch.nn.Parameter(torch.zeros(entity_dim))
            self.null_embedding.data.normal_(mean=0.0, std=0.02)
Пример #5
0
 def __init__(self, config):
     super(SpanAttention, self).__init__()
     self.attention = SpanWordAttention(config)
     init_bert_weights(self.attention, config.initializer_range, (SpanWordAttention, ))
     self.output = BertSelfOutput(config)
     init_bert_weights(self.output, config.initializer_range)
Пример #6
0
    def __init__(self,
                 contextual_embedding_dim,
                 entity_embedding_dim: int,
                 entity_embeddings: torch.nn.Embedding,
                 max_sequence_length: int = 512,
                 span_encoder_config: Dict[str, int] = None,
                 dropout: float = 0.1,
                 output_feed_forward_hidden_dim: int = 100,
                 initializer_range: float = 0.02,
                 weighted_entity_threshold: float = None,
                 null_entity_id: int = None,
                 include_null_embedding_in_dot_attention: bool = False):
        """
        Idea: Align the bert and KG vector space by learning a mapping between
            them.
        """
        super().__init__()

        self.span_extractor = SelfAttentiveSpanExtractor(entity_embedding_dim)
        init_bert_weights(self.span_extractor._global_attention._module,
                          initializer_range)

        self.dropout = torch.nn.Dropout(dropout)

        self.bert_to_kg_projector = torch.nn.Linear(
            contextual_embedding_dim, entity_embedding_dim)
        init_bert_weights(self.bert_to_kg_projector, initializer_range)
        self.projected_span_layer_norm = BertLayerNorm(entity_embedding_dim, eps=1e-5)
        init_bert_weights(self.projected_span_layer_norm, initializer_range)

        self.kg_layer_norm = BertLayerNorm(entity_embedding_dim, eps=1e-5)
        init_bert_weights(self.kg_layer_norm, initializer_range)

        # already pretrained, don't init
        self.entity_embeddings = entity_embeddings
        self.entity_embedding_dim = entity_embedding_dim

        # layers for the dot product attention
        if weighted_entity_threshold is not None or include_null_embedding_in_dot_attention:
            if hasattr(self.entity_embeddings, 'get_null_embedding'):
                null_embedding = self.entity_embeddings.get_null_embedding()
            else:
                null_embedding = self.entity_embeddings.weight[null_entity_id, :]
        else:
            null_embedding = None
        self.dot_attention_with_prior = DotAttentionWithPrior(
            output_feed_forward_hidden_dim,
            weighted_entity_threshold,
            null_embedding,
            initializer_range
        )
        self.null_entity_id = null_entity_id
        self.contextual_embedding_dim = contextual_embedding_dim

        if span_encoder_config is None:
            self.span_encoder = None
        else:
            # create BertConfig
            assert len(span_encoder_config) == 4
            config = BertConfig(
                0,  # vocab size, not used
                hidden_size=span_encoder_config['hidden_size'],
                num_hidden_layers=span_encoder_config['num_hidden_layers'],
                num_attention_heads=span_encoder_config['num_attention_heads'],
                intermediate_size=span_encoder_config['intermediate_size']
            )
            self.span_encoder = BertEncoder(config)
            init_bert_weights(self.span_encoder, initializer_range)