def testFailedFillMissingShape(self):
    y = self._BuildSmallModel()
    run_options = config_pb2.RunOptions(
        trace_level=config_pb2.RunOptions.FULL_TRACE)
    run_metadata = config_pb2.RunMetadata()
    sess = session.Session()
    sess.run(y, options=run_options, run_metadata=run_metadata)

    graph2 = ops.Graph()
    y2 = copy_elements.copy_op_to_graph(y, graph2, [])
    self.assertEquals('<unknown>', str(y2.get_shape()))
    # run_metadata has special name for MatMul, hence failed to fill shape.
    tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
    self.assertEquals('<unknown>', str(y2.get_shape()))
  def testFillMissingShape(self):
    a, b, y = self._BuildSmallPlaceholderlModel()
    run_options = config_pb2.RunOptions(
        trace_level=config_pb2.RunOptions.FULL_TRACE)
    run_metadata = config_pb2.RunMetadata()
    sess = session.Session()
    sess.run(y,
             options=run_options,
             run_metadata=run_metadata,
             feed_dict={a: [[1, 2], [2, 3]],
                        b: [[1, 2], [2, 3]]})

    graph2 = ops.Graph()
    # Use copy_op_to_graph to remove shape information.
    y2 = copy_elements.copy_op_to_graph(y, graph2, [])
    self.assertEquals('<unknown>', str(y2.get_shape()))

    tfprof_logger._fill_missing_graph_shape(graph2, run_metadata)
    self.assertEquals('(2, 2)', str(y2.get_shape()))
Example #3
0
  def testOpsCopy(self):

    with graph1.as_default():
      #Initialize a basic expression y = ax + b
      x = array_ops.placeholder("float")
      a = variables.Variable(3.0)
      b = constant_op.constant(4.0)
      ax = math_ops.multiply(x, a)
      y = math_ops.add(ax, b)
      #Initialize session
      sess1 = session_lib.Session()
      #Initialize the Variable
      variables.global_variables_initializer().run(session=sess1)

    #First, initialize a as a Variable in graph2
    a1 = copy_elements.copy_variable_to_graph(a, graph2)

    #Initialize a1 in graph2
    with graph2.as_default():
      #Initialize session
      sess2 = session_lib.Session()
      #Initialize the Variable
      variables.global_variables_initializer().run(session=sess2)

    #Initialize a copy of y in graph2
    y1 = copy_elements.copy_op_to_graph(y, graph2, [a1])

    #Now that y has been copied, x must be copied too.
    #Get that instance
    x1 = copy_elements.get_copied_op(x, graph2)

    #Compare values of y & y1 for a sample input
    #and check if they match
    v1 = y.eval({x: 5}, session=sess1)
    v2 = y1.eval({x1: 5}, session=sess2)

    assert v1 == v2
Example #4
0
    def testOpsCopy(self):

        with graph1.as_default():
            #Initialize a basic expression y = ax + b
            x = array_ops.placeholder("float")
            a = variables.Variable(3.0)
            b = constant_op.constant(4.0)
            ax = math_ops.multiply(x, a)
            y = math_ops.add(ax, b)
            #Initialize session
            sess1 = session_lib.Session()
            #Initialize the Variable
            variables.global_variables_initializer().run(session=sess1)

        #First, initialize a as a Variable in graph2
        a1 = copy_elements.copy_variable_to_graph(a, graph2)

        #Initialize a1 in graph2
        with graph2.as_default():
            #Initialize session
            sess2 = session_lib.Session()
            #Initialize the Variable
            variables.global_variables_initializer().run(session=sess2)

        #Initialize a copy of y in graph2
        y1 = copy_elements.copy_op_to_graph(y, graph2, [a1])

        #Now that y has been copied, x must be copied too.
        #Get that instance
        x1 = copy_elements.get_copied_op(x, graph2)

        #Compare values of y & y1 for a sample input
        #and check if they match
        v1 = y.eval({x: 5}, session=sess1)
        v2 = y1.eval({x1: 5}, session=sess2)

        assert v1 == v2