def testSplitVWithNonConstAxis(self):
    if test.is_gpu_available(cuda_only=True):
      random_seed.set_random_seed(0)
      x = random_ops.truncated_normal([1, 784], seed=0)
      conv = _two_layer_model(x)
      dim = array_ops.placeholder(dtype='int32')
      sizes = constant_op.constant([50, 10, 4], shape=[3])
      split = gen_array_ops._split_v(
          value=conv, size_splits=sizes, axis=dim, num_split=3)
      output = math_ops.reduce_sum(split[0])

      with session.Session() as sess:
        output_val_ref = sess.run(output, feed_dict={dim: 3})

      with session.Session(config=_get_config()) as sess:
        metadata = config_pb2.RunMetadata()
        output_val = sess.run(output, run_metadata=metadata, feed_dict={dim: 3})

      nodes = []
      num_transposes = 0
      for node in metadata.cost_graph.node:
        if _is_transpose(node.name):
          num_transposes += 1
        nodes.append(node.name)

      # Four transposes were initially added in the Expand phase of
      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
      expected_num_transposes = 2
      self.assertEqual(expected_num_transposes, num_transposes)
      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
      self._assert_trans_nchw_to_nhwc('SplitV-0-0', nodes)
      self._assert_map_nhwc_to_nchw('SplitV-2', nodes)
      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
  def testSplitVWithNonConstAxis(self):
    if test.is_gpu_available(cuda_only=True):
      random_seed.set_random_seed(0)
      x = random_ops.truncated_normal([1, 784], seed=0)
      conv = _two_layer_model(x)
      dim = array_ops.placeholder(dtype='int32')
      sizes = constant_op.constant([50, 10, 4], shape=[3])
      split = gen_array_ops._split_v(
          value=conv, size_splits=sizes, axis=dim, num_split=3)
      output = math_ops.reduce_sum(split[0])

      with session.Session() as sess:
        output_val_ref = sess.run(output, feed_dict={dim: 3})

      with session.Session(config=_get_config()) as sess:
        metadata = config_pb2.RunMetadata()
        output_val = sess.run(output, run_metadata=metadata, feed_dict={dim: 3})

      nodes = []
      num_transposes = 0
      for node in metadata.cost_graph.node:
        if node.name.startswith('LayoutOptimizerTranspose'):
          num_transposes += 1
        nodes.append(node.name)

      # Four transposes were initially added in the Expand phase of
      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
      expected_num_transposes = 2
      self.assertEqual(expected_num_transposes, num_transposes)
      self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
      self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-SplitV-0-0', nodes)
      self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_SplitV_2', nodes)
      self.assertAllClose(output_val_ref, output_val, atol=1e-3)