예제 #1
0
    def _get_context_id(self, context):
        """Get a unique ID for an op-construction context (e.g., a graph).

    If the graph has been encountered before, reuse the same unique ID.
    When encountering a new context (graph), this methods writes a DebugEvent
    proto with the debugged_graph field to the proper DebugEvent file.

    Args:
      context: A context to get the unique ID for. Must be hashable. E.g., a
        Graph object.

    Returns:
      A unique ID for the context.
    """
        # Use the double-checked lock pattern to optimize the common case.
        if context in self._context_to_id:  # 1st check, without lock.
            return self._context_to_id[context]
        graph_is_new = False
        with self._context_lock:
            if context not in self._context_to_id:  # 2nd check, with lock.
                graph_is_new = True
                context_id = _get_id()
                self._context_to_id[context] = context_id
        if graph_is_new:
            self.get_writer().WriteDebuggedGraph(
                debug_event_pb2.DebuggedGraph(
                    graph_id=context_id,
                    graph_name=getattr(context, "name", None),
                    outer_context_id=self._get_outer_context_id(context)))
        return self._context_to_id[context]
  def testRangeReadingGraphExecutionTraces(self, begin, end, expected_begin,
                                           expected_end):
    writer = debug_events_writer.DebugEventsWriter(
        self.dump_root, self.tfdbg_run_id, circular_buffer_size=-1)
    debugged_graph = debug_event_pb2.DebuggedGraph(
        graph_id="graph1", graph_name="graph1")
    writer.WriteDebuggedGraph(debugged_graph)
    for i in range(5):
      op_name = "Op_%d" % i
      graph_op_creation = debug_event_pb2.GraphOpCreation(
          op_name=op_name, graph_id="graph1")
      writer.WriteGraphOpCreation(graph_op_creation)
      trace = debug_event_pb2.GraphExecutionTrace(
          op_name=op_name, tfdbg_context_id="graph1")
      writer.WriteGraphExecutionTrace(trace)
    writer.FlushNonExecutionFiles()
    writer.FlushExecutionFiles()
    writer.Close()

    with debug_events_reader.DebugDataReader(self.dump_root) as reader:
      reader.update()
      traces = reader.graph_execution_traces(begin=begin, end=end)
    self.assertLen(traces, expected_end - expected_begin)
    self.assertEqual(traces[0].op_name, "Op_%d" % expected_begin)
    self.assertEqual(traces[-1].op_name, "Op_%d" % (expected_end - 1))
예제 #3
0
    def testWriteGraphOpCreationAndDebuggedGraphs(self):
        writer = debug_events_writer.DebugEventsWriter(self.dump_root)
        num_op_creations = 10
        for i in range(num_op_creations):
            graph_op_creation = debug_event_pb2.GraphOpCreation()
            graph_op_creation.op_type = "Conv2D"
            graph_op_creation.op_name = "Conv2D_%d" % i
            writer.WriteGraphOpCreation(graph_op_creation)
        debugged_graph = debug_event_pb2.DebuggedGraph()
        debugged_graph.graph_id = "deadbeaf"
        debugged_graph.graph_name = "MyGraph1"
        writer.WriteDebuggedGraph(debugged_graph)
        writer.FlushNonExecutionFiles()

        source_files_paths = glob.glob(os.path.join(self.dump_root,
                                                    "*.graphs"))
        self.assertEqual(len(source_files_paths), 1)
        actuals = ReadDebugEvents(source_files_paths[0])
        self.assertEqual(len(actuals), num_op_creations + 1)
        for i in range(num_op_creations):
            self.assertEqual(actuals[i].graph_op_creation.op_type, "Conv2D")
            self.assertEqual(actuals[i].graph_op_creation.op_name,
                             "Conv2D_%d" % i)
        self.assertEqual(actuals[num_op_creations].debugged_graph.graph_id,
                         "deadbeaf")
예제 #4
0
    def testConcurrentWritesToExecutionFiles(self):
        circular_buffer_size = 5
        writer = debug_events_writer.DebugEventsWriter(self.dump_root,
                                                       self.tfdbg_run_id,
                                                       circular_buffer_size)
        debugged_graph = debug_event_pb2.DebuggedGraph(graph_id="graph1",
                                                       graph_name="graph1")
        writer.WriteDebuggedGraph(debugged_graph)

        execution_state = {"counter": 0, "lock": threading.Lock()}

        def write_execution():
            execution = debug_event_pb2.Execution()
            with execution_state["lock"]:
                execution.op_type = "OpType%d" % execution_state["counter"]
                execution_state["counter"] += 1
            writer.WriteExecution(execution)

        graph_execution_trace_state = {"counter": 0, "lock": threading.Lock()}

        def write_graph_execution_trace():
            with graph_execution_trace_state["lock"]:
                op_name = "Op%d" % graph_execution_trace_state["counter"]
                graph_op_creation = debug_event_pb2.GraphOpCreation(
                    op_type="FooOp", op_name=op_name, graph_id="graph1")
                trace = debug_event_pb2.GraphExecutionTrace(
                    op_name=op_name, tfdbg_context_id="graph1")
                graph_execution_trace_state["counter"] += 1
            writer.WriteGraphOpCreation(graph_op_creation)
            writer.WriteGraphExecutionTrace(trace)

        threads = []
        for i in range(circular_buffer_size * 4):
            if i % 2 == 0:
                target = write_execution
            else:
                target = write_graph_execution_trace
            thread = threading.Thread(target=target)
            thread.start()
            threads.append(thread)
        for thread in threads:
            thread.join()
        writer.FlushNonExecutionFiles()
        writer.FlushExecutionFiles()

        with debug_events_reader.DebugDataReader(self.dump_root) as reader:
            reader.update()
            # Verify the content of the .execution file.
            executions = reader.executions()
            executed_op_types = [execution.op_type for execution in executions]
            self.assertLen(executed_op_types, circular_buffer_size)
            self.assertLen(executed_op_types, len(set(executed_op_types)))

            # Verify the content of the .graph_execution_traces file.
            op_names = [
                trace.op_name for trace in reader.graph_execution_traces()
            ]
            self.assertLen(op_names, circular_buffer_size)
            self.assertLen(op_names, len(set(op_names)))
예제 #5
0
    def testReadingTwoFileSetsWithTheSameDumpRootSucceeds(self):
        # To simulate a multi-host data dump, we first generate file sets in two
        # different directories, with the same tfdbg_run_id, and then combine them.
        tfdbg_run_id = "foo"
        for i in range(2):
            writer = debug_events_writer.DebugEventsWriter(
                os.path.join(self.dump_root, str(i)),
                tfdbg_run_id,
                circular_buffer_size=-1)
            if i == 0:
                debugged_graph = debug_event_pb2.DebuggedGraph(
                    graph_id="graph1", graph_name="graph1")
                writer.WriteDebuggedGraph(debugged_graph)
                op_name = "Op_0"
                graph_op_creation = debug_event_pb2.GraphOpCreation(
                    op_type="FooOp", op_name=op_name, graph_id="graph1")
                writer.WriteGraphOpCreation(graph_op_creation)
                op_name = "Op_1"
                graph_op_creation = debug_event_pb2.GraphOpCreation(
                    op_type="FooOp", op_name=op_name, graph_id="graph1")
                writer.WriteGraphOpCreation(graph_op_creation)
            for _ in range(10):
                trace = debug_event_pb2.GraphExecutionTrace(
                    op_name="Op_%d" % i, tfdbg_context_id="graph1")
                writer.WriteGraphExecutionTrace(trace)
                writer.FlushNonExecutionFiles()
                writer.FlushExecutionFiles()

        # Move all files from the subdirectory /1 to subdirectory /0.
        dump_root_0 = os.path.join(self.dump_root, "0")
        src_paths = glob.glob(os.path.join(self.dump_root, "1", "*"))
        for src_path in src_paths:
            dst_path = os.path.join(
                dump_root_0,
                # Rename the file set to avoid file name collision.
                re.sub(r"(tfdbg_events\.\d+)", r"\g<1>1",
                       os.path.basename(src_path)))
            os.rename(src_path, dst_path)

        with debug_events_reader.DebugDataReader(dump_root_0) as reader:
            reader.update()
            # Verify the content of the .graph_execution_traces file.
            trace_digests = reader.graph_execution_traces(digest=True)
            self.assertLen(trace_digests, 20)
            for _ in range(10):
                trace = reader.read_graph_execution_trace(trace_digests[i])
                self.assertEqual(trace.op_name, "Op_0")
            for _ in range(10):
                trace = reader.read_graph_execution_trace(trace_digests[i +
                                                                        10])
                self.assertEqual(trace.op_name, "Op_1")
예제 #6
0
    def testConcurrentGraphExecutionTraceUpdateAndRandomRead(self):
        circular_buffer_size = -1
        writer = debug_events_writer.DebugEventsWriter(self.dump_root,
                                                       self.tfdbg_run_id,
                                                       circular_buffer_size)
        debugged_graph = debug_event_pb2.DebuggedGraph(graph_id="graph1",
                                                       graph_name="graph1")
        writer.WriteDebuggedGraph(debugged_graph)

        writer_state = {"counter": 0, "done": False}

        with debug_events_reader.DebugDataReader(self.dump_root) as reader:

            def write_and_update_job():
                while True:
                    if writer_state["done"]:
                        break
                    op_name = "Op%d" % writer_state["counter"]
                    graph_op_creation = debug_event_pb2.GraphOpCreation(
                        op_type="FooOp", op_name=op_name, graph_id="graph1")
                    writer.WriteGraphOpCreation(graph_op_creation)
                    trace = debug_event_pb2.GraphExecutionTrace(
                        op_name=op_name, tfdbg_context_id="graph1")
                    writer.WriteGraphExecutionTrace(trace)
                    writer_state["counter"] += 1
                    writer.FlushNonExecutionFiles()
                    writer.FlushExecutionFiles()
                    reader.update()

            # On the sub-thread, keep writing and reading new GraphExecutionTraces.
            write_and_update_thread = threading.Thread(
                target=write_and_update_job)
            write_and_update_thread.start()
            # On the main thread, do concurrent random read.
            while True:
                digests = reader.graph_execution_traces(digest=True)
                if digests:
                    trace_0 = reader.read_graph_execution_trace(digests[0])
                    self.assertEqual(trace_0.op_name, "Op0")
                    writer_state["done"] = True
                    break
                else:
                    time.sleep(0.1)
                    continue
            write_and_update_thread.join()
예제 #7
0
    def testConcurrentGraphExecutionTraceRandomReads(self):
        circular_buffer_size = -1
        writer = debug_events_writer.DebugEventsWriter(self.dump_root,
                                                       self.tfdbg_run_id,
                                                       circular_buffer_size)
        debugged_graph = debug_event_pb2.DebuggedGraph(graph_id="graph1",
                                                       graph_name="graph1")
        writer.WriteDebuggedGraph(debugged_graph)

        for i in range(100):
            op_name = "Op%d" % i
            graph_op_creation = debug_event_pb2.GraphOpCreation(
                op_type="FooOp", op_name=op_name, graph_id="graph1")
            writer.WriteGraphOpCreation(graph_op_creation)
            trace = debug_event_pb2.GraphExecutionTrace(
                op_name=op_name, tfdbg_context_id="graph1")
            writer.WriteGraphExecutionTrace(trace)
        writer.FlushNonExecutionFiles()
        writer.FlushExecutionFiles()

        reader = debug_events_reader.DebugDataReader(self.dump_root)
        reader.update()
        traces = [None] * 100

        def read_job_1():
            digests = reader.graph_execution_traces(digest=True)
            for i in range(49, -1, -1):
                traces[i] = reader.read_graph_execution_trace(digests[i])

        def read_job_2():
            digests = reader.graph_execution_traces(digest=True)
            for i in range(99, 49, -1):
                traces[i] = reader.read_graph_execution_trace(digests[i])

        thread_1 = threading.Thread(target=read_job_1)
        thread_2 = threading.Thread(target=read_job_2)
        thread_1.start()
        thread_2.start()
        thread_1.join()
        thread_2.join()
        for i in range(100):
            self.assertEqual(traces[i].op_name, "Op%d" % i)
  def testWriteGraphOpCreationAndDebuggedGraphs(self):
    writer = debug_events_writer.DebugEventsWriter(self.dump_root)
    num_op_creations = 10
    for i in range(num_op_creations):
      graph_op_creation = debug_event_pb2.GraphOpCreation()
      graph_op_creation.op_type = "Conv2D"
      graph_op_creation.op_name = "Conv2D_%d" % i
      writer.WriteGraphOpCreation(graph_op_creation)
    debugged_graph = debug_event_pb2.DebuggedGraph()
    debugged_graph.graph_id = "deadbeaf"
    debugged_graph.graph_name = "MyGraph1"
    writer.WriteDebuggedGraph(debugged_graph)
    writer.FlushNonExecutionFiles()

    reader = debug_events_reader.DebugEventsReader(self.dump_root)
    actuals = list(reader.graphs_iterator())
    self.assertLen(actuals, num_op_creations + 1)
    for i in range(num_op_creations):
      self.assertEqual(actuals[i].graph_op_creation.op_type, "Conv2D")
      self.assertEqual(actuals[i].graph_op_creation.op_name, "Conv2D_%d" % i)
    self.assertEqual(actuals[num_op_creations].debugged_graph.graph_id,
                     "deadbeaf")