예제 #1
0
 def testMultipleEventsFromDifferentDevicesAndSameTensorName(self):
     alert_1 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1)
     alert_2 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1)
     alert_3 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 1, 1, 1)
     registry = numerics_alert.NumericsAlertRegistry()
     registry.register(alert_1)
     registry.register(alert_2)
     registry.register(alert_3)
     self.assertEqual(
         [
             numerics_alert.NumericsAlertReportRow(
                 "/job:worker/replica:0/task:0/gpu:0",
                 "xent/Log:0",
                 1234,
                 1,
                 2,
                 2,
             ),
             numerics_alert.NumericsAlertReportRow(
                 "/job:worker/replica:0/task:1/gpu:0",
                 "xent/Log:0",
                 1434,
                 0,
                 1,
                 1,
             ),
         ],
         registry.report(),
     )
예제 #2
0
    def testCreateJsonableRegistry(self):
        alert = numerics_alert.NumericsAlert(
            "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1)
        registry = numerics_alert.NumericsAlertRegistry()
        registry.register(alert)

        triplet_list = registry.create_jsonable_registry()
        self.assertEqual(1, len(triplet_list))

        triplet = triplet_list[0]
        self.assertEqual("/job:worker/replica:0/task:1/gpu:0", triplet.device)
        self.assertEqual("xent/Log:0", triplet.tensor)
        self.assertListEqual([0, -1, -1],
                             list(triplet.jsonable_history["nan"]))
        self.assertListEqual([1, 1434, 1434],
                             list(triplet.jsonable_history["neg_inf"]))
        self.assertListEqual([1, 1434, 1434],
                             list(triplet.jsonable_history["pos_inf"]))
예제 #3
0
 def testSingleAlert(self):
     alert = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 10,
         10)
     registry = numerics_alert.NumericsAlertRegistry()
     registry.register(alert)
     self.assertEqual(
         [
             numerics_alert.NumericsAlertReportRow(
                 "/job:worker/replica:0/task:0/gpu:0",
                 "xent/Log:0",
                 1234,
                 0,
                 1,
                 1,
             )
         ],
         registry.report(),
     )
예제 #4
0
 def testLoadFromJson(self):
     registry = numerics_alert.NumericsAlertRegistry(initialization_list=[
         [
             "/job:localhost/replica:0/task:0/cpu:0",
             "MatMul:0",
             {
                 "pos_inf": [0, -1, -1],
                 "nan": [1624, 1496818651573005, 1496818690371163],
                 "neg_inf": [0, -1, -1],
             },
         ],
         [
             "/job:localhost/replica:0/task:0/cpu:0",
             "weight/Adagrad:0",
             {
                 "pos_inf": [0, -1, -1],
                 "nan": [1621, 1496818651607234, 1496818690370891],
                 "neg_inf": [0, -1, -1],
             },
         ],
     ])
     self.assertEqual(
         [
             numerics_alert.NumericsAlertReportRow(
                 "/job:localhost/replica:0/task:0/cpu:0",
                 "MatMul:0",
                 1496818651573005,
                 1624,
                 0,
                 0,
             ),
             numerics_alert.NumericsAlertReportRow(
                 "/job:localhost/replica:0/task:0/cpu:0",
                 "weight/Adagrad:0",
                 1496818651607234,
                 1621,
                 0,
                 0,
             ),
         ],
         registry.report(),
     )
예제 #5
0
 def testMultipleEventsFromSameDeviceAndSameTensor(self):
     alert_1 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 10,
         10)
     alert_2 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1634, 5, 5, 5)
     registry = numerics_alert.NumericsAlertRegistry()
     registry.register(alert_1)
     registry.register(alert_2)
     self.assertEqual(
         [
             numerics_alert.NumericsAlertReportRow(
                 "/job:worker/replica:0/task:0/gpu:0",
                 "xent/Log:0",
                 1234,
                 1,
                 2,
                 2,
             )
         ],
         registry.report(),
     )
예제 #6
0
 def testFilterReport(self):
     alert_1 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1)
     alert_2 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1)
     alert_3 = numerics_alert.NumericsAlert(
         "/job:worker/replica:0/task:0/gpu:0", "xent/Mean:0", 1634, 1, 1, 1)
     registry = numerics_alert.NumericsAlertRegistry()
     registry.register(alert_1)
     registry.register(alert_2)
     registry.register(alert_3)
     self.assertEqual(
         [
             numerics_alert.NumericsAlertReportRow(
                 "/job:worker/replica:0/task:1/gpu:0",
                 "xent/Log:0",
                 1434,
                 0,
                 1,
                 1,
             )
         ],
         registry.report(device_name_filter=r".*\/task:1\/.*"),
     )
     self.assertEqual(
         [
             numerics_alert.NumericsAlertReportRow(
                 "/job:worker/replica:0/task:0/gpu:0",
                 "xent/Mean:0",
                 1634,
                 1,
                 1,
                 1,
             )
         ],
         registry.report(tensor_name_filter=r".*Mean.*"),
     )
예제 #7
0
  def __init__(self,
               receive_port,
               logdir,
               always_flush=False):
    """Receives health pills from a debugger and writes them to disk.

    Args:
      receive_port: The port at which to receive health pills from the
        TensorFlow debugger.
      logdir: The directory in which to write events files that TensorBoard will
        read.
      always_flush: A boolean indicating whether the EventsWriter will be
        flushed after every write. Can be used for testing.
    """
    # We create a special directory within logdir to store debugger-related
    # events (if that directory does not already exist). This is necessary
    # because for each directory within logdir, TensorBoard only reads through
    # each events file once. There may be other non-debugger events files being
    # written to at the same time. Without this special directory, TensorBoard
    # may stop surfacing health pills after some arbitrary step value.
    debugger_directory = os.path.join(
        os.path.expanduser(logdir), constants.DEBUGGER_DATA_DIRECTORY_NAME)

    if not tf.gfile.Exists(debugger_directory):
      try:
        tf.gfile.MakeDirs(debugger_directory)
        tf.logging.info("Created directory for debugger data: %s",
                        debugger_directory)
      except tf.OpError as e:
        tf.logging.fatal(
            "Could not make directory for debugger data: %s. Error: %s",
            debugger_directory, e)

    self._events_writer_manager = events_writer_manager_lib.EventsWriterManager(
        events_directory=debugger_directory,
        always_flush=always_flush)

    # Write an event with a file version as the first event within the events
    # file. If the event version is 2, TensorBoard uses a path for purging
    # events that does not depend on step. This is important because debugger
    # events use a notion of step that differs from that of the rest of
    # TensorBoard.
    try:
      self._events_writer_manager.write_event(
          tf.Event(
              wall_time=0, step=0, file_version=constants.EVENTS_VERSION))
    except IOError as e:
      tf.logging.error(
          "Writing to %s failed: %s",
          self._events_writer_manager.get_current_file_name(), e)

    # See if a backup file exists. If so, use it to initialize the registry.
    self._registry_backup_file_path = os.path.join(
        debugger_directory, constants.ALERT_REGISTRY_BACKUP_FILE_NAME)
    initial_data = None

    if tf.gfile.Exists(self._registry_backup_file_path):
      # A backup file exists. Read its contents to use for initialization.
      with tf.gfile.Open(self._registry_backup_file_path, "r") as backup_file:
        try:
          # Use the data to initialize the registry.
          initial_data = json.load(backup_file)
        except ValueError as err:
          # Could not parse the data. No backup data obtained.
          tf.logging.error(
              "Could not parse contents of %s: %s",
              self._registry_backup_file_path, err)

    self._numerics_alert_registry = numerics_alert.NumericsAlertRegistry(
        initialization_list=initial_data)

    self._numerics_alert_lock = threading.Lock()
    curried_handler_constructor = functools.partial(
        DebuggerDataStreamHandler,
        self._events_writer_manager,
        self._numerics_alert_callback)
    grpc_debug_server.EventListenerBaseServicer.__init__(
        self, receive_port, curried_handler_constructor)
예제 #8
0
 def testNoAlert(self):
     registry = numerics_alert.NumericsAlertRegistry()
     self.assertEqual([], registry.report())
예제 #9
0
 def testCreateEmptyJsonableRegistry(self):
     """Tests that an empty registry yields an empty report."""
     registry = numerics_alert.NumericsAlertRegistry()
     self.assertEqual([], registry.report())
예제 #10
0
    def testRegisterBeyondCapacityObeysCapacity(self):
        alert_1 = numerics_alert.NumericsAlert(
            "/job:worker/replica:0/task:1/gpu:0", "xent/Log:0", 1434, 0, 1, 1)
        alert_2 = numerics_alert.NumericsAlert(
            "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1234, 0, 1, 1)
        alert_3 = numerics_alert.NumericsAlert(
            "/job:worker/replica:0/task:2/gpu:0", "xent/Log:0", 1634, 0, 1, 1)
        alert_4 = numerics_alert.NumericsAlert(
            "/job:worker/replica:0/task:0/gpu:0", "xent/Log:0", 1834, 1, 1, 1)
        registry = numerics_alert.NumericsAlertRegistry(capacity=2)
        registry.register(alert_1)
        registry.register(alert_2)
        self.assertEqual(
            [
                numerics_alert.NumericsAlertReportRow(
                    "/job:worker/replica:0/task:0/gpu:0",
                    "xent/Log:0",
                    1234,
                    0,
                    1,
                    1,
                ),
                numerics_alert.NumericsAlertReportRow(
                    "/job:worker/replica:0/task:1/gpu:0",
                    "xent/Log:0",
                    1434,
                    0,
                    1,
                    1,
                ),
            ],
            registry.report(),
        )

        registry.register(alert_3)
        self.assertEqual(
            [
                numerics_alert.NumericsAlertReportRow(
                    "/job:worker/replica:0/task:0/gpu:0",
                    "xent/Log:0",
                    1234,
                    0,
                    1,
                    1,
                ),
                numerics_alert.NumericsAlertReportRow(
                    "/job:worker/replica:0/task:1/gpu:0",
                    "xent/Log:0",
                    1434,
                    0,
                    1,
                    1,
                ),
            ],
            registry.report(),
        )

        registry.register(alert_4)
        self.assertEqual(
            [
                numerics_alert.NumericsAlertReportRow(
                    "/job:worker/replica:0/task:0/gpu:0",
                    "xent/Log:0",
                    1234,
                    1,
                    2,
                    2,
                ),
                numerics_alert.NumericsAlertReportRow(
                    "/job:worker/replica:0/task:1/gpu:0",
                    "xent/Log:0",
                    1434,
                    0,
                    1,
                    1,
                ),
            ],
            registry.report(),
        )