示例#1
0
    def test_attn_lstm_embedding(self):
        """Test that attention LSTM computation works properly."""
        max_depth = 5
        n_test = 5
        n_support = 11
        n_feat = 10

        test = np.random.rand(n_test, n_feat)
        support = np.random.rand(n_support, n_feat)
        with self.test_session() as sess:
            test = tf.convert_to_tensor(test, dtype=tf.float32)
            support = tf.convert_to_tensor(support, dtype=tf.float32)

            attn_embedding_layer = AttnLSTMEmbedding(n_test, n_support, n_feat,
                                                     max_depth)
            out_tensor = attn_embedding_layer(test, support)
            sess.run(tf.global_variables_initializer())
            test_out, support_out = out_tensor[0].eval(), out_tensor[1].eval()
            assert test_out.shape == (n_test, n_feat)
            assert support_out.shape == (n_support, n_feat)