def test_sentence_encoder(self):
        with self.assertRaises(Exception):
            # pylint: disable=no-value-for-parameter
            # on purpose, should fail
            SentenceEncoder()
            # pylint: enable=no-value-for-parameter

        self._run_constructors(SentenceEncoder, SENTENCE_ENCODER_GOOD,
                               SENTENCE_ENCODER_BAD)
Exemplo n.º 2
0
    def test_save_and_load(self):
        """Try to save and load encoder."""
        vocabulary = Vocabulary()
        vocabulary.add_word("a")
        vocabulary.add_word("b")

        checkpoint_file = tempfile.NamedTemporaryFile(delete=False)
        checkpoint_file.close()

        encoder = SentenceEncoder(name="enc",
                                  vocabulary=Vocabulary(),
                                  data_id="data_id",
                                  embedding_size=10,
                                  rnn_size=20,
                                  max_input_len=30,
                                  save_checkpoint=checkpoint_file.name,
                                  load_checkpoint=checkpoint_file.name)

        encoder.input_sequence.register_input()

        # NOTE: This assert needs to be here otherwise the model has
        # no parameters since the sentence encoder is initialized lazily
        self.assertIsInstance(encoder.temporal_states, tf.Tensor)

        encoders_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope="enc")

        sess_1 = tf.Session()
        sess_1.run(tf.global_variables_initializer())
        encoder.save(sess_1)

        sess_2 = tf.Session()
        sess_2.run(tf.global_variables_initializer())
        encoder.load(sess_2)

        values_in_sess_1 = sess_1.run(encoders_variables)
        values_in_sess_2 = sess_2.run(encoders_variables)

        self.assertTrue(
            all(
                np.all(v1 == v2)
                for v1, v2 in zip(values_in_sess_1, values_in_sess_2)))

        os.remove(checkpoint_file.name)