示例#1
0
 def test_token_to_uid(self):
     alias_database = AliasDatabase(
         token_lookup=self.token_lookup,
         id_map_lookup=self.id_map_lookup,
         id_array_lookup=self.id_array_lookup,
         token_to_entity_lookup=self.token_to_entity_lookup)
     assert alias_database.token_to_uid('Entity1', 'Robert') == 1
     assert alias_database.token_to_uid('Entity1', 'Nelson') == 0
    def __init__(self,
                 alias_database_path: str,
                 mode: str = "generative",
                 token_indexers: Dict[str, TokenIndexer] = None,
                 entity_indexers: Dict[str, TokenIndexer] = None,
                 raw_entity_indexers: Dict[str, TokenIndexer] = None,
                 relation_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = False) -> None:
        """
        Parameters
        ----------
        alias_database_path : str
            Path to the alias database.
        mode : str, optional (default="generative")
            One of "discriminative" or "generative", indicating whether generated
            instances are suitable for the discriminative or generative version of
            the model.
        """
        super().__init__(lazy)
        if mode not in {"discriminative", "generative"}:
            raise ConfigurationError(
                "Got mode {}, expected one of 'generative'"
                "or 'discriminative'".format(mode))
        self._mode = mode

        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }
        self._entity_indexers = entity_indexers or {
            'entity_ids': SingleIdTokenIndexer(namespace='entity_ids')
        }
        self._raw_entity_indexers = raw_entity_indexers or {
            'raw_entity_ids': SingleIdTokenIndexer(namespace='raw_entity_ids')
        }
        self._relation_indexers = relation_indexers or {
            'relations': SingleIdTokenIndexer(namespace='relations')
        }
        if 'tokens' not in self._token_indexers or \
                not isinstance(self._token_indexers['tokens'], SingleIdTokenIndexer):
            raise ConfigurationError(
                "EnhancedWikitextReader expects 'token_indexers' to contain "
                "a 'single_id' token indexer called 'tokens'.")
        if 'entity_ids' not in self._entity_indexers or \
                not isinstance(self._entity_indexers['entity_ids'], SingleIdTokenIndexer):
            raise ConfigurationError(
                "EnhancedWikitextReader expects 'entity_indexers' to contain "
                "a 'single_id' token indexer called 'entity_ids'.")
        if 'raw_entity_ids' not in self._raw_entity_indexers or \
                not isinstance(self._raw_entity_indexers['raw_entity_ids'], SingleIdTokenIndexer):
            raise ConfigurationError(
                "EnhancedWikitextReader expects 'raw_entity_indexers' to contain "
                "a 'single_id' token indexer called 'raw_entity_ids'.")
        if 'relations' not in self._relation_indexers or \
                not isinstance(self._relation_indexers['relations'], SingleIdTokenIndexer):
            raise ConfigurationError(
                "EnhancedWikitextReader expects 'relation_indexers' to contain "
                "a 'single_id' token indexer called 'relations'.")
        self._alias_database = AliasDatabase.load(path=alias_database_path)
示例#3
0
def sample(args: argparse.Namespace):
    model_archive = load_archive(args.model_archive_file,
                                 cuda_device=args.cuda_device)
    config = model_archive.config
    prepare_environment(config)
    model = model_archive.model
    model.eval()

    alias_database = AliasDatabase.load(args.alias_database)

    samples = model.sample(alias_database,
                           batch_size=args.batch_size,
                           length=args.length)
示例#4
0
    def test_load(self):
        # Test that the load function has the expected behavior
        alias_database = AliasDatabase.load(
            'kglm/tests/fixtures/enhanced-wikitext-test/alias.pkl')
        test_entity = 'Q156216'  # Benton County

        # Check that aliases are tokenized properly
        expected_tokenized_aliases = [['Benton', 'County'],
                                      ['Benton', 'County', ',', 'Washington']]
        assert alias_database._token_lookup[
            test_entity] == expected_tokenized_aliases

        # Check that the id map has 4 unique tokens
        assert len(alias_database._id_map_lookup[test_entity]) == 4

        # Check that the first token in each alias has the same local id
        test_id_array = alias_database._id_array_lookup[test_entity]
        assert test_id_array[0, 0] == test_id_array[1, 0]
示例#5
0
    def test_tensorize_and_lookup(self):
        # Tensor fields should be empty when ``AliasDatabase``` is created
        alias_database = AliasDatabase(
            token_lookup=self.token_lookup,
            id_map_lookup=self.id_map_lookup,
            id_array_lookup=self.id_array_lookup,
            token_to_entity_lookup=self.token_to_entity_lookup)
        assert not alias_database.is_tensorized

        # But should exist after ``AliasDatabase`` is tensorized
        alias_database.tensorize(self.vocab)
        return

        assert alias_database.is_tensorized
        assert alias_database._global_id_lookup != []
        assert alias_database._local_id_lookup != []

        tensor_dict = self.dataset.as_tensor_dict()
        entity_ids = tensor_dict['entity_identifiers']['entity_ids']
        tokens = tensor_dict['tokens']['tokens']
        global_tensor, local_tensor = alias_database.lookup(entity_ids)
        entity_id_tensor = alias_database.reverse_lookup(tokens)

        # The first two dimensions should match the batch_size and sequence length of the index.
        # The next dimensions should be the max number of aliases of all entities (in this case 2
        # for 'Robert Logan', and 'Robby') and the max length of the aliases (again 2 because
        # 'Robert Logan' is two tokens).
        assert global_tensor.shape == (1, 6, 2, 2)
        assert local_tensor.shape == (1, 6, 2, 2)
        assert entity_id_tensor.shape == (1, 6, 1)

        # Check that the global ids match the vocabulary indices
        assert global_tensor[0, 0, 0, 0] == self.vocab.get_token_index(
            'Robert', namespace='tokens')
        assert global_tensor[0, 1, 0, 0] == 0  # Padding since not an alias

        assert local_tensor[0, 0, 0, 0] == 1
        assert local_tensor[0, 1, 0, 0] == 0  # Padding since not an alias

        match_token_idx = self.vocab.get_token_index('Entity1',
                                                     namespace='entity_ids')
        nonmatch_token_idx = self.vocab.get_token_index('Entity2',
                                                        namespace='entity_ids')
        assert entity_id_tensor[0, 0, match_token_idx] == 1
        assert entity_id_tensor[0, 0, nonmathc_token_idx] == 0
示例#6
0
    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      alias_database: AliasDatabase,
                      mention_type: torch.Tensor,
                      raw_entity_ids: Dict[str, torch.Tensor],
                      entity_ids: Dict[str, torch.Tensor],
                      parent_ids: Dict[str, torch.Tensor],
                      relations: Dict[str, torch.Tensor],
                      shortlist: Dict[str, torch.Tensor],
                      shortlist_inds: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get the token mask and extract indexed text fields.
        # shape: (batch_size, sequence_length)
        target_mask = get_text_field_mask(source)
        source = source['tokens']
        raw_entity_ids = raw_entity_ids['raw_entity_ids']
        entity_ids = entity_ids['entity_ids']
        parent_ids = parent_ids['entity_ids']
        relations = relations['relations']

        logger.debug('Source & Target shape: %s', source.shape)
        logger.debug('Entity ids shape: %s', entity_ids.shape)
        logger.debug('Relations & Parent ids shape: %s', relations.shape)
        logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape)
        # Embed source tokens.
        # shape: (batch_size, sequence_length, embedding_dim)
        encoded, alpha_loss, beta_loss = self._encode_source(source)
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
        encoded_token, encoded_head, encoded_relation = encoded.split(
            splits, dim=-1)

        # Predict whether or not the next token will be an entity mention, and if so which type.
        mention_type_loss = self._mention_type_loss(
            encoded_token, mention_type, target_mask)
        self._avg_mention_type_loss(float(mention_type_loss))

        # For new mentions, predict which entity (among those in the supplied shortlist) will be
        # mentioned.
        overlap_feature = alias_database.reverse_lookup(source)
        if self._use_shortlist:
            new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation,
                                                    overlap_feature,
                                                    shortlist_inds,
                                                    shortlist,
                                                    target_mask)
        else:
            new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation,
                                                    overlap_feature,
                                                    entity_ids,
                                                    None,
                                                    target_mask)

        self._avg_new_entity_loss(float(new_entity_loss))

        # For derived mentions, first predict which parent(s) to expand...
        knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head,
                                                                        encoded_relation,
                                                                        raw_entity_ids,
                                                                        entity_ids,
                                                                        parent_ids,
                                                                        target_mask)
        self._avg_knowledge_graph_entity_loss(
            float(knowledge_graph_entity_loss))

        # Compute total loss
        loss = mention_type_loss + new_entity_loss + knowledge_graph_entity_loss

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * alpha_loss
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * beta_loss

        return {'loss': loss}
示例#7
0
    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      target: Dict[str, torch.Tensor],
                      alias_database: AliasDatabase,
                      entity_ids: Dict[str, torch.Tensor],
                      shortlist: Dict[str, torch.Tensor],
                      shortlist_inds: torch.Tensor,
                      alias_copy_inds: torch.Tensor) -> Dict[str, torch.Tensor]:

        # Get the token mask and unwrap the target tokens.
        target_mask = get_text_field_mask(target)
        target = target['tokens']

        # Embed source tokens.
        source = source['tokens']
        source_embeddings = embedded_dropout(
            embed=self._token_embedder,
            words=source,
            dropout=self._dropoute if self.training else 0)
        source_embeddings = self._locked_dropout(
            source_embeddings, self._dropouti)

        # Embed entities.
        entity_ids = entity_ids['entity_ids']
        # entity_embeddings = embedded_dropout(
        #     embed=self._entity_embedder,
        #     words=entity_ids,
        #     dropout=self._dropoute if self.training else 0)
        # entity_embeddings = self._locked_dropout(entity_embeddings, self._dropouti)

        # # Embed shortlist.
        # shortlist_mask = get_text_field_mask(shortlist)
        # shortlist = shortlist['entity_ids']
        # shortlist_embeddings = embedded_dropout(
        #     embed=self._entity_embedder,
        #     words=shortlist,
        #     dropout=self._dropoute if self.training else 0)

        # Encode source tokens.
        current_input = source_embeddings
        hidden_states = []
        for layer, rnn in enumerate(self.rnns):
            # Retrieve previous hidden state for layer.
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None
            # Forward-pass.
            output, hidden = rnn(current_input, prev_hidden)
            output = output.contiguous()
            # Update hidden state for layer.
            hidden = tuple(h.detach() for h in hidden)
            hidden_states.append(hidden)
            # Apply dropout.
            if layer == self._num_layers - 1:
                dropped_output = self._locked_dropout(output, self._dropout)
            else:
                dropped_output = self._locked_dropout(output, self._dropouth)
            current_input = dropped_output
        encoded = current_input
        self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)}

        # Predict whether or not the next token will be an entity mention. This corresponds to the
        # case that the entity's id is not a padding token.
        # mention_loss = self._mention_loss(encoded, entity_ids.gt(0), target_mask)

        # Predict which entity (among those in the supplied shortlist) is going to be
        # mentioned.
        # entity_loss = self._entity_loss(encoded,
        #                                 shortlist_inds,
        #                                 target_mask,
        #                                 shortlist_embeddings,
        #                                 shortlist_mask)

        # Predict generation-mode scores. Start by concatenating predicted entity embeddings with
        # the encoder output - then feed through a linear layer.
        # concatenated = torch.cat((encoded, entity_embeddings), dim=-1)
        # condensed = self._fc_condense(concatenated)
        generate_scores = self._fc_generate(encoded)

        # Predict copy-mode scores.
        alias_tokens, alias_inds = alias_database.lookup(entity_ids)
        copy_scores = self._copy_scores(encoded, alias_tokens)

        # Combine scores to get vocab loss
        vocab_loss = self._vocab_loss(generate_scores,
                                      copy_scores,
                                      target,
                                      alias_copy_inds,
                                      target_mask,
                                      alias_inds,
                                      alias_tokens,
                                      entity_ids.gt(0))

        # Compute total loss
        loss = vocab_loss  # + mention_loss + entity_loss

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * dropped_output.pow(2).mean()
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * \
                (output[:, 1:] - output[:, :-1]).pow(2).mean()

        return {'loss': loss}
示例#8
0
    def _forward_loop(self,
                      source: Dict[str, torch.Tensor],
                      target: Dict[str, torch.Tensor],
                      alias_database: AliasDatabase,
                      mention_type: torch.Tensor,
                      raw_entity_ids: Dict[str, torch.Tensor],
                      entity_ids: Dict[str, torch.Tensor],
                      parent_ids: Dict[str, torch.Tensor],
                      relations: Dict[str, torch.Tensor],
                      shortlist: Dict[str, torch.Tensor],
                      shortlist_inds: torch.Tensor,
                      alias_copy_inds: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get the token mask and extract indexed text fields.
        # shape: (batch_size, sequence_length)
        target_mask = get_text_field_mask(target)
        source = source['tokens']
        target = target['tokens']
        raw_entity_ids = raw_entity_ids['raw_entity_ids']
        entity_ids = entity_ids['entity_ids']

        logger.debug('Source & Target shape: %s', source.shape)
        logger.debug('Entity ids shape: %s', entity_ids.shape)
        logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape)
        # Embed source tokens.
        # shape: (batch_size, sequence_length, embedding_dim)
        encoded, alpha_loss, beta_loss = self._encode_source(source)
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim]
        encoded_token, encoded_head = encoded.split(splits, dim=-1)

        # Predict whether or not the next token will be an entity mention, and if so which type.
        mention_type = mention_type.gt(0).long()  # Map 1, 2 -> 1
        mention_type_loss = self._mention_type_loss(
            encoded_token, mention_type, target_mask)
        self._avg_mention_type_loss(float(mention_type_loss))

        # For new mentions, predict which entity (among those in the supplied shortlist) will be
        # mentioned.
        if self._use_shortlist:
            new_entity_loss = self._new_entity_loss(encoded_head,
                                                    shortlist_inds,
                                                    shortlist,
                                                    target_mask)
        else:
            new_entity_loss = self._new_entity_loss(encoded_head,
                                                    entity_ids,
                                                    None,
                                                    target_mask)

        self._avg_new_entity_loss(float(new_entity_loss))

        # Predict generation-mode scores. Note: these are W.R.T to entity_ids since we need the embedding.
        generate_scores = self._generate_scores(encoded_token, entity_ids)

        # Predict copy-mode scores. Note: these are W.R.T raw_entity_ids since we need to look up aliases.
        alias_tokens, alias_inds = alias_database.lookup(raw_entity_ids)
        copy_scores = self._copy_scores(encoded_token, alias_tokens)

        # Combine scores to get vocab loss
        vocab_loss, penalized_vocab_loss = self._vocab_loss(generate_scores,
                                                            copy_scores,
                                                            target,
                                                            alias_copy_inds,
                                                            target_mask,
                                                            alias_inds,
                                                            entity_ids.gt(0))
        self._avg_vocab_loss(float(vocab_loss))

        # Compute total loss. Also compute logp (needed for importance sampling evaluation).
        loss = vocab_loss + mention_type_loss + new_entity_loss
        logp = -(vocab_loss + mention_type_loss +
                 new_entity_loss) * target_mask.sum()
        penalized_logp = - \
            (penalized_vocab_loss + mention_type_loss +
             new_entity_loss) * target_mask.sum()

        # Activation regularization
        if self._alpha:
            loss = loss + self._alpha * alpha_loss
        # Temporal activation regularization (slowness)
        if self._beta:
            loss = loss + self._beta * beta_loss

        return {'loss': loss, 'logp': logp, 'penalized_logp': penalized_logp}