def __getattr__(self, attr: str) -> Any: check.is_not_none( self._poly_hvd_type, "You must call det.horovod.hvd.require_horovod_type() before any other calls.", ) check.is_not_none(self._poly_hvd_module, "Horovod could not be imported in this process.") return getattr(self._poly_hvd_module, attr)
def metrics_result(self) -> Metrics: """Identical to result but disallow workload.Skipped responses.""" check.is_not_none(self._response, "_respond() was not called by the TrialController.") check.is_instance(self._response, dict, "unexpected SkippedWorkload response.") return cast(Metrics, self._response)
def result(self) -> Response: """Read the WorkloadResponse from the TrialController (only call once per send).""" check.is_not_none(self._response, "_respond() was not called by the TrialController.") out = self._response self._response = None return cast(Response, out)
def run(self) -> None: # Do nothing while we wait for a StartMessage while True: msg = self.inbound_queue.get() if isinstance(msg, StartMessage): break if isinstance(msg, ShutdownMessage): self.send_queue.put(ShutdownMessage()) return else: # Ignore any Timings that are received before StartMessage pass batch_start_time = None # type: Optional[float] while True: # Wait for the next Timing to arrive. If it doesn't arrive before the next flush # should happen, we stop waiting for the next Timing and go straight to flushing. timeout = None if batch_start_time is not None: time_since_flush = time.time() - batch_start_time timeout = self.FLUSH_INTERVAL - time_since_flush try: message = self.inbound_queue.get(timeout=timeout) if isinstance(message, ShutdownMessage): self.send_queue.put( self.current_batch.convert_to_post_format()) self.send_queue.put(ShutdownMessage()) return elif isinstance(message, Timing): if batch_start_time is None: batch_start_time = time.time() self.current_batch.add_timing(message) else: logging.fatal( f"ProfilerAgent.TimingsBatcherThread received a message " f"of unexpected type '{type(message)}' from the " f"inbound_queue. This should never happen - there must " f"be a bug in the code.") except queue.Empty: pass check.is_not_none( batch_start_time, "batch_start_time should never be None. The inbound_queue.get() " "should never return and proceed to this piece of code " "without batch_start_time being updated to a real timestamp. If " "batch_start_time is None, inbound_queue.get() timeout should be " "None and the get() should block until a Timing is received.", ) batch_start_time = cast(float, batch_start_time) if time.time() - batch_start_time > self.FLUSH_INTERVAL: self.send_queue.put( self.current_batch.convert_to_post_format()) self.current_batch.clear() batch_start_time = time.time()
def _init_device(self) -> None: self.n_gpus = len(self.env.container_gpus) if self.hvd_config.use: check.gt(self.n_gpus, 0) # We launch a horovod process per GPU. Each process # needs to bind to a unique GPU. self.device = torch.device(hvd.local_rank()) torch.cuda.set_device(self.device) elif self.n_gpus > 0: self.device = torch.device("cuda", 0) else: self.device = torch.device("cpu") check.is_not_none(self.device)
def to_measurement(self) -> Measurement: check.is_not_none( self.start_time, "Timing has no start time and to_measurement() was called. You probably didn't " "run start() before to_measurement().", ) check.is_not_none( self.dur, "Timing has no duration and to_measurement() was called. You probably didn't " "run end() before to_measurement().", ) self.start_time = cast(float, self.start_time) start_time_dt = datetime.fromtimestamp(self.start_time, timezone.utc) self.dur = cast(float, self.dur) return Measurement( timestamp=start_time_dt, batch_idx=self.current_batch_idx, value=self.dur )
def __init__( self, num_connections: Optional[int] = None, ports: Optional[List[int]] = None, port_range: Optional[Tuple[int, int]] = None, ) -> None: self.context = zmq.Context() # type: ignore self.sockets = [] # type: List[zmq.Socket] self.ports = [] # type: List[int] if ports: check.is_none(port_range) self._bind_to_specified_ports(ports=ports) check.eq(len(self.ports), len(ports)) else: check.is_not_none(num_connections) check.is_not_none(port_range) num_connections = cast(int, num_connections) port_range = cast(Tuple[int, int], port_range) self._bind_to_random_ports(port_range=port_range, num_connections=num_connections) check.eq(len(self.ports), num_connections)