Exemplo n.º 1
0
 def test_save_and_new_restore(self):
   """Tests that a fresh checkpointer correctly restores an existing ckpt."""
   with mock.patch.object(paths, 'get_unique_id') as mock_unique_id:
     mock_unique_id.return_value = ('test',)
     x = tf.Variable(0, dtype=tf.int32)
     directory = self.get_tempdir()
     checkpointer1 = tf2_savers.Checkpointer(
         objects_to_save={'x': x}, time_delta_minutes=0., directory=directory)
     checkpointer1.save()
     x.assign_add(1)
     # Simulate a preemption: x is changed, and we make a new Checkpointer.
     checkpointer2 = tf2_savers.Checkpointer(
         objects_to_save={'x': x}, time_delta_minutes=0., directory=directory)
     checkpointer2.restore()
     np.testing.assert_array_equal(x.numpy(), np.int32(0))
Exemplo n.º 2
0
  def test_no_checkpoint(self):
    """Test that checkpointer does nothing when checkpoint=False."""
    num_steps = tf.Variable(0)
    checkpointer = tf2_savers.Checkpointer(
        objects_to_save={'num_steps': num_steps}, enable_checkpointing=False)

    for _ in range(10):
      self.assertFalse(checkpointer.save())
    self.assertIsNone(checkpointer._checkpoint_manager)
Exemplo n.º 3
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        discount: float,
        dataset: tf.data.Dataset,
        critic_lr: float = 1e-4,
        checkpoint_interval_minutes: int = 10.0,
        clipping: bool = True,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
        init_observations: Any = None,
    ):

        self._policy_network = policy_network
        self._critic_network = critic_network
        self._discount = discount
        self._clipping = clipping
        self._init_observations = init_observations

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self._num_steps = tf.Variable(0, dtype=tf.int32)

        # Batch dataset and create iterator.
        self._iterator = iter(dataset)

        self._critic_optimizer = snt.optimizers.Adam(critic_lr)

        # Expose the variables.
        self._variables = {
            'critic': self._critic_network.variables,
        }
        # We remove trailing dimensions to keep same output dimmension
        # as existing FQE based on D4PG. i.e.: (batch_size,).
        critic_mean = snt.Sequential(
            [self._critic_network, lambda t: tf.squeeze(t, -1)])
        self._critic_mean = critic_mean

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save=self.state,
                time_delta_minutes=checkpoint_interval_minutes,
                checkpoint_ttl_seconds=_CHECKPOINT_TTL)
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'critic': critic_mean,
            },
                                                       time_delta_minutes=60.)
Exemplo n.º 4
0
  def test_save_and_restore(self):
    """Test that checkpointer correctly calls save and restore."""

    x = tf.Variable(0, dtype=tf.int32)
    directory = self.get_tempdir()
    checkpointer = tf2_savers.Checkpointer(
        objects_to_save={'x': x}, time_delta_minutes=0., directory=directory)

    for _ in range(10):
      saved = checkpointer.save()
      self.assertTrue(saved)
      x.assign_add(1)
      checkpointer.restore()
      np.testing.assert_array_equal(x.numpy(), np.int32(0))
Exemplo n.º 5
0
  def test_save_and_restore_time_based(self):
    """Test that checkpointer correctly calls save and restore based on time."""

    x = tf.Variable(0, dtype=tf.int32)
    directory = self.get_tempdir()
    checkpointer = tf2_savers.Checkpointer(
        objects_to_save={'x': x}, time_delta_minutes=1., directory=directory)

    with mock.patch.object(time, 'time') as mock_time:
      mock_time.return_value = 0.
      self.assertFalse(checkpointer.save())

      mock_time.return_value = 40.
      self.assertFalse(checkpointer.save())

      mock_time.return_value = 70.
      self.assertTrue(checkpointer.save())
    x.assign_add(1)
    checkpointer.restore()
    np.testing.assert_array_equal(x.numpy(), np.int32(0))
Exemplo n.º 6
0
  def __init__(self,
               network: snt.Module,
               learning_rate: float,
               dataset: tf.data.Dataset,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               checkpoint_subpath: str = '~/acme/'
               ):
    """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      learning_rate: learning rate for the q-network update.
      dataset: dataset to learn from.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Get an iterator over the dataset.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    # TODO(b/155086959): Fix type stubs and remove.

    self._network = network
    self._optimizer = snt.optimizers.Adam(learning_rate)

    self._variables: List[List[tf.Tensor]] = [network.trainable_variables]

    # Create a checkpointer and snapshoter object.
    self._checkpointer = tf2_savers.Checkpointer(
      objects_to_save=self.state,
      time_delta_minutes=10.,
      directory=checkpoint_subpath,
      subdirectory='bc_learner'
    )

    self._snapshotter = tf2_savers.Snapshotter(
      objects_to_save={'network': network}, time_delta_minutes=60.)
Exemplo n.º 7
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 network: snt.RNNCore,
                 target_network: snt.RNNCore,
                 burn_in_length: int,
                 trace_length: int,
                 replay_period: int,
                 demonstration_dataset: tf.data.Dataset,
                 demonstration_ratio: float,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 discount: float = 0.99,
                 batch_size: int = 32,
                 target_update_period: int = 100,
                 importance_sampling_exponent: float = 0.2,
                 epsilon: float = 0.01,
                 learning_rate: float = 1e-3,
                 log_to_bigtable: bool = False,
                 log_name: str = 'agent',
                 checkpoint: bool = True,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               sequence_length=sequence_length)

        # Combine with demonstration dataset.
        transition = functools.partial(_sequence_from_episode,
                                       extra_spec=extra_spec,
                                       **sequence_kwargs)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemplo n.º 8
0
  def __init__(
      self,
      policy_network: snt.Module,
      critic_network: snt.Module,
      target_policy_network: snt.Module,
      target_critic_network: snt.Module,
      discount: float,
      target_update_period: int,
      dataset_iterator: Iterator[reverb.ReplaySample],
      observation_network: types.TensorTransformation = lambda x: x,
      target_observation_network: types.TensorTransformation = lambda x: x,
      policy_optimizer: snt.Optimizer = None,
      critic_optimizer: snt.Optimizer = None,
      clipping: bool = True,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
  ):
    """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset_iterator: dataset to learn from, whether fixed or from a replay
        buffer (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    # Store online and target networks.
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_policy_network = target_policy_network
    self._target_critic_network = target_critic_network

    # Make sure observation networks are snt.Module's so they have variables.
    self._observation_network = tf2_utils.to_sonnet_module(observation_network)
    self._target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')

    # Other learner parameters.
    self._discount = discount
    self._clipping = clipping

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    self._iterator = dataset_iterator

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
    self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

    # Expose the variables.
    policy_network_to_expose = snt.Sequential(
        [self._target_observation_network, self._target_policy_network])
    self._variables = {
        'critic': self._target_critic_network.variables,
        'policy': policy_network_to_expose.variables,
    }

    # Create a checkpointer and snapshotter objects.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          subdirectory='d4pg_learner',
          objects_to_save={
              'counter': self._counter,
              'policy': self._policy_network,
              'critic': self._critic_network,
              'observation': self._observation_network,
              'target_policy': self._target_policy_network,
              'target_critic': self._target_critic_network,
              'target_observation': self._target_observation_network,
              'policy_optimizer': self._policy_optimizer,
              'critic_optimizer': self._critic_optimizer,
              'num_steps': self._num_steps,
          })
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'policy': self._policy_network,
              'critic': critic_mean,
          })

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Exemplo n.º 9
0
Arquivo: agent.py Projeto: wilixx/acme
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: snt.RNNCore,
      burn_in_length: int,
      trace_length: int,
      replay_period: int,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      discount: float = 0.99,
      batch_size: int = 32,
      prefetch_size: int = tf.data.experimental.AUTOTUNE,
      target_update_period: int = 100,
      importance_sampling_exponent: float = 0.2,
      priority_exponent: float = 0.6,
      epsilon: float = 0.01,
      learning_rate: float = 1e-3,
      min_replay_size: int = 1000,
      max_replay_size: int = 1000000,
      samples_per_insert: float = 32.0,
      store_lstm_state: bool = True,
      max_priority_weight: float = 0.9,
      checkpoint: bool = True,
  ):

    extra_spec = {
        'core_state': network.initial_state(1),
    }
    # Remove batch dimensions.
    extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
    replay_table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Prioritized(priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
        signature=adders.SequenceAdder.signature(environment_spec,
                                                   extra_spec))
    self._server = reverb.Server([replay_table], port=None)
    address = f'localhost:{self._server.port}'

    sequence_length = burn_in_length + trace_length + 1
    # Component to add things into replay.
    adder = adders.SequenceAdder(
        client=reverb.Client(address),
        period=replay_period,
        sequence_length=sequence_length,
    )

    # The dataset object to learn from.
    reverb_client = reverb.TFClient(address)
    dataset = datasets.make_reverb_dataset(
        client=reverb_client,
        environment_spec=environment_spec,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        extra_spec=extra_spec,
        sequence_length=sequence_length)

    target_network = copy.deepcopy(network)
    tf2_utils.create_variables(network, [environment_spec.observations])
    tf2_utils.create_variables(target_network, [environment_spec.observations])

    learner = learning.R2D2Learner(
        environment_spec=environment_spec,
        network=network,
        target_network=target_network,
        burn_in_length=burn_in_length,
        sequence_length=sequence_length,
        dataset=dataset,
        reverb_client=reverb_client,
        counter=counter,
        logger=logger,
        discount=discount,
        target_update_period=target_update_period,
        importance_sampling_exponent=importance_sampling_exponent,
        max_replay_size=max_replay_size,
        learning_rate=learning_rate,
        store_lstm_state=store_lstm_state,
        max_priority_weight=max_priority_weight,
    )

    self._checkpointer = tf2_savers.Checkpointer(
        subdirectory='r2d2_learner',
        time_delta_minutes=60,
        objects_to_save=learner.state,
        enable_checkpointing=checkpoint,
    )
    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    policy_network = snt.DeepRNN([
        network,
        lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
    ])

    actor = actors.RecurrentActor(policy_network, adder)
    observations_per_step = (
        float(replay_period * batch_size) / samples_per_insert)
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=replay_period * max(batch_size, min_replay_size),
        observations_per_step=observations_per_step)
Exemplo n.º 10
0
  def __init__(self,
               policy_network: snt.Module,
               critic_network: snt.Module,
               target_critic_network: snt.Module,
               discount: float,
               target_update_period: int,
               dataset: tf.data.Dataset,
               critic_optimizer: snt.Optimizer = None,
               critic_lr: float = 1e-4,
               checkpoint_interval_minutes: int = 10.0,
               clipping: bool = True,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               checkpoint: bool = True,
               init_observations: Any = None,
               distributional: bool = True,
               vmin: Optional[float] = None,
               vmax: Optional[float] = None,
               ):
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_critic_network = target_critic_network
    self._discount = discount
    self._clipping = clipping
    self._init_observations = init_observations
    self._distributional = distributional
    self._vmin = vmin
    self._vmax = vmax

    if (self._distributional and
        (self._vmin is not None or self._vmax is not None)):
      logging.warning('vmin and vmax arguments to FQELearner are ignored when '
                      'distributional=True. They should be provided when '
                      'creating the critic network.')

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    self._iterator = iter(dataset)

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(critic_lr)

    # Expose the variables.
    self._variables = {
        'critic': self._target_critic_network.variables,
    }
    if distributional:
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
    else:
      # We remove trailing dimensions to keep same output dimmension
      # as existing FQE based on D4PG. i.e.: (batch_size,).
      critic_mean = snt.Sequential(
          [self._critic_network, lambda t: tf.squeeze(t, -1)])
    self._critic_mean = critic_mean

    # Create a checkpointer object.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          objects_to_save=self.state,
          time_delta_minutes=checkpoint_interval_minutes,
          checkpoint_ttl_seconds=_CHECKPOINT_TTL)
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'critic': critic_mean,
          },
          time_delta_minutes=60.)
Exemplo n.º 11
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        target_update_period: int,
        dataset_iterator: Iterator[reverb.ReplaySample],
        prior_network: Optional[snt.Module] = None,
        target_prior_network: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        prior_optimizer: Optional[snt.Optimizer] = None,
        distillation_cost: Optional[float] = 1e-3,
        entropy_regularizer_cost: Optional[float] = 1e-3,
        num_action_samples: int = 10,
        lambda_: float = 1.0,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset_iterator: dataset to learn from, whether fixed or from a replay
        buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation).
      prior_network: the online (optimized) prior.
      target_prior_network: the target prior (which lags behind the online
        prior).
      policy_optimizer: the optimizer to be applied to the SVG-0 (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      prior_optimizer: the optimizer to be applied to the prior (distillation)
        loss.
      distillation_cost: a multiplier to be used when adding distillation
        against the prior to the losses.
      entropy_regularizer_cost: a multiplier used for per state sample based
        entropy added to the actor loss.
      num_action_samples: the number of action samples to use for estimating the
        value function and sample based entropy.
      lambda_: the `lambda` value to be used with retrace.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        self._prior_network = prior_network
        self._target_prior_network = target_prior_network

        self._lambda = lambda_
        self._num_action_samples = num_action_samples
        self._distillation_cost = distillation_cost
        self._entropy_regularizer_cost = entropy_regularizer_cost

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Batch dataset and create iterator.
        self._iterator = dataset_iterator

        # Create optimizers if they aren't given.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4)

        # Expose the variables.
        self._variables = {
            'critic': self._critic_network.variables,
            'policy': self._policy_network.variables,
        }
        if self._prior_network is not None:
            self._variables['prior'] = self._prior_network.variables

        # Create a checkpointer and snapshotter objects.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            objects_to_save = {
                'counter': self._counter,
                'policy': self._policy_network,
                'critic': self._critic_network,
                'target_policy': self._target_policy_network,
                'target_critic': self._target_critic_network,
                'policy_optimizer': self._policy_optimizer,
                'critic_optimizer': self._critic_optimizer,
                'num_steps': self._num_steps,
            }
            if self._prior_network is not None:
                objects_to_save['prior'] = self._prior_network
                objects_to_save['target_prior'] = self._target_prior_network
                objects_to_save['prior_optimizer'] = self._prior_optimizer

            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='svg0_learner', objects_to_save=objects_to_save)
            objects_to_snapshot = {
                'policy': self._policy_network,
                'critic': self._critic_network,
            }
            if self._prior_network is not None:
                objects_to_snapshot['prior'] = self._prior_network

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save=objects_to_snapshot)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemplo n.º 12
0
    def __init__(
        self,
        reward_objectives: Sequence[RewardObjective],
        qvalue_objectives: Sequence[QValueObjective],
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[losses.MultiObjectiveMPO] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Store objectives
        self._reward_objectives = reward_objectives
        self._qvalue_objectives = qvalue_objectives
        if self._qvalue_objectives is None:
            self._qvalue_objectives = []
        self._num_critic_heads = len(self._reward_objectives)  # C
        self._objective_names = ([x.name for x in self._reward_objectives] +
                                 [x.name for x in self._qvalue_objectives])

        self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO(
            epsilons=[
                losses.KLConstraint(name, _DEFAULT_EPSILON)
                for name in self._objective_names
            ],
            epsilon_mean=_DEFAULT_EPSILON_MEAN,
            epsilon_stddev=_DEFAULT_EPSILON_STDDEV,
            init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE,
            init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN,
            init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV)

        # Check that ordering of objectives matches the policy_loss_module's
        if self._objective_names != list(
                self._policy_loss_module.objective_names):
            raise ValueError("Agent's ordering of objectives doesn't match "
                             "the policy loss module's ordering of epsilons.")

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp: float = None
Exemplo n.º 13
0
Arquivo: agent.py Projeto: wilixx/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.Module,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: tf.Tensor = None,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
        checkpoint_subpath: str = '~/acme/',
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      priority_exponent: exponent used in prioritized sampling.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
      checkpoint_subpath: directory for the checkpoint.
    """

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        replay_client = reverb.TFClient(address)
        dataset = datasets.make_reverb_dataset(
            client=replay_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        # Use constant 0.05 epsilon greedy policy by default.
        if epsilon is None:
            epsilon = tf.Variable(0.05, trainable=False)
        policy_network = snt.Sequential([
            network,
            lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
        ])

        # Create a target network.
        target_network = copy.deepcopy(network)

        # Ensure that we create the variables before proceeding (maybe not needed).
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network, adder)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            learning_rate=learning_rate,
            target_update_period=target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            logger=logger,
            checkpoint=checkpoint)

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                directory=checkpoint_subpath,
                objects_to_save=learner.state,
                subdirectory='dqn_learner',
                time_delta_minutes=60.)
        else:
            self._checkpointer = None

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemplo n.º 14
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        target_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = lambda x: x,
        target_observation_network: types.TensorTransformation = lambda x: x,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        clipping: bool = True,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
    ):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the critic loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Create an iterator to go through the dataset.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Create optimizers if they aren't given.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        self._checkpointer = tf2_savers.Checkpointer(
            time_delta_minutes=5,
            objects_to_save={
                'counter': self._counter,
                'policy': self._policy_network,
                'critic': self._critic_network,
                'target_policy': self._target_policy_network,
                'target_critic': self._target_critic_network,
                'policy_optimizer': self._policy_optimizer,
                'critic_optimizer': self._critic_optimizer,
                'num_steps': self._num_steps,
            },
            enable_checkpointing=checkpoint,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemplo n.º 15
0
  def __init__(self,
               policy_network: snt.Module,
               critic_network: snt.Module,
               f_network: snt.Module,
               discount: float,
               dataset: tf.data.Dataset,
               use_tilde_critic: bool,
               tilde_critic_network: snt.Module = None,
               tilde_critic_update_period: int = None,
               critic_optimizer_class: str = 'OAdam',
               critic_lr: float = 1e-4,
               critic_beta1: float = 0.5,
               critic_beta2: float = 0.9,
               f_optimizer_class: str = 'OAdam',
               f_lr: float = None,  # Either f_lr or f_lr_multiplier must be
                                    # None.
               f_lr_multiplier: Optional[float] = 1.0,
               f_beta1: float = 0.5,
               f_beta2: float = 0.9,
               critic_regularizer: float = 0.0,
               f_regularizer: float = 1.0,  # Ignored if use_tilde_critic = True
               critic_ortho_regularizer: float = 0.0,
               f_ortho_regularizer: float = 0.0,
               critic_l2_regularizer: float = 0.0,
               f_l2_regularizer: float = 0.0,
               checkpoint_interval_minutes: int = 10.0,
               clipping: bool = True,
               clipping_action: bool = True,
               bre_check_period: int = 0,  # Bellman residual error check.
               bre_check_num_actions: int = 0,  # Number of sampled actions.
               dev_dataset: tf.data.Dataset = None,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               checkpoint: bool = True):

    self._policy_network = policy_network
    self._critic_network = critic_network
    self._f_network = f_network
    self._discount = discount
    self._clipping = clipping
    self._clipping_action = clipping_action
    self._bre_check_period = bre_check_period
    self._bre_check_num_actions = bre_check_num_actions

    # Development dataset for hyper-parameter selection.
    self._dev_dataset = dev_dataset
    self._dev_actions_dataset = self._sample_actions()

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Necessary to track when to update tilde critic networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)

    self._use_tilde_critic = use_tilde_critic
    self._tilde_critic_network = tilde_critic_network
    self._tilde_critic_update_period = tilde_critic_update_period
    if use_tilde_critic and tilde_critic_update_period is None:
      raise ValueError('tilde_critic_update_period must be provided if '
                       'use_tilde_critic is True.')

    # Batch dataset and create iterator.
    self._iterator = iter(dataset)

    # Create optimizers if they aren't given.
    self._critic_optimizer = _optimizer_class(critic_optimizer_class)(
        critic_lr, beta1=critic_beta1, beta2=critic_beta2)

    if f_lr is not None:
      if f_lr_multiplier is not None:
        raise ValueError(f'Either f_lr ({f_lr}) or f_lr_multiplier '
                         f'({f_lr_multiplier}) must be None.')
    else:
      f_lr = f_lr_multiplier * critic_lr
    # Prevent unreasonable value in hyper-param search.
    f_lr = max(min(f_lr, 1e-2), critic_lr)

    self._f_optimizer = _optimizer_class(f_optimizer_class)(
        f_lr, beta1=f_beta1, beta2=f_beta2)

    # Regularization on network values.
    self._critic_regularizer = critic_regularizer
    self._f_regularizer = f_regularizer

    # Orthogonal regularization strength.
    self._critic_ortho_regularizer = critic_ortho_regularizer
    self._f_ortho_regularizer = f_ortho_regularizer

    # L2 regularization strength.
    self._critic_l2_regularizer = critic_l2_regularizer
    self._f_l2_regularizer = f_l2_regularizer

    # Expose the variables.
    self._variables = {
        'critic': self._critic_network.variables,
    }

    # Create a checkpointer object.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      objects_to_save = {
          'counter': self._counter,
          'critic': self._critic_network,
          'f': self._f_network,
          'tilde_critic': self._tilde_critic_network,
          'critic_optimizer': self._critic_optimizer,
          'f_optimizer': self._f_optimizer,
          'num_steps': self._num_steps,
      }
      self._checkpointer = tf2_savers.Checkpointer(
          objects_to_save={k: v for k, v in objects_to_save.items()
                           if v is not None},
          time_delta_minutes=checkpoint_interval_minutes,
          checkpoint_ttl_seconds=_CHECKPOINT_TTL)
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'critic': self._critic_network,
              'f': self._f_network,
          },
          time_delta_minutes=60.)
Exemplo n.º 16
0
    def __init__(self,
                 network: snt.Module,
                 discount: float,
                 importance_sampling_exponent: float,
                 learning_rate: float,
                 target_update_period: int,
                 cql_alpha: float,
                 dataset: tf.data.Dataset,
                 huber_loss_parameter: float = 1.,
                 empirical_policy: dict = None,
                 translate_lse: float = 100.,
                 replay_client: reverb.TFClient = None,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint_subpath: str = '~/acme/'):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      target_network: the target Q critic (which lags behind the online net).
      discount: discount to use for TD updates.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      learning_rate: learning rate for the q-network update.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer (see
        `acme.datasets.reverb.make_dataset` documentation).
      huber_loss_parameter: Quadratic-linear boundary for Huber loss.
      replay_client: client to replay to allow for updating priorities.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Internalise agent components (replay buffer, networks, optimizer).
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        self._network = network
        self._target_network = copy.deepcopy(network)
        self._optimizer = snt.optimizers.Adam(learning_rate)
        self._alpha = tf.constant(cql_alpha, dtype=tf.float32)
        self._tr = tf.constant(translate_lse, dtype=tf.float32)
        self._emp_policy = empirical_policy
        self._replay_client = replay_client

        # Internalise the hyperparameters.
        self._discount = discount
        self._target_update_period = target_update_period
        self._importance_sampling_exponent = importance_sampling_exponent
        self._huber_loss_parameter = huber_loss_parameter

        # Learner state.
        self._variables: List[List[tf.Tensor]] = [network.trainable_variables]

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._counter.increment(learner_steps=0)

        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Create a checkpointer and snapshotter object.
        self._checkpointer = tf2_savers.Checkpointer(
            objects_to_save=self.state,
            time_delta_minutes=10.,
            directory=checkpoint_subpath,
            subdirectory='cql_learner')
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=30.)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemplo n.º 17
0
    def __init__(self,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 target_policy_network: snt.Module,
                 target_critic_network: snt.Module,
                 discount: float,
                 target_update_period: int,
                 dataset: tf.data.Dataset,
                 observation_network: types.TensorTransformation = lambda x: x,
                 target_observation_network: types.
                 TensorTransformation = lambda x: x,
                 policy_optimizer: snt.Optimizer = None,
                 critic_optimizer: snt.Optimizer = None,
                 clipping: bool = True,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 specified_path: str = None):
        # print('\033[94m I am sub_virtual acme d4pg learning\033[0m')
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Create optimizers if they aren't given.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter objects.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='d4pg_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                })
            # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/efc0f104-d4a6-11eb-9d04-04d4c40103a8/checkpoints/d4pg_learner/checkpoint/ckpt-1')
            # self._checkpointer._checkpoint.restore('/home/argsubt/acme/f397d4d6-edf2-11eb-a739-04d4c40103a8/checkpoints/d4pg_learner/ckpt-1')
            # print('\033[92mload checkpoints~\033[0m')
            # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/4346ec84-ee10-11eb-8185-04d4c40103a8/checkpoints/d4pg_learner/ckpt-532')
            self.specified_path = specified_path
            if self.specified_path != None:
                self._checkpointer._checkpoint.restore(self.specified_path)
                print('\033[92mspecified_path: ', str(self.specified_path),
                      '\033[0m')
            critic_mean = snt.Sequential(
                [self._critic_network,
                 acme_nets.StochasticMeanHead()])
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'policy': self._policy_network,
                'critic': critic_mean,
            })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemplo n.º 18
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self._policy_loss_module = policy_loss_module or losses.MPO(
            epsilon=1e-1,
            epsilon_penalty=1e-3,
            epsilon_mean=1e-3,
            epsilon_stddev=1e-6,
            init_log_temperature=1.,
            init_log_alpha_mean=1.,
            init_log_alpha_stddev=10.)

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemplo n.º 19
0
    def __init__(
        self,
        network: networks.IQNNetwork,
        target_network: snt.Module,
        discount: float,
        importance_sampling_exponent: float,
        learning_rate: float,
        target_update_period: int,
        dataset: tf.data.Dataset,
        huber_loss_parameter: float = 1.,
        replay_client: Optional[reverb.TFClient] = None,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized) that outputs
        (q_values, q_logits, atoms).
      target_network: the target Q critic (which lags behind the online net).
      discount: discount to use for TD updates.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      learning_rate: learning rate for the q-network update.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_reverb_dataset` documentation).
      huber_loss_parameter: Quadratic-linear boundary for Huber loss.
      replay_client: client to replay to allow for updating priorities.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
      checkpoint: boolean indicating whether to checkpoint the learner or not.
    """

        # Internalise agent components (replay buffer, networks, optimizer).
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        self._network = network
        self._target_network = target_network
        self._optimizer = snt.optimizers.Adam(learning_rate)
        self._replay_client = replay_client

        # Internalise the hyperparameters.
        self._discount = discount
        self._target_update_period = target_update_period
        self._importance_sampling_exponent = importance_sampling_exponent
        self._huber_loss_parameter = huber_loss_parameter

        # Learner state.
        self._variables: List[List[tf.Tensor]] = [network.trainable_variables]
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Create a snapshotter object.
        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                time_delta_minutes=5,
                objects_to_save={
                    'network': self._network,
                    'target_network': self._target_network,
                    'optimizer': self._optimizer,
                    'num_steps': self._num_steps
                })
            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={'network': network}, time_delta_minutes=60.)
        else:
            self._checkpointer = None
            self._snapshotter = None
Exemplo n.º 20
0
    def __init__(self,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 dataset: tf.data.Dataset,
                 discount: float,
                 behavior_network: Optional[snt.Module] = None,
                 cwp_network: Optional[snt.Module] = None,
                 policy_optimizer: Optional[
                     snt.Optimizer] = snt.optimizers.Adam(1e-4),
                 critic_optimizer: Optional[
                     snt.Optimizer] = snt.optimizers.Adam(1e-4),
                 target_update_period: int = 100,
                 policy_improvement_modes: str = 'exp',
                 ratio_upper_bound: float = 20.,
                 beta: float = 1.0,
                 cql_alpha: float = 0.0,
                 translate_lse: float = 100.,
                 empirical_policy: dict = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 checkpoint_subpath: str = '~/acme/'):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      target_network: the target Q critic (which lags behind the online net).
      discount: discount to use for TD updates.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      learning_rate: learning rate for the q-network update.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer (see
        `acme.datasets.reverb.make_dataset` documentation).
      huber_loss_parameter: Quadratic-linear boundary for Huber loss.
      replay_client: client to replay to allow for updating priorities.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        # Create a target networks.
        self._target_policy_network = copy.deepcopy(policy_network)
        self._target_critic_network = copy.deepcopy(critic_network)
        self._critic_optimizer = critic_optimizer
        self._policy_optimizer = policy_optimizer
        self._target_update_period = target_update_period

        # Internalise the hyperparameters.
        self._discount = discount
        self._target_update_period = target_update_period
        # crr specific
        assert policy_improvement_modes in [
            'exp', 'binary', 'all'
        ], 'Policy imp. mode must be one of {exp, binary, all}'
        self._policy_improvement_modes = policy_improvement_modes
        self._beta = beta
        self._ratio_upper_bound = ratio_upper_bound
        # cql specific
        self._alpha = tf.constant(cql_alpha, dtype=tf.float32)
        self._tr = tf.constant(translate_lse, dtype=tf.float32)
        if cql_alpha:
            assert empirical_policy is not None, 'Empirical behavioural policy must be specified with non-zero cql_alpha.'
        self._emp_policy = empirical_policy

        # Learner state.
        # Expose the variables.
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': self._target_policy_network.variables,
        }

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._counter.increment(learner_steps=0)
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Create a checkpointer and snapshoter object.
        self._checkpointer = tf2_savers.Checkpointer(
            objects_to_save=self.state,
            time_delta_minutes=10.,
            directory=checkpoint_subpath,
            subdirectory='crr_learner')

        objects_to_save = {
            'raw_policy': policy_network,
            'critic': critic_network,
        }
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save=objects_to_save, time_delta_minutes=10)
        # Timestamp to keep track of the wall time.
        self._walltime_timestamp = time.time()
Exemplo n.º 21
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        sampling_head: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        target_update_period: int,
        learning_rate: float,
        entropy_coeff: float,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation,
        target_observation_network: types.TensorTransformation,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        clipping: bool = False,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        checkpoint: bool = True,
    ):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      discount: discount to use for TD updates.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the critic loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network
        self._sampling_head = sampling_head

        self._observation_network = observation_network
        self._target_observation_network = target_observation_network

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._clipping = clipping
        self._entropy_coeff = entropy_coeff
        self._learning_rate = learning_rate

        self._ReFER = model.ReFER_module()

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Create an iterator to go through the dataset.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Create optimizers if they aren't given.
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._ReFER_optimizer = snt.optimizers.SGD(1e-3)

        # Get trainable variables.
        # Here we train the preprocessing variables along with the policy.
        self._policy_variables = (
            self._observation_network.trainable_variables +
            self._policy_network.trainable_variables)
        self._critic_variables = self._critic_network.trainable_variables

        self._checkpointer = tf2_savers.Checkpointer(
            time_delta_minutes=10,
            objects_to_save={
                'counter': self._counter,
                'policy': self._policy_network,
                'critic': self._critic_network,
                'encoder': self._observation_network,
                'ReFER': self._ReFER,
                'policy_optimizer': self._policy_optimizer,
                'critic_optimizer': self._critic_optimizer,
            },
            enable_checkpointing=checkpoint,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Exemplo n.º 22
0
    def __init__(self,
                 policy_network: snt.RNNCore,
                 critic_network: networks.CriticDeepRNN,
                 target_policy_network: snt.RNNCore,
                 target_critic_network: networks.CriticDeepRNN,
                 dataset: tf.data.Dataset,
                 accelerator_strategy: Optional[tf.distribute.Strategy] = None,
                 behavior_network: Optional[snt.Module] = None,
                 cwp_network: Optional[snt.Module] = None,
                 policy_optimizer: Optional[snt.Optimizer] = None,
                 critic_optimizer: Optional[snt.Optimizer] = None,
                 discount: float = 0.99,
                 target_update_period: int = 100,
                 num_action_samples_td_learning: int = 1,
                 num_action_samples_policy_weight: int = 4,
                 baseline_reduce_function: str = 'mean',
                 clipping: bool = True,
                 policy_improvement_modes: str = 'exp',
                 ratio_upper_bound: float = 20.,
                 beta: float = 1.0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 checkpoint: bool = False):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_reverb_dataset` documentation).
      accelerator_strategy: the strategy used to distribute computation,
        whether on a single, or multiple, GPU or TPU; as supported by
        tf.distribute.
      behavior_network: The network to snapshot under `policy` name. If None,
        snapshots `policy_network` instead.
      cwp_network: CWP network to snapshot: samples actions
        from the policy and weighs them with the critic, then returns the action
        by sampling from the softmax distribution using critic values as logits.
        Used only for snapshotting, not training.
      policy_optimizer: the optimizer to be applied to the policy loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      num_action_samples_td_learning: number of action samples to use to
        estimate expected value of the critic loss w.r.t. stochastic policy.
      num_action_samples_policy_weight: number of action samples to use to
        estimate the advantage function for the CRR weighting of the policy
        loss.
      baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating
        values from `num_action_samples` estimates of the value function.
      clipping: whether to clip gradients by global norm.
      policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which
        determines how the advantage function is processed before being
        multiplied by the policy loss.
      ratio_upper_bound: if policy_improvement_modes is 'exp', determines
        the upper bound of the weight (i.e. the weight is
          min(exp(advantage / beta), upper_bound)
        ).
      beta: if policy_improvement_modes is 'exp', determines the beta (see
        above).
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        if accelerator_strategy is None:
            accelerator_strategy = snt.distribute.Replicator()
        self._accelerator_strategy = accelerator_strategy
        self._policy_improvement_modes = policy_improvement_modes
        self._ratio_upper_bound = ratio_upper_bound
        self._num_action_samples_td_learning = num_action_samples_td_learning
        self._num_action_samples_policy_weight = num_action_samples_policy_weight
        self._baseline_reduce_function = baseline_reduce_function
        self._beta = beta

        # When running on TPUs we have to know the amount of memory required (and
        # thus the sequence length) at the graph compilation stage. At the moment,
        # the only way to get it is to sample from the dataset, since the dataset
        # does not have any metadata, see b/160672927 to track this upcoming
        # feature.
        sample = next(dataset.as_numpy_iterator())
        self._sequence_length = sample.action.shape[1]

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)
        self._discount = discount
        self._clipping = clipping

        self._target_update_period = target_update_period

        with self._accelerator_strategy.scope():
            # Necessary to track when to update target networks.
            self._num_steps = tf.Variable(0, dtype=tf.int32)

            # (Maybe) distributing the dataset across multiple accelerators.
            distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset(
                dataset)
            self._iterator = iter(distributed_dataset)

            # Create the optimizers.
            self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(
                1e-4)
            self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(
                1e-4)

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Expose the variables.
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': self._target_policy_network.variables,
        }

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None
        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                },
                time_delta_minutes=30.)

            raw_policy = snt.DeepRNN(
                [policy_network,
                 networks.StochasticSamplingHead()])
            critic_mean = networks.CriticDeepRNN(
                [critic_network, networks.StochasticMeanHead()])
            objects_to_save = {
                'raw_policy': raw_policy,
                'critic': critic_mean,
            }
            if behavior_network is not None:
                objects_to_save['policy'] = behavior_network
            if cwp_network is not None:
                objects_to_save['cwp_policy'] = cwp_network
            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save=objects_to_save, time_delta_minutes=30)
        # Timestamp to keep track of the wall time.
        self._walltime_timestamp = time.time()
Exemplo n.º 23
0
    def __init__(
        self,
        seed: int,
        environment_spec: specs.EnvironmentSpec,
        builder: builders.ActorLearnerBuilder,
        networks: Any,
        policy_network: Any,
        learner_logger: Optional[loggers.Logger] = None,
        workdir: Optional[str] = '~/acme',
        batch_size: int = 256,
        num_sgd_steps_per_step: int = 1,
        prefetch_size: int = 1,
        counter: Optional[counting.Counter] = None,
        checkpoint: bool = True,
    ):
        """Initialize the agent.

    Args:
      seed: A random seed to use for this layout instance.
      environment_spec: description of the actions, observations, etc.
      builder: builder defining an RL algorithm to train.
      networks: network objects to be passed to the learner.
      policy_network: function that given an observation returns actions.
      learner_logger: logger used by the learner.
      workdir: if provided saves the state of the learner and the counter
        (if the counter is not None) into workdir.
      batch_size: batch size for updates.
      num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
        For performance reasons (especially to reduce TPU host-device transfer
        times) it is performance-beneficial to do multiple sgd updates at once,
        provided that it does not hurt the training, which needs to be verified
        empirically for each environment.
      prefetch_size: whether to prefetch iterator.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner
        and the counter (if the counter is not None).
    """
        if prefetch_size < 0:
            raise ValueError(
                f'Prefetch size={prefetch_size} should be non negative')

        key = jax.random.PRNGKey(seed)

        # Create the replay server and grab its address.
        replay_tables = builder.make_replay_tables(environment_spec,
                                                   policy_network)

        # Disable blocking of inserts by tables' rate limiters, as LocalLayout
        # agents run inserts and sampling from the same thread and blocked insert
        # would result in a hang.
        new_tables = []
        for table in replay_tables:
            rl_info = table.info.rate_limiter_info
            rate_limiter = reverb.rate_limiters.RateLimiter(
                samples_per_insert=rl_info.samples_per_insert,
                min_size_to_sample=rl_info.min_size_to_sample,
                min_diff=rl_info.min_diff,
                max_diff=sys.float_info.max)
            new_tables.append(table.replace(rate_limiter=rate_limiter))
        replay_tables = new_tables

        replay_server = reverb.Server(replay_tables, port=None)
        replay_client = reverb.Client(f'localhost:{replay_server.port}')

        # Create actor, dataset, and learner for generating, storing, and consuming
        # data respectively.
        adder = builder.make_adder(replay_client, environment_spec,
                                   policy_network)

        dataset = builder.make_dataset_iterator(replay_client)
        # We always use prefetch, as it provides an iterator with additional
        # 'ready' method.
        dataset = utils.prefetch(dataset, buffer_size=prefetch_size)
        learner_key, key = jax.random.split(key)
        learner = builder.make_learner(
            random_key=learner_key,
            networks=networks,
            dataset=dataset,
            logger_fn=(lambda label, steps_key=None, task_instance=None:
                       learner_logger),
            environment_spec=environment_spec,
            replay_client=replay_client,
            counter=counter)
        if not checkpoint or workdir is None:
            self._checkpointer = None
        else:
            objects_to_save = {'learner': learner}
            if counter is not None:
                objects_to_save.update({'counter': counter})
            self._checkpointer = savers.Checkpointer(
                objects_to_save,
                time_delta_minutes=30,
                subdirectory='learner',
                directory=workdir,
                add_uid=(workdir == '~/acme'))

        actor_key, key = jax.random.split(key)
        actor = builder.make_actor(actor_key,
                                   policy_network,
                                   environment_spec,
                                   variable_source=learner,
                                   adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         iterator=dataset,
                         replay_tables=replay_tables)

        # Save the replay so we don't garbage collect it.
        self._replay_server = replay_server
Exemplo n.º 24
0
Arquivo: agent.py Projeto: dzorlu/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        target_network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        demonstration_generator: iter,
        demonstration_ratio: float,
        model_directory: str,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        log_to_bigtable: bool = False,
        log_name: str = 'agent',
        checkpoint: bool = True,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
    ):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # replay table
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        # demonstation table.
        demonstration_table = reverb.Table(
            name='demonstration_table',
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))

        # launch server
        self._server = reverb.Server([replay_table, demonstration_table],
                                     port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1

        # Component to add things into replay and demo
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)
        priority_function = {demonstration_table.name: lambda x: 1.}
        demo_adder = adders.SequenceAdder(client=reverb.Client(address),
                                          priority_fns=priority_function,
                                          **sequence_kwargs)
        # play demonstrations and write
        # exhaust the generator
        # TODO: MAX REPLAY SIZE
        _prev_action = 1  # this has to come from spec
        _add_first = True
        #include this to make datasets equivalent
        numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1))
        for ts, action in demonstration_generator:
            if _add_first:
                demo_adder.add_first(ts)
                _add_first = False
            else:
                demo_adder.add(_prev_action, ts, extras=(numpy_state, ))
            _prev_action = action
            # reset to new episode
            if ts.last():
                _prev_action = None
                _add_first = True

        # replay dataset
        max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=
            2,  # memory perf improvment attempt  https://github.com/deepmind/acme/issues/33
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        # demonstation dataset
        d_dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=demonstration_table.name,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=2,
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, d_dataset],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            directory=model_directory,
            subdirectory='r2d2_learner_v1',
            time_delta_minutes=15,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None,
                                                   time_delta_minutes=15000.,
                                                   directory=model_directory)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Exemplo n.º 25
0
    def __init__(self,
                 value_func: snt.Module,
                 mixture_density: snt.Module,
                 policy_net: snt.Module,
                 discount: float,
                 value_learning_rate: float,
                 density_learning_rate: float,
                 n_sampling: int,
                 density_iter: int,
                 dataset: tf.data.Dataset,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 checkpoint_interval_minutes: int = 10.0):
        """Initializes the learner.

        Args:
          value_func: value function network
          mixture_density: mixture density function network.
          policy_net: policy network.
          discount: global discount.
          value_learning_rate: learning rate for the treatment_net update.
          density_learning_rate: learning rate for the mixture_density update.
          n_sampling: number of samples generated in stage 2,
          density_iter: number of iteration for mixture_density function,
          dataset: dataset to learn from.
          counter: Counter object for (potentially distributed) counting.
          logger: Logger object for writing logs to.
          checkpoint: boolean indicating whether to checkpoint the learner.
          checkpoint_interval_minutes: checkpoint interval in minutes.
        """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self.density_iter = density_iter
        self.n_sampling = n_sampling
        self.discount = discount

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self.value_func = value_func
        self.mixture_density = mixture_density
        self.policy = policy_net
        self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate)
        self._mixture_density_optimizer = snt.optimizers.Adam(
            density_learning_rate)

        self._variables = [
            value_func.trainable_variables,
            mixture_density.trainable_variables,
        ]
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._mse = tf.keras.losses.MeanSquaredError()

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save=self.state,
                time_delta_minutes=checkpoint_interval_minutes,
                checkpoint_ttl_seconds=_CHECKPOINT_TTL)
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'value_func':
                value_func,
                'mixture_density':
                mixture_density,
            },
                                                       time_delta_minutes=60.)
Exemplo n.º 26
0
    def __init__(self,
                 value_func: snt.Module,
                 instrumental_feature: snt.Module,
                 policy_net: snt.Module,
                 discount: float,
                 value_learning_rate: float,
                 instrumental_learning_rate: float,
                 value_reg: float,
                 instrumental_reg: float,
                 stage1_reg: float,
                 stage2_reg: float,
                 instrumental_iter: int,
                 value_iter: int,
                 dataset: tf.data.Dataset,
                 d_tm1_weight: float = 1.0,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 checkpoint_interval_minutes: int = 10.0):
        """Initializes the learner.

        Args:
          value_func: value function network
          instrumental_feature: dual function network.
          policy_net: policy network.
          discount: global discount.
          value_learning_rate: learning rate for the treatment_net update.
          instrumental_learning_rate: learning rate for the instrumental_net update.
          value_reg: L2 regularizer for value net.
          instrumental_reg: L2 regularizer for instrumental net.
          stage1_reg: ridge regularizer for stage 1 regression
          stage2_reg: ridge regularizer for stage 2 regression
          instrumental_iter: number of iteration for instrumental net
          value_iter: number of iteration for value function,
          dataset: dataset to learn from.
          d_tm1_weight: weights for terminal state transitions. Ignored in this variant.
          counter: Counter object for (potentially distributed) counting.
          logger: Logger object for writing logs to.
          checkpoint: boolean indicating whether to checkpoint the learner.
          checkpoint_interval_minutes: checkpoint interval in minutes.
        """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self.stage1_reg = stage1_reg
        self.stage2_reg = stage2_reg
        self.instrumental_iter = instrumental_iter
        self.value_iter = value_iter
        self.discount = discount
        self.value_reg = value_reg
        self.instrumental_reg = instrumental_reg
        del d_tm1_weight

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self.value_func = value_func
        self.value_feature = value_func._feature
        self.instrumental_feature = instrumental_feature
        self.policy = policy_net
        self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate,
                                                         beta1=0.5,
                                                         beta2=0.9)
        self._instrumental_func_optimizer = snt.optimizers.Adam(
            instrumental_learning_rate, beta1=0.5, beta2=0.9)

        # Define additional variables.
        self.stage1_weight = tf.Variable(
            tf.zeros(
                (instrumental_feature.feature_dim(), value_func.feature_dim()),
                dtype=tf.float32))
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._variables = [
            self.value_func.trainable_variables,
            self.instrumental_feature.trainable_variables,
            self.stage1_weight,
        ]

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save=self.state,
                time_delta_minutes=checkpoint_interval_minutes,
                checkpoint_ttl_seconds=_CHECKPOINT_TTL)
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'value_func':
                self.value_func,
                'instrumental_feature':
                self.instrumental_feature,
            },
                                                       time_delta_minutes=60.)
Exemplo n.º 27
0
  def __init__(
      self,
      seed: int,
      environment_spec: specs.EnvironmentSpec,
      builder: builders.GenericActorLearnerBuilder,
      networks: Any,
      policy_network: Any,
      workdir: Optional[str] = '~/acme',
      min_replay_size: int = 1000,
      samples_per_insert: float = 256.0,
      batch_size: int = 256,
      num_sgd_steps_per_step: int = 1,
      prefetch_size: int = 1,
      device_prefetch: bool = True,
      counter: Optional[counting.Counter] = None,
      checkpoint: bool = True,
  ):
    """Initialize the agent.

    Args:
      seed: A random seed to use for this layout instance.
      environment_spec: description of the actions, observations, etc.
      builder: builder defining an RL algorithm to train.
      networks: network objects to be passed to the learner.
      policy_network: function that given an observation returns actions.
      workdir: if provided saves the state of the learner and the counter
        (if the counter is not None) into workdir.
      min_replay_size: minimum replay size before updating.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      batch_size: batch size for updates.
      num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call.
        For performance reasons (especially to reduce TPU host-device transfer
        times) it is performance-beneficial to do multiple sgd updates at once,
        provided that it does not hurt the training, which needs to be verified
        empirically for each environment.
      prefetch_size: whether to prefetch iterator.
      device_prefetch: whether prefetching should happen to a device.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner
        and the counter (if the counter is not None).
    """
    if prefetch_size < 0:
      raise ValueError(f'Prefetch size={prefetch_size} should be non negative')

    key = jax.random.PRNGKey(seed)

    # Create the replay server and grab its address.
    replay_tables = builder.make_replay_tables(environment_spec)
    replay_server = reverb.Server(replay_tables, port=None)
    replay_client = reverb.Client(f'localhost:{replay_server.port}')

    # Create actor, dataset, and learner for generating, storing, and consuming
    # data respectively.
    adder = builder.make_adder(replay_client)

    def _is_reverb_queue(reverb_table: reverb.Table,
                         reverb_client: reverb.Client) -> bool:
      """Returns True iff the Reverb Table is actually a queue."""
      # TODO(sinopalnikov): make it more generic and check for a table that
      # needs special handling on update.
      info = reverb_client.server_info()
      table_info = info[reverb_table.name]
      is_queue = (
          table_info.max_times_sampled == 1 and
          table_info.sampler_options.fifo and
          table_info.remover_options.fifo)
      return is_queue

    is_reverb_queue = any(_is_reverb_queue(table, replay_client)
                          for table in replay_tables)

    dataset = builder.make_dataset_iterator(replay_client)
    if prefetch_size > 1:
      device = jax.devices()[0] if device_prefetch else None
      dataset = utils.prefetch(dataset, buffer_size=prefetch_size,
                               device=device)
    learner_key, key = jax.random.split(key)
    learner = builder.make_learner(
        random_key=learner_key,
        networks=networks,
        dataset=dataset,
        replay_client=replay_client,
        counter=counter)
    if not checkpoint or workdir is None:
      self._checkpointer = None
    else:
      objects_to_save = {'learner': learner}
      if counter is not None:
        objects_to_save.update({'counter': counter})
      self._checkpointer = savers.Checkpointer(
          objects_to_save,
          time_delta_minutes=30,
          subdirectory='learner',
          directory=workdir,
          add_uid=(workdir == '~/acme'))

    actor_key, key = jax.random.split(key)
    actor = builder.make_actor(
        actor_key, policy_network, adder, variable_source=learner)
    self._custom_update_fn = None
    if is_reverb_queue:
      # Reverb queue requires special handling on update: custom logic to
      # decide when it is safe to make a learner step. This is only needed for
      # the local agent, where the actor and the learner are running
      # synchronously and the learner will deadlock if it makes a step with
      # no data available.
      def custom_update():
        should_update_actor = False
        # Run a number of learner steps (usually gradient steps).
        # TODO(raveman): This is wrong. When running multi-level learners,
        # different levels might have different batch sizes. Find a solution.
        while all(table.can_sample(batch_size) for table in replay_tables):
          learner.step()
          should_update_actor = True

        if should_update_actor:
          # "wait=True" to make it more onpolicy
          actor.update(wait=True)

      self._custom_update_fn = custom_update

    effective_batch_size = batch_size * num_sgd_steps_per_step
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(effective_batch_size, min_replay_size),
        observations_per_step=float(effective_batch_size) / samples_per_insert)

    # Save the replay so we don't garbage collect it.
    self._replay_server = replay_server