예제 #1
0
    def testZeroBridge(self):
        encoder1 = rnn_encoder.StackBidirectionalRNNEncoder(
            BridgeTest.little_params1, tf.contrib.learn.ModeKeys.TRAIN)
        encoder2 = rnn_encoder.StackBidirectionalRNNEncoder(
            BridgeTest.little_params2, tf.contrib.learn.ModeKeys.TRAIN)
        encoder3 = rnn_encoder.StackBidirectionalRNNEncoder(
            BridgeTest.little_params3, tf.contrib.learn.ModeKeys.TRAIN)
        encoder_output1 = encoder1.encode(
            *build_inputs(*BridgeTest.input_shape), scope="1")
        encoder_output2 = encoder2.encode(
            *build_inputs(*BridgeTest.input_shape), scope="2")
        encoder_output3 = encoder3.encode(
            *build_inputs(*BridgeTest.input_shape), scope="3")
        bridge1 = bridges.ZeroBridge({}, encoder_output1,
                                     tf.contrib.learn.ModeKeys.TRAIN)
        bridge2 = bridges.ZeroBridge({}, encoder_output2,
                                     tf.contrib.learn.ModeKeys.TRAIN)
        bridge3 = bridges.ZeroBridge({}, encoder_output3,
                                     tf.contrib.learn.ModeKeys.TRAIN)

        state1 = bridge1(BridgeTest.state_size1)
        true_state1 = (build_state(BridgeTest.hidden_state_shape, True), )
        state2 = bridge2(BridgeTest.state_size2)
        true_state2 = (
            build_state(BridgeTest.hidden_state_shape, True),
            build_state(BridgeTest.hidden_state_shape, True),
            build_state(BridgeTest.hidden_state_shape, True),
        )
        state3 = bridge3(BridgeTest.state_size3)
        true_state3 = (
            build_state(BridgeTest.hidden_state_shape, False),
            build_state(BridgeTest.hidden_state_shape, False),
            build_state(BridgeTest.hidden_state_shape, False),
        )
        state1b = bridge1(BridgeTest.state_size1,
                          beam_size=BridgeTest.beam_size)
        true_state1b = (build_state(BridgeTest.hidden_state_shape_wb, True), )
        state2b = bridge2(BridgeTest.state_size2,
                          beam_size=BridgeTest.beam_size)
        true_state2b = (
            build_state(BridgeTest.hidden_state_shape_wb, True),
            build_state(BridgeTest.hidden_state_shape_wb, True),
            build_state(BridgeTest.hidden_state_shape_wb, True),
        )
        state3b = bridge3(BridgeTest.state_size3,
                          beam_size=BridgeTest.beam_size)
        true_state3b = (
            build_state(BridgeTest.hidden_state_shape_wb, False),
            build_state(BridgeTest.hidden_state_shape_wb, False),
            build_state(BridgeTest.hidden_state_shape_wb, False),
        )
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            self.assertAllEqual(sess.run(state1), sess.run(true_state1))
            self.assertAllEqual(sess.run(state2), sess.run(true_state2))
            self.assertAllEqual(sess.run(state3), sess.run(true_state3))
            self.assertAllEqual(sess.run(state1b), sess.run(true_state1b))
            self.assertAllEqual(sess.run(state2b), sess.run(true_state2b))
            self.assertAllEqual(sess.run(state3b), sess.run(true_state3b))
    def testStackBidirectionalRNNEncoder3(self):
        encoder = rnn_encoder.StackBidirectionalRNNEncoder(
            RNNEncoderTest.little_params3, tf.contrib.learn.ModeKeys.TRAIN)
        encoder_output = encoder.encode(*build_inputs(
            *RNNEncoderTest.input_shape))
        self.assertAllEqual(encoder_output.outputs.shape,
                            tf.TensorShape(RNNEncoderTest.bi_context_shape))

        final_states = flatten_final_states(encoder_output.final_states)
        true_final_states = flatten_final_states({
            "forward":
            build_state(RNNEncoderTest.hidden_state_shape, False),
            "backward":
            build_state(RNNEncoderTest.hidden_state_shape, False)
        })
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            final_state = sess.run(final_states)
            true_final_state = sess.run(true_final_states)
            self.assertAllEqual([x.shape for x in final_state],
                                [x.shape for x in true_final_state])
예제 #3
0
 def testPassThroughBridge(self):
     # TODO
     # encoder1 = rnn_encoder.StackBidirectionalRNNEncoder(BridgeTest.little_params1, tf.contrib.learn.ModeKeys.TRAIN)
     encoder2 = rnn_encoder.StackBidirectionalRNNEncoder(
         BridgeTest.little_params3, tf.contrib.learn.ModeKeys.TRAIN)
     encoder3 = rnn_encoder.UnidirectionalRNNEncoder(
         BridgeTest.little_params2, tf.contrib.learn.ModeKeys.TRAIN)
     # encoder_output1 = encoder1.encode(*build_inputs(*BridgeTest.input_shape), scope="1")
     encoder_output2 = encoder2.encode(
         *build_inputs(*BridgeTest.input_shape), scope="2")
     encoder_output3 = encoder3.encode(
         *build_inputs(*BridgeTest.input_shape), scope="3")
     # bridge1 = bridges.PassThroughBridge({}, encoder_output1, tf.contrib.learn.ModeKeys.TRAIN)
     # bridge4 = bridges.PassThroughBridge({"direction": "forward"}, encoder_output1, tf.contrib.learn.ModeKeys.TRAIN)
     bridge2 = bridges.PassThroughBridge({"direction": "backward"},
                                         encoder_output2,
                                         tf.contrib.learn.ModeKeys.TRAIN)
     bridge3 = bridges.PassThroughBridge({"direction": "forward"},
                                         encoder_output3,
                                         tf.contrib.learn.ModeKeys.TRAIN)
     # state4 = bridge4(BridgeTest.state_size1)
     # true_state4 = (encoder_output1.final_states["forward"],)
     state2 = bridge2(BridgeTest.state_size2)
     true_state2 = (
         encoder_output2.final_states["backward"],
         encoder_output2.final_states["backward"],
         encoder_output2.final_states["backward"],
     )
     state3 = bridge3(BridgeTest.state_size3)
     true_state3 = (
         encoder_output3.final_states,
         encoder_output3.final_states,
         encoder_output3.final_states,
     )
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         self.assertAllEqual(sess.run(state2), sess.run(true_state2))
         self.assertAllEqual(sess.run(state3), sess.run(true_state3))