コード例 #1
0
import os
import time
from typing import Callable, Text, Tuple

from absl import logging

import tensorflow.compat.v2 as tf

from tf_agents.agents import tf_agent
from tf_agents.policies import py_tf_eager_policy
from tf_agents.typing import types
from tf_agents.utils import lazy_loader

# Lazy loading since not all users have the reverb package installed.
reverb = lazy_loader.LazyLoader('reverb', globals(), 'reverb')

# By default the implementation of wait functions blocks with relatively large
# number of frequent retries assuming that the event usually happens soon, but
# occasionally takes longer.
_WAIT_DEFAULT_SLEEP_TIME_SECS = 1
_WAIT_DEFAULT_NUM_RETRIES = 60 * 60 * 24  # 1 day


def create_train_step() -> tf.Variable:
  return tf.Variable(
      0,
      trainable=False,
      dtype=tf.int64,
      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
      shape=())
コード例 #2
0
"""Utilities for working with a reverb replay buffer."""

from __future__ import absolute_import
from __future__ import division
# Using Type Annotations.
from __future__ import print_function

from typing import Text, Union

from absl import logging

from tf_agents.typing import types
from tf_agents.utils import lazy_loader

# Lazy loading since not all users have the reverb package installed.
reverb = lazy_loader.LazyLoader("reverb", globals(), "reverb")


class ReverbAddEpisodeObserver(object):
    """Observer for writing episodes to the Reverb replay buffer."""
    def __init__(self,
                 py_client: types.ReverbClient,
                 table_name: Text,
                 max_sequence_length: int,
                 priority: Union[float, int] = 1,
                 bypass_partial_episodes: bool = False):
        """Creates an instance of the ReverbAddEpisodeObserver.

    **Note**: This observer is designed to work with py_drivers only, and does
    not support batches.