Exemplo n.º 1
0
 def testSeparableLstmUnsupportedCellType(self):
     with self.test_session():
         inputs = tf.constant(_rand(2, 5, 7, 11))
         cell_type = "unsupported cell type"
         with self.assertRaisesRegexp(
                 NotImplementedError,
                 cell_type + " not supported by ndlstm."):
             lstm2d.separable_lstm(inputs,
                                   8,
                                   data_format='NCHW',
                                   cell_type=cell_type)
Exemplo n.º 2
0
 def testSeparableLstmDimsNCHW(self):
     with self.test_session():
         inputs = tf.constant(_rand(2, 5, 7, 11))
         outputs = lstm2d.separable_lstm(inputs, 8, data_format='NCHW')
         tf.global_variables_initializer().run()
         result = outputs.eval()
         self.assertEqual(tuple(result.shape), (2, 8, 7, 11))
Exemplo n.º 3
0
 def testSeparableLstmDimsInvalidDataFormat(self):
     with self.test_session():
         inputs = tf.constant(_rand(2, 7, 11, 5))
         with self.assertRaisesRegexp(
                 ValueError, 'data_format has to be either NCHW or NHWC.'):
             lstm2d.separable_lstm(inputs, 8, data_format="CHWN")