コード例 #1
0
  def testComputeShape(self):
    shape = core_layers.Flatten().compute_output_shape((1, 2, 3, 2))
    self.assertEqual(shape.as_list(), [1, 12])

    shape = core_layers.Flatten().compute_output_shape((None, 3, 2))
    self.assertEqual(shape.as_list(), [None, 6])

    shape = core_layers.Flatten().compute_output_shape((None, 3, None))
    self.assertEqual(shape.as_list(), [None, None])
コード例 #2
0
  def testFlattenUnknownAxes(self):
    with self.cached_session() as sess:
      x = tf.compat.v1.placeholder(shape=(5, None, None), dtype='float32')
      y = core_layers.Flatten()(x)
      np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
      self.assertEqual(list(np_output.shape), [5, 6])
      self.assertEqual(y.get_shape().as_list(), [5, None])

      x = tf.compat.v1.placeholder(shape=(5, None, 2), dtype='float32')
      y = core_layers.Flatten()(x)
      np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
      self.assertEqual(list(np_output.shape), [5, 6])
      self.assertEqual(y.get_shape().as_list(), [5, None])
コード例 #3
0
  def testCreateFlatten(self):
    with self.cached_session() as sess:
      x = tf.compat.v1.placeholder(shape=(None, 2, 3), dtype='float32')
      y = core_layers.Flatten()(x)
      np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
      self.assertEqual(list(np_output.shape), [3, 6])
      self.assertEqual(y.get_shape().as_list(), [None, 6])

      x = tf.compat.v1.placeholder(shape=(1, 2, 3, 2), dtype='float32')
      y = core_layers.Flatten()(x)
      np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
      self.assertEqual(list(np_output.shape), [1, 12])
      self.assertEqual(y.get_shape().as_list(), [1, 12])
コード例 #4
0
  def testFlattenLargeDim(self):
    if any(platform.win32_ver()):
      self.skipTest('values are truncated on windows causing test failures')

    x = tf.compat.v1.placeholder(shape=(None, 21316, 21316, 80), dtype='float32')
    y = core_layers.Flatten()(x)
    self.assertEqual(y.shape.as_list(), [None, 21316 * 21316 * 80])
コード例 #5
0
 def testFlatten0D(self):
   x = tf.compat.v1.placeholder(shape=(None,), dtype='float32')
   y = core_layers.Flatten()(x)
   with self.cached_session() as sess:
     np_output = sess.run(y, feed_dict={x: np.zeros((5,))})
   self.assertEqual(list(np_output.shape), [5, 1])
   self.assertEqual(y.shape.as_list(), [None, 1])
コード例 #6
0
  def testDataFormat4d(self):
    np_input_channels_last = np.arange(
        24, dtype='float32').reshape([1, 4, 3, 2])

    with self.test_session() as sess:
      x = tf.compat.v1.placeholder(shape=(1, 4, 3, 2), dtype='float32')
      y = core_layers.Flatten(data_format='channels_last')(x)
      np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})

      x = tf.compat.v1.placeholder(shape=(1, 2, 4, 3), dtype='float32')
      y = core_layers.Flatten(data_format='channels_first')(x)
      np_input_channels_first = np.transpose(np_input_channels_last,
                                             [0, 3, 1, 2])
      np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})

      self.assertAllEqual(np_output_cl, np_output_cf)
コード例 #7
0
 def testFlattenLargeBatchDim(self):
   batch_size = np.iinfo(np.int32).max + 10
   x = tf.compat.v1.placeholder(
       shape=(batch_size, None, None, 1), dtype='float32')
   y = core_layers.Flatten()(x)
   self.assertEqual(y.shape.as_list(), [batch_size, None])