def bow_snippets(self, token, snippets=None): """ Bag of words embedding for snippets""" if snippet_handler.is_snippet(token): """ Somehow in this part the program goes wrong in the server.(OK in my own computer) Phenomenon: token is predicted to be a snippet, and wrongly goes into this branch. Just ignore the assertion error. """ try: assert snippets except: return self(token) snippet_sequence = [] for snippet in snippets: if snippet.name == token: snippet_sequence = snippet.sequence break assert snippet_sequence snippet_embeddings = [self(subtoken) for subtoken in snippet_sequence] return dy.average(snippet_embeddings) else: return self(token)
def __call__(self, token): assert isinstance(token, int) or not snippet_handler.is_snippet(token), \ "embedder should only be called on flat tokens; use snippet_bow if " \ + "you are trying to encode snippets" if self.in_vocabulary(token): return self.token_embedding_matrix[self.vocab_token_lookup(token)] elif self.anonymizer and self.anonymizer.is_anon_tok(token): return self.entity_embedding_matrix[self.anonymizer.get_anon_id( token)] else: return self.token_embedding_matrix[self.unknown_token_id]
def __init__(self, token_sequences, filename, params, is_input, anonymizer=None): self.raw_vocab = Vocabulary( token_sequences, filename, functional_types=INPUT_FN_TYPES if is_input else OUTPUT_FN_TYPES, min_occur=MIN_INPUT_OCCUR if is_input else MIN_OUTPUT_OCCUR, ignore_fn=lambda x: snippets.is_snippet(x) or (anonymizer and anonymizer.is_anon_tok(x))) self.tokens = set(self.raw_vocab.token_to_id.keys()) self.inorder_tokens = self.raw_vocab.id_to_token assert len(self.inorder_tokens) == len(self.raw_vocab)
def bow_snippets(self, token, snippets=None): """ Bag of words embedding for snippets""" if snippet_handler.is_snippet(token): assert snippets snippet_sequence = [] for snippet in snippets: if snippet.name == token: snippet_sequence = snippet.sequence break assert snippet_sequence snippet_embeddings = [self(subtoken) for subtoken in snippet_sequence] return dy.average(snippet_embeddings) else: return self(token)
def predict_turn(self, utterance_final_state, input_hidden_states, max_generation_length, gold_query=None, snippets=None, input_sequence=None, feed_gold_tokens=False, training=False, first_utterance=True, gold_copy=None): """ Gets a prediction for a single turn -- calls decoder and updates loss, etc. TODO: this can probably be split into two methods, one that just predicts and another that computes the loss. """ predicted_sequence = [] fed_sequence = [] loss = None token_accuracy = 0. if feed_gold_tokens: decoder_results, pick_loss = self.decoder( utterance_final_state, input_hidden_states, max_generation_length, gold_sequence=gold_query, input_sequence=input_sequence, snippets=snippets, dropout_amount=self.dropout, controller=self.controller, first_utterance=first_utterance, gold_copy=gold_copy) all_scores = [step.scores for step in decoder_results.predictions] all_alignments = [ step.aligned_tokens for step in decoder_results.predictions ] # Compute the loss if not pick_loss: loss = du.compute_loss(gold_query, all_scores, all_alignments, get_token_indices) else: loss = du.compute_loss(gold_copy[1:], all_scores, all_alignments, get_token_indices) if pick_loss: loss += pick_loss if not loss: loss = dy.zeros((1, 1)) if not training: predicted_sequence = du.get_seq_from_scores( all_scores, all_alignments) token_accuracy = du.per_token_accuracy(gold_query, predicted_sequence) fed_sequence = gold_query else: decoder_results, pick_loss = self.decoder( utterance_final_state, input_hidden_states, max_generation_length, input_sequence=input_sequence, snippets=snippets, dropout_amount=self.dropout, first_utterance=first_utterance) predicted_sequence = decoder_results.sequence fed_sequence = predicted_sequence # fed_sequence contains EOS, which we don't need when encoding snippets. # also ignore the first state, as it contains the BEG encoding. decoder_states = [ pred.decoder_state for pred in decoder_results.predictions ] if pick_loss: fed_sequence = fed_sequence[1:] for token, state in zip(fed_sequence[:-1], decoder_states[1:]): if snippet_handler.is_snippet(token): snippet_length = 0 for snippet in snippets: if snippet.name == token: snippet_length = len(snippet.sequence) break assert snippet_length > 0 decoder_states.extend([state for _ in range(snippet_length)]) else: decoder_states.append(state) return (predicted_sequence, loss, token_accuracy, decoder_states, decoder_results)