def MultiRNNCell(): """Multi-layer RNN cell.""" return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)))
def test_serial_with_side_outputs_div_div(self): def some_layer(): return tl.Parallel(DivideBy(2.0), DivideBy(5.0)) layer = tl.SerialWithSideOutputs([some_layer(), some_layer()]) xs = (np.array([1, 2, 3]), np.array([10, 20, 30, 40, 50]), np.array([100, 200])) ys = layer(xs) output_shapes = [y.shape for y in ys] self.assertEqual(output_shapes, [(3, ), (5, ), (2, )])