def testComputeShape(self): shape = core_layers.Flatten().compute_output_shape((1, 2, 3, 2)) self.assertEqual(shape.as_list(), [1, 12]) shape = core_layers.Flatten().compute_output_shape((None, 3, 2)) self.assertEqual(shape.as_list(), [None, 6]) shape = core_layers.Flatten().compute_output_shape((None, 3, None)) self.assertEqual(shape.as_list(), [None, None])
def testFlattenUnknownAxes(self): with self.cached_session() as sess: x = array_ops.placeholder(shape=(5, None, None), dtype='float32') y = core_layers.Flatten()(x) np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))}) self.assertEqual(list(np_output.shape), [5, 6]) self.assertEqual(y.get_shape().as_list(), [5, None]) x = array_ops.placeholder(shape=(5, None, 2), dtype='float32') y = core_layers.Flatten()(x) np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))}) self.assertEqual(list(np_output.shape), [5, 6]) self.assertEqual(y.get_shape().as_list(), [5, None])
def testCreateFlatten(self): with self.cached_session() as sess: x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32') y = core_layers.Flatten()(x) np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))}) self.assertEqual(list(np_output.shape), [3, 6]) self.assertEqual(y.get_shape().as_list(), [None, 6]) x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32') y = core_layers.Flatten()(x) np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))}) self.assertEqual(list(np_output.shape), [1, 12]) self.assertEqual(y.get_shape().as_list(), [1, 12])
def testFlatten0D(self): x = array_ops.placeholder(shape=(None, ), dtype='float32') y = core_layers.Flatten()(x) with self.cached_session() as sess: np_output = sess.run(y, feed_dict={x: np.zeros((5, ))}) self.assertEqual(list(np_output.shape), [5, 1]) self.assertEqual(y.shape.as_list(), [None, 1])
def testDataFormat4d(self): np_input_channels_last = np.arange(24, dtype='float32').reshape( [1, 4, 3, 2]) with self.test_session() as sess: x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32') y = core_layers.Flatten(data_format='channels_last')(x) np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last}) x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32') y = core_layers.Flatten(data_format='channels_first')(x) np_input_channels_first = np.transpose(np_input_channels_last, [0, 3, 1, 2]) np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first}) self.assertAllEqual(np_output_cl, np_output_cf)
def testFlattenLargeDim(self): if any(platform.win32_ver()): self.skipTest('values are truncated on windows causing test failures') x = array_ops.placeholder(shape=(None, 21316, 21316, 80), dtype='float32') y = core_layers.Flatten()(x) self.assertEqual(y.shape.as_list(), [None, 21316 * 21316 * 80])
def testFlattenLargeBatchDim(self): batch_size = np.iinfo(np.int32).max + 10 x = array_ops.placeholder(shape=(batch_size, None, None, 1), dtype='float32') y = core_layers.Flatten()(x) self.assertEqual(y.shape.as_list(), [batch_size, None])