def testSeparableLstmUnsupportedCellType(self): with self.test_session(): inputs = tf.constant(_rand(2, 5, 7, 11)) cell_type = "unsupported cell type" with self.assertRaisesRegexp( NotImplementedError, cell_type + " not supported by ndlstm."): lstm2d.separable_lstm(inputs, 8, data_format='NCHW', cell_type=cell_type)
def testSeparableLstmDimsNCHW(self): with self.test_session(): inputs = tf.constant(_rand(2, 5, 7, 11)) outputs = lstm2d.separable_lstm(inputs, 8, data_format='NCHW') tf.global_variables_initializer().run() result = outputs.eval() self.assertEqual(tuple(result.shape), (2, 8, 7, 11))
def testSeparableLstmDimsInvalidDataFormat(self): with self.test_session(): inputs = tf.constant(_rand(2, 7, 11, 5)) with self.assertRaisesRegexp( ValueError, 'data_format has to be either NCHW or NHWC.'): lstm2d.separable_lstm(inputs, 8, data_format="CHWN")