Ejemplo n.º 1
0
 def test_tuple_input(self):
     # Input: a list of 3 tuples, each has two arrays
     # Expected output: a tuple of two 4+1+3=8-length arrays
     in_size0 = 2
     in_size1 = 3
     seqs_x0 = [
         np.random.uniform(-1, 1, size=(4, in_size0)).astype(np.float32),
         np.random.uniform(-1, 1, size=(1, in_size0)).astype(np.float32),
         np.random.uniform(-1, 1, size=(3, in_size0)).astype(np.float32),
     ]
     seqs_x1 = [
         np.random.uniform(-1, 1, size=(4, in_size1)).astype(np.float32),
         np.random.uniform(-1, 1, size=(1, in_size1)).astype(np.float32),
         np.random.uniform(-1, 1, size=(3, in_size1)).astype(np.float32),
     ]
     seqs_x = [
         (seqs_x0[0], seqs_x1[0]),
         (seqs_x0[1], seqs_x1[1]),
         (seqs_x0[2], seqs_x1[2]),
     ]
     concat_seqs = stateless_recurrent.concatenate_sequences(seqs_x)
     self.assertIsInstance(concat_seqs, tuple)
     self.assertEqual(len(concat_seqs), 2)
     self.assertEqual(concat_seqs[0].shape, (8, in_size0))
     self.assertEqual(concat_seqs[1].shape, (8, in_size1))
     np.testing.assert_allclose(concat_seqs[0][:4].array, seqs_x0[0])
     np.testing.assert_allclose(concat_seqs[0][4:5].array, seqs_x0[1])
     np.testing.assert_allclose(concat_seqs[0][5:].array, seqs_x0[2])
     np.testing.assert_allclose(concat_seqs[1][:4].array, seqs_x1[0])
     np.testing.assert_allclose(concat_seqs[1][4:5].array, seqs_x1[1])
     np.testing.assert_allclose(concat_seqs[1][5:].array, seqs_x1[2])
Ejemplo n.º 2
0
 def n_step_forward(self, sequences, recurrent_state, output_mode):
     assert sequences
     assert output_mode in ['split', 'concat']
     if recurrent_state is None:
         recurrent_state_queue = [None] * len(self.recurrent_children)
     else:
         assert len(recurrent_state) == len(self.recurrent_children)
         recurrent_state_queue = list(reversed(recurrent_state))
     new_recurrent_state = []
     h = sequences
     seq_mode = True
     sections = np.cumsum([len(x) for x in sequences[:-1]], dtype=np.int32)
     for layer in self._layers:
         if is_recurrent_link(layer):
             if not seq_mode:
                 h = split_batched_sequences(h, sections)
                 seq_mode = True
             rs = recurrent_state_queue.pop()
             h, rs = call_recurrent_link(layer, h, rs, output_mode='split')
             new_recurrent_state.append(rs)
         else:
             if seq_mode:
                 seq_mode = False
                 h = concatenate_sequences(h)
             if isinstance(h, tuple):
                 h = layer(*h)
             else:
                 h = layer(h)
     if not seq_mode and output_mode == 'split':
         h = split_batched_sequences(h, sections)
         seq_mode = True
     elif seq_mode and output_mode == 'concat':
         h = concatenate_sequences(h)
         seq_mode = False
     assert seq_mode is (output_mode == 'split')
     assert not recurrent_state_queue
     assert len(new_recurrent_state) == len(self.recurrent_children)
     return h, tuple(new_recurrent_state)
Ejemplo n.º 3
0
 def test_array_input(self):
     # Input: a list of 3 arrays (4-, 1-, and 3-length)
     # Expected output: an 4+1+3=8-length array
     in_size = 2
     seqs_x = [
         np.random.uniform(-1, 1, size=(4, in_size)).astype(np.float32),
         np.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32),
         np.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32),
     ]
     concat_seqs = stateless_recurrent.concatenate_sequences(seqs_x)
     self.assertEqual(concat_seqs.shape, (8, in_size))
     np.testing.assert_allclose(concat_seqs[:4].array, seqs_x[0])
     np.testing.assert_allclose(concat_seqs[4:5].array, seqs_x[1])
     np.testing.assert_allclose(concat_seqs[5:].array, seqs_x[2])