def prepare_tensorboard( env: det.EnvContext, container_path: Optional[str] = None, ) -> Tuple[tensorboard.TensorboardManager, tensorboard.BatchMetricWriter]: tensorboard_mgr = tensorboard.build( env.det_cluster_id, env.det_experiment_id, env.det_trial_id, env.experiment_config["checkpoint_storage"], container_path, ) try: from determined.tensorboard.metric_writers import tensorflow writer: tensorboard.MetricWriter = tensorflow.TFWriter() except ModuleNotFoundError: logging.warning("Tensorflow writer not found") from determined.tensorboard.metric_writers import pytorch writer = pytorch.TorchWriter() return ( tensorboard_mgr, tensorboard.BatchMetricWriter(writer), )
def get_metric_writer() -> tensorboard.BatchMetricWriter: try: from determined.tensorboard.metric_writers import tensorflow writer: tensorboard.MetricWriter = tensorflow.TFWriter() except ModuleNotFoundError: logging.warning("TensorFlow writer not found") from determined.tensorboard.metric_writers import pytorch writer = pytorch.TorchWriter() return tensorboard.BatchMetricWriter(writer)
def prepare_tensorboard( env: det.EnvContext, ) -> Tuple[tensorboard.TensorboardManager, tensorboard.BatchMetricWriter]: tensorboard_mgr = tensorboard.build(env, env.experiment_config["checkpoint_storage"]) try: from determined.tensorboard.metric_writers import pytorch writer: tensorboard.MetricWriter = pytorch.TorchWriter() except ImportError: print("PYTORCH WRITER NOT FOUND") from determined.tensorboard.metric_writers import tensorflow writer = tensorflow.TFWriter() return ( tensorboard_mgr, tensorboard.BatchMetricWriter(writer, env.experiment_config.batches_per_step()), )
def override_unsupported_nud(lm: pl.LightningModule, context: PyTorchTrialContext) -> None: writer = pytorch.TorchWriter() def lm_print(*args: Any, **kwargs: Any) -> None: if context.distributed.get_rank() == 0: print(*args, **kwargs) def lm_log_dict(a_dict: Dict, *args: Any, **kwargs: Any) -> None: if len(args) != 0 or len(kwargs) != 0: raise InvalidModelException( f"unsupported arguments to LightningModule.log {args} {kwargs}" ) for metric, value in a_dict.items(): if type(value) == int or type(value) == float: writer.add_scalar(metric, value, context.current_train_batch()) def lm_log(name: str, value: Any, *args: Any, **kwargs: Any) -> None: lm_log_dict({name: value}, *args, **kwargs) lm.print = lm_print # type: ignore lm.log = lm_log # type: ignore lm.log_dict = lm_log_dict # type: ignore