def prepare_tensorboard(
    env: det.EnvContext,
    container_path: Optional[str] = None,
) -> Tuple[tensorboard.TensorboardManager, tensorboard.BatchMetricWriter]:
    tensorboard_mgr = tensorboard.build(
        env.det_cluster_id,
        env.det_experiment_id,
        env.det_trial_id,
        env.experiment_config["checkpoint_storage"],
        container_path,
    )
    try:
        from determined.tensorboard.metric_writers import tensorflow

        writer: tensorboard.MetricWriter = tensorflow.TFWriter()

    except ModuleNotFoundError:
        logging.warning("Tensorflow writer not found")
        from determined.tensorboard.metric_writers import pytorch

        writer = pytorch.TorchWriter()

    return (
        tensorboard_mgr,
        tensorboard.BatchMetricWriter(writer),
    )
Esempio n. 2
0
def get_metric_writer() -> tensorboard.BatchMetricWriter:
    try:
        from determined.tensorboard.metric_writers import tensorflow

        writer: tensorboard.MetricWriter = tensorflow.TFWriter()

    except ModuleNotFoundError:
        logging.warning("TensorFlow writer not found")
        from determined.tensorboard.metric_writers import pytorch

        writer = pytorch.TorchWriter()

    return tensorboard.BatchMetricWriter(writer)
def prepare_tensorboard(
    env: det.EnvContext,
) -> Tuple[tensorboard.TensorboardManager, tensorboard.BatchMetricWriter]:
    tensorboard_mgr = tensorboard.build(env, env.experiment_config["checkpoint_storage"])
    try:
        from determined.tensorboard.metric_writers import pytorch

        writer: tensorboard.MetricWriter = pytorch.TorchWriter()

    except ImportError:
        print("PYTORCH WRITER NOT FOUND")
        from determined.tensorboard.metric_writers import tensorflow

        writer = tensorflow.TFWriter()

    return (
        tensorboard_mgr,
        tensorboard.BatchMetricWriter(writer, env.experiment_config.batches_per_step()),
    )
Esempio n. 4
0
def override_unsupported_nud(lm: pl.LightningModule, context: PyTorchTrialContext) -> None:
    writer = pytorch.TorchWriter()

    def lm_print(*args: Any, **kwargs: Any) -> None:
        if context.distributed.get_rank() == 0:
            print(*args, **kwargs)

    def lm_log_dict(a_dict: Dict, *args: Any, **kwargs: Any) -> None:
        if len(args) != 0 or len(kwargs) != 0:
            raise InvalidModelException(
                f"unsupported arguments to LightningModule.log {args} {kwargs}"
            )
        for metric, value in a_dict.items():
            if type(value) == int or type(value) == float:
                writer.add_scalar(metric, value, context.current_train_batch())

    def lm_log(name: str, value: Any, *args: Any, **kwargs: Any) -> None:
        lm_log_dict({name: value}, *args, **kwargs)

    lm.print = lm_print  # type: ignore
    lm.log = lm_log  # type: ignore
    lm.log_dict = lm_log_dict  # type: ignore