예제 #1
0
    def testComputeShape(self):
        shape = core_layers.Flatten().compute_output_shape((1, 2, 3, 2))
        self.assertEqual(shape.as_list(), [1, 12])

        shape = core_layers.Flatten().compute_output_shape((None, 3, 2))
        self.assertEqual(shape.as_list(), [None, 6])

        shape = core_layers.Flatten().compute_output_shape((None, 3, None))
        self.assertEqual(shape.as_list(), [None, None])
예제 #2
0
    def testFlattenUnknownAxes(self):
        with self.cached_session() as sess:
            x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
            y = core_layers.Flatten()(x)
            np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
            self.assertEqual(list(np_output.shape), [5, 6])
            self.assertEqual(y.get_shape().as_list(), [5, None])

            x = array_ops.placeholder(shape=(5, None, 2), dtype='float32')
            y = core_layers.Flatten()(x)
            np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
            self.assertEqual(list(np_output.shape), [5, 6])
            self.assertEqual(y.get_shape().as_list(), [5, None])
예제 #3
0
    def testCreateFlatten(self):
        with self.cached_session() as sess:
            x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
            y = core_layers.Flatten()(x)
            np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
            self.assertEqual(list(np_output.shape), [3, 6])
            self.assertEqual(y.get_shape().as_list(), [None, 6])

            x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
            y = core_layers.Flatten()(x)
            np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
            self.assertEqual(list(np_output.shape), [1, 12])
            self.assertEqual(y.get_shape().as_list(), [1, 12])
예제 #4
0
 def testFlatten0D(self):
     x = array_ops.placeholder(shape=(None, ), dtype='float32')
     y = core_layers.Flatten()(x)
     with self.cached_session() as sess:
         np_output = sess.run(y, feed_dict={x: np.zeros((5, ))})
     self.assertEqual(list(np_output.shape), [5, 1])
     self.assertEqual(y.shape.as_list(), [None, 1])
예제 #5
0
    def testDataFormat4d(self):
        np_input_channels_last = np.arange(24, dtype='float32').reshape(
            [1, 4, 3, 2])

        with self.test_session() as sess:
            x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32')
            y = core_layers.Flatten(data_format='channels_last')(x)
            np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})

            x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32')
            y = core_layers.Flatten(data_format='channels_first')(x)
            np_input_channels_first = np.transpose(np_input_channels_last,
                                                   [0, 3, 1, 2])
            np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})

            self.assertAllEqual(np_output_cl, np_output_cf)
예제 #6
0
  def testFlattenLargeDim(self):
    if any(platform.win32_ver()):
      self.skipTest('values are truncated on windows causing test failures')

    x = array_ops.placeholder(shape=(None, 21316, 21316, 80), dtype='float32')
    y = core_layers.Flatten()(x)
    self.assertEqual(y.shape.as_list(), [None, 21316 * 21316 * 80])
예제 #7
0
 def testFlattenLargeBatchDim(self):
     batch_size = np.iinfo(np.int32).max + 10
     x = array_ops.placeholder(shape=(batch_size, None, None, 1),
                               dtype='float32')
     y = core_layers.Flatten()(x)
     self.assertEqual(y.shape.as_list(), [batch_size, None])