def testFailedFillMissingShape(self):
    y = self._BuildSmallModel()
    run_options = config_pb2.RunOptions(
        trace_level=config_pb2.RunOptions.FULL_TRACE)
    run_metadata = config_pb2.RunMetadata()
    sess = session.Session()
    sess.run(y, options=run_options, run_metadata=run_metadata)

    graph2 = ops.Graph()
    y2 = copy_elements.copy_op_to_graph(y, graph2, [])
    self.assertEquals('<unknown>', str(y2.get_shape()))
    # run_metadata has special name for MatMul, hence failed to fill shape.
    tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
    self.assertEquals('<unknown>', str(y2.get_shape()))
  def testFillMissingShape(self):
    a, b, y = self._BuildSmallPlaceholderlModel()
    run_options = config_pb2.RunOptions(
        trace_level=config_pb2.RunOptions.FULL_TRACE)
    run_metadata = config_pb2.RunMetadata()
    sess = session.Session()
    sess.run(y,
             options=run_options,
             run_metadata=run_metadata,
             feed_dict={a: [[1, 2], [2, 3]],
                        b: [[1, 2], [2, 3]]})

    graph2 = ops.Graph()
    # Use copy_op_to_graph to remove shape information.
    y2 = copy_elements.copy_op_to_graph(y, graph2, [])
    self.assertEquals('<unknown>', str(y2.get_shape()))

    tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
    self.assertEquals('(2, 2)', str(y2.get_shape()))