def test_token_to_uid(self): alias_database = AliasDatabase( token_lookup=self.token_lookup, id_map_lookup=self.id_map_lookup, id_array_lookup=self.id_array_lookup, token_to_entity_lookup=self.token_to_entity_lookup) assert alias_database.token_to_uid('Entity1', 'Robert') == 1 assert alias_database.token_to_uid('Entity1', 'Nelson') == 0
def __init__(self, alias_database_path: str, mode: str = "generative", token_indexers: Dict[str, TokenIndexer] = None, entity_indexers: Dict[str, TokenIndexer] = None, raw_entity_indexers: Dict[str, TokenIndexer] = None, relation_indexers: Dict[str, TokenIndexer] = None, lazy: bool = False) -> None: """ Parameters ---------- alias_database_path : str Path to the alias database. mode : str, optional (default="generative") One of "discriminative" or "generative", indicating whether generated instances are suitable for the discriminative or generative version of the model. """ super().__init__(lazy) if mode not in {"discriminative", "generative"}: raise ConfigurationError( "Got mode {}, expected one of 'generative'" "or 'discriminative'".format(mode)) self._mode = mode self._token_indexers = token_indexers or { 'tokens': SingleIdTokenIndexer() } self._entity_indexers = entity_indexers or { 'entity_ids': SingleIdTokenIndexer(namespace='entity_ids') } self._raw_entity_indexers = raw_entity_indexers or { 'raw_entity_ids': SingleIdTokenIndexer(namespace='raw_entity_ids') } self._relation_indexers = relation_indexers or { 'relations': SingleIdTokenIndexer(namespace='relations') } if 'tokens' not in self._token_indexers or \ not isinstance(self._token_indexers['tokens'], SingleIdTokenIndexer): raise ConfigurationError( "EnhancedWikitextReader expects 'token_indexers' to contain " "a 'single_id' token indexer called 'tokens'.") if 'entity_ids' not in self._entity_indexers or \ not isinstance(self._entity_indexers['entity_ids'], SingleIdTokenIndexer): raise ConfigurationError( "EnhancedWikitextReader expects 'entity_indexers' to contain " "a 'single_id' token indexer called 'entity_ids'.") if 'raw_entity_ids' not in self._raw_entity_indexers or \ not isinstance(self._raw_entity_indexers['raw_entity_ids'], SingleIdTokenIndexer): raise ConfigurationError( "EnhancedWikitextReader expects 'raw_entity_indexers' to contain " "a 'single_id' token indexer called 'raw_entity_ids'.") if 'relations' not in self._relation_indexers or \ not isinstance(self._relation_indexers['relations'], SingleIdTokenIndexer): raise ConfigurationError( "EnhancedWikitextReader expects 'relation_indexers' to contain " "a 'single_id' token indexer called 'relations'.") self._alias_database = AliasDatabase.load(path=alias_database_path)
def sample(args: argparse.Namespace): model_archive = load_archive(args.model_archive_file, cuda_device=args.cuda_device) config = model_archive.config prepare_environment(config) model = model_archive.model model.eval() alias_database = AliasDatabase.load(args.alias_database) samples = model.sample(alias_database, batch_size=args.batch_size, length=args.length)
def test_load(self): # Test that the load function has the expected behavior alias_database = AliasDatabase.load( 'kglm/tests/fixtures/enhanced-wikitext-test/alias.pkl') test_entity = 'Q156216' # Benton County # Check that aliases are tokenized properly expected_tokenized_aliases = [['Benton', 'County'], ['Benton', 'County', ',', 'Washington']] assert alias_database._token_lookup[ test_entity] == expected_tokenized_aliases # Check that the id map has 4 unique tokens assert len(alias_database._id_map_lookup[test_entity]) == 4 # Check that the first token in each alias has the same local id test_id_array = alias_database._id_array_lookup[test_entity] assert test_id_array[0, 0] == test_id_array[1, 0]
def test_tensorize_and_lookup(self): # Tensor fields should be empty when ``AliasDatabase``` is created alias_database = AliasDatabase( token_lookup=self.token_lookup, id_map_lookup=self.id_map_lookup, id_array_lookup=self.id_array_lookup, token_to_entity_lookup=self.token_to_entity_lookup) assert not alias_database.is_tensorized # But should exist after ``AliasDatabase`` is tensorized alias_database.tensorize(self.vocab) return assert alias_database.is_tensorized assert alias_database._global_id_lookup != [] assert alias_database._local_id_lookup != [] tensor_dict = self.dataset.as_tensor_dict() entity_ids = tensor_dict['entity_identifiers']['entity_ids'] tokens = tensor_dict['tokens']['tokens'] global_tensor, local_tensor = alias_database.lookup(entity_ids) entity_id_tensor = alias_database.reverse_lookup(tokens) # The first two dimensions should match the batch_size and sequence length of the index. # The next dimensions should be the max number of aliases of all entities (in this case 2 # for 'Robert Logan', and 'Robby') and the max length of the aliases (again 2 because # 'Robert Logan' is two tokens). assert global_tensor.shape == (1, 6, 2, 2) assert local_tensor.shape == (1, 6, 2, 2) assert entity_id_tensor.shape == (1, 6, 1) # Check that the global ids match the vocabulary indices assert global_tensor[0, 0, 0, 0] == self.vocab.get_token_index( 'Robert', namespace='tokens') assert global_tensor[0, 1, 0, 0] == 0 # Padding since not an alias assert local_tensor[0, 0, 0, 0] == 1 assert local_tensor[0, 1, 0, 0] == 0 # Padding since not an alias match_token_idx = self.vocab.get_token_index('Entity1', namespace='entity_ids') nonmatch_token_idx = self.vocab.get_token_index('Entity2', namespace='entity_ids') assert entity_id_tensor[0, 0, match_token_idx] == 1 assert entity_id_tensor[0, 0, nonmathc_token_idx] == 0
def _forward_loop(self, source: Dict[str, torch.Tensor], alias_database: AliasDatabase, mention_type: torch.Tensor, raw_entity_ids: Dict[str, torch.Tensor], entity_ids: Dict[str, torch.Tensor], parent_ids: Dict[str, torch.Tensor], relations: Dict[str, torch.Tensor], shortlist: Dict[str, torch.Tensor], shortlist_inds: torch.Tensor) -> Dict[str, torch.Tensor]: # Get the token mask and extract indexed text fields. # shape: (batch_size, sequence_length) target_mask = get_text_field_mask(source) source = source['tokens'] raw_entity_ids = raw_entity_ids['raw_entity_ids'] entity_ids = entity_ids['entity_ids'] parent_ids = parent_ids['entity_ids'] relations = relations['relations'] logger.debug('Source & Target shape: %s', source.shape) logger.debug('Entity ids shape: %s', entity_ids.shape) logger.debug('Relations & Parent ids shape: %s', relations.shape) logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape) # Embed source tokens. # shape: (batch_size, sequence_length, embedding_dim) encoded, alpha_loss, beta_loss = self._encode_source(source) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 encoded_token, encoded_head, encoded_relation = encoded.split( splits, dim=-1) # Predict whether or not the next token will be an entity mention, and if so which type. mention_type_loss = self._mention_type_loss( encoded_token, mention_type, target_mask) self._avg_mention_type_loss(float(mention_type_loss)) # For new mentions, predict which entity (among those in the supplied shortlist) will be # mentioned. overlap_feature = alias_database.reverse_lookup(source) if self._use_shortlist: new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation, overlap_feature, shortlist_inds, shortlist, target_mask) else: new_entity_loss = self._new_entity_loss(encoded_head + encoded_relation, overlap_feature, entity_ids, None, target_mask) self._avg_new_entity_loss(float(new_entity_loss)) # For derived mentions, first predict which parent(s) to expand... knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head, encoded_relation, raw_entity_ids, entity_ids, parent_ids, target_mask) self._avg_knowledge_graph_entity_loss( float(knowledge_graph_entity_loss)) # Compute total loss loss = mention_type_loss + new_entity_loss + knowledge_graph_entity_loss # Activation regularization if self._alpha: loss = loss + self._alpha * alpha_loss # Temporal activation regularization (slowness) if self._beta: loss = loss + self._beta * beta_loss return {'loss': loss}
def _forward_loop(self, source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], alias_database: AliasDatabase, entity_ids: Dict[str, torch.Tensor], shortlist: Dict[str, torch.Tensor], shortlist_inds: torch.Tensor, alias_copy_inds: torch.Tensor) -> Dict[str, torch.Tensor]: # Get the token mask and unwrap the target tokens. target_mask = get_text_field_mask(target) target = target['tokens'] # Embed source tokens. source = source['tokens'] source_embeddings = embedded_dropout( embed=self._token_embedder, words=source, dropout=self._dropoute if self.training else 0) source_embeddings = self._locked_dropout( source_embeddings, self._dropouti) # Embed entities. entity_ids = entity_ids['entity_ids'] # entity_embeddings = embedded_dropout( # embed=self._entity_embedder, # words=entity_ids, # dropout=self._dropoute if self.training else 0) # entity_embeddings = self._locked_dropout(entity_embeddings, self._dropouti) # # Embed shortlist. # shortlist_mask = get_text_field_mask(shortlist) # shortlist = shortlist['entity_ids'] # shortlist_embeddings = embedded_dropout( # embed=self._entity_embedder, # words=shortlist, # dropout=self._dropoute if self.training else 0) # Encode source tokens. current_input = source_embeddings hidden_states = [] for layer, rnn in enumerate(self.rnns): # Retrieve previous hidden state for layer. if self._state is not None: prev_hidden = self._state['layer_%i' % layer] else: prev_hidden = None # Forward-pass. output, hidden = rnn(current_input, prev_hidden) output = output.contiguous() # Update hidden state for layer. hidden = tuple(h.detach() for h in hidden) hidden_states.append(hidden) # Apply dropout. if layer == self._num_layers - 1: dropped_output = self._locked_dropout(output, self._dropout) else: dropped_output = self._locked_dropout(output, self._dropouth) current_input = dropped_output encoded = current_input self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)} # Predict whether or not the next token will be an entity mention. This corresponds to the # case that the entity's id is not a padding token. # mention_loss = self._mention_loss(encoded, entity_ids.gt(0), target_mask) # Predict which entity (among those in the supplied shortlist) is going to be # mentioned. # entity_loss = self._entity_loss(encoded, # shortlist_inds, # target_mask, # shortlist_embeddings, # shortlist_mask) # Predict generation-mode scores. Start by concatenating predicted entity embeddings with # the encoder output - then feed through a linear layer. # concatenated = torch.cat((encoded, entity_embeddings), dim=-1) # condensed = self._fc_condense(concatenated) generate_scores = self._fc_generate(encoded) # Predict copy-mode scores. alias_tokens, alias_inds = alias_database.lookup(entity_ids) copy_scores = self._copy_scores(encoded, alias_tokens) # Combine scores to get vocab loss vocab_loss = self._vocab_loss(generate_scores, copy_scores, target, alias_copy_inds, target_mask, alias_inds, alias_tokens, entity_ids.gt(0)) # Compute total loss loss = vocab_loss # + mention_loss + entity_loss # Activation regularization if self._alpha: loss = loss + self._alpha * dropped_output.pow(2).mean() # Temporal activation regularization (slowness) if self._beta: loss = loss + self._beta * \ (output[:, 1:] - output[:, :-1]).pow(2).mean() return {'loss': loss}
def _forward_loop(self, source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], alias_database: AliasDatabase, mention_type: torch.Tensor, raw_entity_ids: Dict[str, torch.Tensor], entity_ids: Dict[str, torch.Tensor], parent_ids: Dict[str, torch.Tensor], relations: Dict[str, torch.Tensor], shortlist: Dict[str, torch.Tensor], shortlist_inds: torch.Tensor, alias_copy_inds: torch.Tensor) -> Dict[str, torch.Tensor]: # Get the token mask and extract indexed text fields. # shape: (batch_size, sequence_length) target_mask = get_text_field_mask(target) source = source['tokens'] target = target['tokens'] raw_entity_ids = raw_entity_ids['raw_entity_ids'] entity_ids = entity_ids['entity_ids'] logger.debug('Source & Target shape: %s', source.shape) logger.debug('Entity ids shape: %s', entity_ids.shape) logger.debug('Shortlist shape: %s', shortlist['entity_ids'].shape) # Embed source tokens. # shape: (batch_size, sequence_length, embedding_dim) encoded, alpha_loss, beta_loss = self._encode_source(source) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] encoded_token, encoded_head = encoded.split(splits, dim=-1) # Predict whether or not the next token will be an entity mention, and if so which type. mention_type = mention_type.gt(0).long() # Map 1, 2 -> 1 mention_type_loss = self._mention_type_loss( encoded_token, mention_type, target_mask) self._avg_mention_type_loss(float(mention_type_loss)) # For new mentions, predict which entity (among those in the supplied shortlist) will be # mentioned. if self._use_shortlist: new_entity_loss = self._new_entity_loss(encoded_head, shortlist_inds, shortlist, target_mask) else: new_entity_loss = self._new_entity_loss(encoded_head, entity_ids, None, target_mask) self._avg_new_entity_loss(float(new_entity_loss)) # Predict generation-mode scores. Note: these are W.R.T to entity_ids since we need the embedding. generate_scores = self._generate_scores(encoded_token, entity_ids) # Predict copy-mode scores. Note: these are W.R.T raw_entity_ids since we need to look up aliases. alias_tokens, alias_inds = alias_database.lookup(raw_entity_ids) copy_scores = self._copy_scores(encoded_token, alias_tokens) # Combine scores to get vocab loss vocab_loss, penalized_vocab_loss = self._vocab_loss(generate_scores, copy_scores, target, alias_copy_inds, target_mask, alias_inds, entity_ids.gt(0)) self._avg_vocab_loss(float(vocab_loss)) # Compute total loss. Also compute logp (needed for importance sampling evaluation). loss = vocab_loss + mention_type_loss + new_entity_loss logp = -(vocab_loss + mention_type_loss + new_entity_loss) * target_mask.sum() penalized_logp = - \ (penalized_vocab_loss + mention_type_loss + new_entity_loss) * target_mask.sum() # Activation regularization if self._alpha: loss = loss + self._alpha * alpha_loss # Temporal activation regularization (slowness) if self._beta: loss = loss + self._beta * beta_loss return {'loss': loss, 'logp': logp, 'penalized_logp': penalized_logp}