Esempio n. 1
0
    def bow_snippets(self, token, snippets=None):
        """ Bag of words embedding for snippets"""
        if snippet_handler.is_snippet(token):
            """
            Somehow in this part the program goes wrong in the server.(OK in my own computer)
            Phenomenon: token is predicted to be a snippet, and wrongly goes into this branch.
            Just ignore the assertion error.
            """
            try:
                assert snippets
            except:
                return self(token)
            snippet_sequence = []
            for snippet in snippets:
                if snippet.name == token:
                    snippet_sequence = snippet.sequence
                    break
            assert snippet_sequence

            snippet_embeddings = [self(subtoken)
                                  for subtoken in snippet_sequence]

            return dy.average(snippet_embeddings)
        else:
            return self(token)
Esempio n. 2
0
    def __call__(self, token):
        assert isinstance(token, int) or not snippet_handler.is_snippet(token), \
            "embedder should only be called on flat tokens; use snippet_bow if " \
            + "you are trying to encode snippets"

        if self.in_vocabulary(token):
            return self.token_embedding_matrix[self.vocab_token_lookup(token)]
        elif self.anonymizer and self.anonymizer.is_anon_tok(token):
            return self.entity_embedding_matrix[self.anonymizer.get_anon_id(
                token)]
        else:
            return self.token_embedding_matrix[self.unknown_token_id]
Esempio n. 3
0
    def __init__(self,
                 token_sequences,
                 filename,
                 params,
                 is_input,
                 anonymizer=None):
        self.raw_vocab = Vocabulary(
            token_sequences,
            filename,
            functional_types=INPUT_FN_TYPES if is_input else OUTPUT_FN_TYPES,
            min_occur=MIN_INPUT_OCCUR if is_input else MIN_OUTPUT_OCCUR,
            ignore_fn=lambda x: snippets.is_snippet(x) or
            (anonymizer and anonymizer.is_anon_tok(x)))
        self.tokens = set(self.raw_vocab.token_to_id.keys())
        self.inorder_tokens = self.raw_vocab.id_to_token

        assert len(self.inorder_tokens) == len(self.raw_vocab)
Esempio n. 4
0
    def bow_snippets(self, token, snippets=None):
        """ Bag of words embedding for snippets"""
        if snippet_handler.is_snippet(token):
            assert snippets
            snippet_sequence = []
            for snippet in snippets:
                if snippet.name == token:
                    snippet_sequence = snippet.sequence
                    break
            assert snippet_sequence

            snippet_embeddings = [self(subtoken)
                                  for subtoken in snippet_sequence]

            return dy.average(snippet_embeddings)
        else:
            return self(token)
Esempio n. 5
0
    def predict_turn(self,
                     utterance_final_state,
                     input_hidden_states,
                     max_generation_length,
                     gold_query=None,
                     snippets=None,
                     input_sequence=None,
                     feed_gold_tokens=False,
                     training=False,
                     first_utterance=True,
                     gold_copy=None):
        """ Gets a prediction for a single turn -- calls decoder and updates loss, etc.

        TODO:  this can probably be split into two methods, one that just predicts
            and another that computes the loss.
        """
        predicted_sequence = []
        fed_sequence = []
        loss = None
        token_accuracy = 0.

        if feed_gold_tokens:
            decoder_results, pick_loss = self.decoder(
                utterance_final_state,
                input_hidden_states,
                max_generation_length,
                gold_sequence=gold_query,
                input_sequence=input_sequence,
                snippets=snippets,
                dropout_amount=self.dropout,
                controller=self.controller,
                first_utterance=first_utterance,
                gold_copy=gold_copy)

            all_scores = [step.scores for step in decoder_results.predictions]
            all_alignments = [
                step.aligned_tokens for step in decoder_results.predictions
            ]
            # Compute the loss
            if not pick_loss:
                loss = du.compute_loss(gold_query, all_scores, all_alignments,
                                       get_token_indices)
            else:
                loss = du.compute_loss(gold_copy[1:], all_scores,
                                       all_alignments, get_token_indices)
            if pick_loss:
                loss += pick_loss
            if not loss:
                loss = dy.zeros((1, 1))

            if not training:
                predicted_sequence = du.get_seq_from_scores(
                    all_scores, all_alignments)

                token_accuracy = du.per_token_accuracy(gold_query,
                                                       predicted_sequence)

            fed_sequence = gold_query
        else:
            decoder_results, pick_loss = self.decoder(
                utterance_final_state,
                input_hidden_states,
                max_generation_length,
                input_sequence=input_sequence,
                snippets=snippets,
                dropout_amount=self.dropout,
                first_utterance=first_utterance)
            predicted_sequence = decoder_results.sequence

            fed_sequence = predicted_sequence

        # fed_sequence contains EOS, which we don't need when encoding snippets.
        # also ignore the first state, as it contains the BEG encoding.
        decoder_states = [
            pred.decoder_state for pred in decoder_results.predictions
        ]

        if pick_loss:
            fed_sequence = fed_sequence[1:]
        for token, state in zip(fed_sequence[:-1], decoder_states[1:]):
            if snippet_handler.is_snippet(token):
                snippet_length = 0
                for snippet in snippets:
                    if snippet.name == token:
                        snippet_length = len(snippet.sequence)
                        break
                assert snippet_length > 0
                decoder_states.extend([state for _ in range(snippet_length)])
            else:
                decoder_states.append(state)

        return (predicted_sequence, loss, token_accuracy, decoder_states,
                decoder_results)