示例#1
0
 def MultiRNNCell():
     """Multi-layer RNN cell."""
     return tl.Serial(
         tl.Parallel([], tl.Split(n_items=n_layers)),
         tl.SerialWithSideOutputs(
             [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
         tl.Parallel([], tl.Concatenate(n_items=n_layers)))
示例#2
0
    def test_serial_with_side_outputs_div_div(self):
        def some_layer():
            return tl.Parallel(DivideBy(2.0), DivideBy(5.0))

        layer = tl.SerialWithSideOutputs([some_layer(), some_layer()])
        xs = (np.array([1, 2, 3]), np.array([10, 20, 30, 40,
                                             50]), np.array([100, 200]))
        ys = layer(xs)
        output_shapes = [y.shape for y in ys]
        self.assertEqual(output_shapes, [(3, ), (5, ), (2, )])