예제 #1
0
 def test_tuple_input(self):
     # Input: a tuple of two 4+1+3=8-length arrays
     # Expected output: a list of 3 tuples, each has two arrays
     in_size0 = 2
     in_size1 = 3
     seqs_x0 = [
         np.random.uniform(-1, 1, size=(4, in_size0)).astype(np.float32),
         np.random.uniform(-1, 1, size=(1, in_size0)).astype(np.float32),
         np.random.uniform(-1, 1, size=(3, in_size0)).astype(np.float32),
     ]
     seqs_x1 = [
         np.random.uniform(-1, 1, size=(4, in_size1)).astype(np.float32),
         np.random.uniform(-1, 1, size=(1, in_size1)).astype(np.float32),
         np.random.uniform(-1, 1, size=(3, in_size1)).astype(np.float32),
     ]
     batched_seqs_x0 = np.concatenate(seqs_x0, axis=0)
     batched_seqs_x1 = np.concatenate(seqs_x1, axis=0)
     sections = [4, 5]
     split = stateless_recurrent.split_batched_sequences(
         (batched_seqs_x0, batched_seqs_x1), sections)
     self.assertEqual(len(split), 3)
     self.assertIsInstance(split[0], tuple)
     self.assertEqual(len(split[0]), 2)
     np.testing.assert_allclose(split[0][0].array, seqs_x0[0])
     np.testing.assert_allclose(split[0][1].array, seqs_x1[0])
     np.testing.assert_allclose(split[1][0].array, seqs_x0[1])
     np.testing.assert_allclose(split[1][1].array, seqs_x1[1])
     np.testing.assert_allclose(split[2][0].array, seqs_x0[2])
     np.testing.assert_allclose(split[2][1].array, seqs_x1[2])
예제 #2
0
 def n_step_forward(self, sequences, recurrent_state, output_mode):
     assert sequences
     assert output_mode in ['split', 'concat']
     if recurrent_state is None:
         recurrent_state_queue = [None] * len(self.recurrent_children)
     else:
         assert len(recurrent_state) == len(self.recurrent_children)
         recurrent_state_queue = list(reversed(recurrent_state))
     new_recurrent_state = []
     h = sequences
     seq_mode = True
     sections = np.cumsum([len(x) for x in sequences[:-1]], dtype=np.int32)
     for layer in self._layers:
         if is_recurrent_link(layer):
             if not seq_mode:
                 h = split_batched_sequences(h, sections)
                 seq_mode = True
             rs = recurrent_state_queue.pop()
             h, rs = call_recurrent_link(layer, h, rs, output_mode='split')
             new_recurrent_state.append(rs)
         else:
             if seq_mode:
                 seq_mode = False
                 h = concatenate_sequences(h)
             if isinstance(h, tuple):
                 h = layer(*h)
             else:
                 h = layer(h)
     if not seq_mode and output_mode == 'split':
         h = split_batched_sequences(h, sections)
         seq_mode = True
     elif seq_mode and output_mode == 'concat':
         h = concatenate_sequences(h)
         seq_mode = False
     assert seq_mode is (output_mode == 'split')
     assert not recurrent_state_queue
     assert len(new_recurrent_state) == len(self.recurrent_children)
     return h, tuple(new_recurrent_state)
예제 #3
0
 def test_array_input(self):
     # Input: an 4+1+3=8-length array
     # Expected output: a list of 3 arrays (4-, 1-, and 3-length)
     in_size = 2
     seqs_x = [
         np.random.uniform(-1, 1, size=(4, in_size)).astype(np.float32),
         np.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32),
         np.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32),
     ]
     batched_seqs = np.concatenate(seqs_x, axis=0)
     self.assertEqual(batched_seqs.shape, (8, in_size))
     sections = [4, 5]
     split = stateless_recurrent.split_batched_sequences(
         batched_seqs, sections)
     self.assertEqual(len(split), 3)
     np.testing.assert_allclose(split[0].array, seqs_x[0])
     np.testing.assert_allclose(split[1].array, seqs_x[1])
     np.testing.assert_allclose(split[2].array, seqs_x[2])