def _make_heartbeat_op(session, device, request_ph): """Return a heartbeat op or None if heartbeats are not supported by device.""" try: # Test if we can connect in a isolated graph + session with ops.Graph().as_default(): with _clone_session(session) as temp_session: with ops.device(device): heartbeat_op = tpu_ops.worker_heartbeat('') options = config_pb2.RunOptions(timeout_in_ms=5000) temp_session.run(heartbeat_op, options=options) except errors.InvalidArgumentError as _: logging.warning('Error running heartbeat on %s', device) return None except errors.DeadlineExceededError as _: logging.warning('Timeout connecting to %s when testing heartbeat', device) return None # If we successfully connected and pinged the worker, go ahead and construct # the operation. with ops.device(device): return tpu_ops.worker_heartbeat(request_ph)
def from_devices(session, devices): """Construct a heartbeat manager for the given devices.""" if not devices: logging.error('Trying to create heartbeat manager with no devices?') logging.info('Creating heartbeat manager for %s', devices) request_placeholder = array_ops.placeholder( name='worker_heartbeat_request', dtype=dtypes.string) heartbeat_ops = [] for device in devices: with ops.device(device): heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) return WorkerHeartbeatManager(session, devices, heartbeat_ops, request_placeholder)