예제 #1
0
def make_default_logger(
    label: str,
    save_data: bool = True,
    time_delta: float = 1.0,
) -> base.Logger:
  """Make a default Acme logger.

  Args:
    label: Name to give to the logger.
    save_data: Ignored.
    time_delta: Time (in seconds) between logging events.

  Returns:
    A logger (pipe) object that responds to logger.write(some_dict).
  """
  terminal_logger = terminal.TerminalLogger(label=label, time_delta=time_delta)

  loggers = [terminal_logger]
  if save_data:
    loggers.append(csv.CSVLogger(label))

  logger = aggregators.Dispatcher(loggers)
  logger = filters.NoneFilter(logger)
  logger = filters.TimeFilter(logger, time_delta)
  return logger
예제 #2
0
def make_default_logger(
    logdir: str,
    label: str,
    save_data: bool = True,
    time_delta: float = 0.0,
) -> base.Logger:
  """Make a default Acme logger.

  Args:
    label: Name to give to the logger.
    save_data: Ignored.
    time_delta: Time (in seconds) between logging events.

  Returns:
    A logger (pipe) object that responds to logger.write(some_dict).
  """
  loggers = []

  # TODO: temporarily disable terminal logger for environment.
  if 'agent' in label:
    loggers.append(terminal.TerminalLogger(label=label, time_delta=time_delta))

  if save_data:
    loggers.append(csv.CSVLogger(logdir=logdir, label=label))
    loggers.append(tf_summary.TFSummaryLogger(logdir=logdir, label=label))

  logger = aggregators.Dispatcher(loggers)
  logger = filters.NoneFilter(logger)
  logger = filters.TimeFilter(logger, time_delta)

  if save_data:
    logger = wrapper.CSVDumper(logger, label=label, logdir=logdir)

  return logger
예제 #3
0
def make_default_logger(
    label: str,
    save_data: bool = True,
    time_delta: float = 1.0,
    asynchronous: bool = False,
    print_fn: Optional[Callable[[str], None]] = None,
    serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy,
    steps_key: str = 'steps',
) -> base.Logger:
    """Makes a default Acme logger.

  Args:
    label: Name to give to the logger.
    save_data: Whether to persist data.
    time_delta: Time (in seconds) between logging events.
    asynchronous: Whether the write function should block or not.
    print_fn: How to print to terminal (defaults to print).
    serialize_fn: An optional function to apply to the write inputs before
      passing them to the various loggers.
    steps_key: Ignored.

  Returns:
    A logger object that responds to logger.write(some_dict).
  """
    del steps_key
    if not print_fn:
        print_fn = logging.info
    terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn)

    loggers = [terminal_logger]

    if save_data:
        loggers.append(csv.CSVLogger(label=label))

    # Dispatch to all writers and filter Nones and by time.
    logger = aggregators.Dispatcher(loggers, serialize_fn)
    logger = filters.NoneFilter(logger)
    if asynchronous:
        logger = async_logger.AsyncLogger(logger)
    logger = filters.TimeFilter(logger, time_delta)

    return logger
예제 #4
0
def create_default_logger(
    label: str,
    tf_summary_logdir: str,
    save_data: bool = True,
    step_filter_delta: int = 1,
    time_delta: float = 1.0,
    asynchronous: bool = False,
    print_fn: Optional[Callable[[str], None]] = None,
    serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy,
    steps_key: str = 'steps',
    extra_primary_keys: Optional[PrimaryKeyList] = None,
) -> base.Logger:
    """Creates a logger that has TerminalLogger and TF Summary."""
    loggers = []

    terminal_logger = terminal.TerminalLogger(label, print_fn=print)
    loggers.append(terminal_logger)

    # tensorboard logger
    tf_logger = tf_summary.TFSummaryLogger(
        logdir=tf_summary_logdir,
        label=label,
        steps_key=steps_key,
        # steps_key=None,
    )
    loggers.append(tf_logger)

    # aggregate and add modifiers
    logger = aggregators.Dispatcher(loggers, serialize_fn=serialize_fn)
    logger = filters.NoneFilter(logger)
    if step_filter_delta > 1:
        logger = StepFilter(logger,
                            steps_label=steps_key,
                            delta=step_filter_delta)
    logger = filters.TimeFilter(logger, time_delta=time_delta)
    if asynchronous:
        logger = async_logger.AsyncLogger(logger)
    # logger = filters.TimeFilter(logger, time_delta=time_delta)
    # logger = logger_filters.FlattenDictLogger(logger, label=label)

    return logger