Ejemplo n.º 1
0
 def test_save_and_new_restore(self):
     """Tests that a fresh checkpointer correctly restores an existing ckpt."""
     with mock.patch.object(paths, 'get_unique_id') as mock_unique_id:
         mock_unique_id.return_value = ('test', )
         x = tf.Variable(0, dtype=tf.int32)
         directory = self.get_tempdir()
         checkpointer1 = tf2_savers.Checkpointer(objects_to_save={'x': x},
                                                 time_delta_minutes=0.,
                                                 directory=directory)
         checkpointer1.save()
         x.assign_add(1)
         # Simulate a preemption: x is changed, and we make a new Checkpointer.
         checkpointer2 = tf2_savers.Checkpointer(objects_to_save={'x': x},
                                                 time_delta_minutes=0.,
                                                 directory=directory)
         checkpointer2.restore()
         np.testing.assert_array_equal(x.numpy(), np.int32(0))
Ejemplo n.º 2
0
    def test_no_checkpoint(self):
        """Test that checkpointer does nothing when checkpoint=False."""
        num_steps = tf.Variable(0)
        checkpointer = tf2_savers.Checkpointer(
            objects_to_save={'num_steps': num_steps},
            enable_checkpointing=False)

        for _ in range(10):
            self.assertFalse(checkpointer.save())
        self.assertIsNone(checkpointer._checkpoint_manager)
Ejemplo n.º 3
0
    def test_save_and_restore(self):
        """Test that checkpointer correctly calls save and restore."""

        x = tf.Variable(0, dtype=tf.int32)
        directory = self.get_tempdir()
        checkpointer = tf2_savers.Checkpointer(objects_to_save={'x': x},
                                               time_delta_minutes=0.,
                                               directory=directory)

        for _ in range(10):
            saved = checkpointer.save()
            self.assertTrue(saved)
            x.assign_add(1)
            checkpointer.restore()
            np.testing.assert_array_equal(x.numpy(), np.int32(0))
Ejemplo n.º 4
0
    def test_save_and_restore_time_based(self):
        """Test that checkpointer correctly calls save and restore based on time."""

        x = tf.Variable(0, dtype=tf.int32)
        directory = self.get_tempdir()
        checkpointer = tf2_savers.Checkpointer(objects_to_save={'x': x},
                                               time_delta_minutes=1.,
                                               directory=directory)

        with mock.patch.object(time, 'time') as mock_time:
            mock_time.return_value = 0.
            self.assertFalse(checkpointer.save())

            mock_time.return_value = 40.
            self.assertFalse(checkpointer.save())

            mock_time.return_value = 70.
            self.assertTrue(checkpointer.save())
        x.assign_add(1)
        checkpointer.restore()
        np.testing.assert_array_equal(x.numpy(), np.int32(0))
Ejemplo n.º 5
0
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self._policy_loss_module = policy_loss_module or losses.MPO(
            epsilon=1e-1,
            epsilon_penalty=1e-3,
            epsilon_mean=1e-3,
            epsilon_stddev=1e-6,
            init_log_temperature=1.,
            init_log_alpha_mean=1.,
            init_log_alpha_stddev=10.)

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Ejemplo n.º 6
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: snt.Module,
      batch_size: int = 256,
      prefetch_size: int = 4,
      target_update_period: int = 100,
      samples_per_insert: float = 32.0,
      min_replay_size: int = 1000,
      max_replay_size: int = 1000000,
      importance_sampling_exponent: float = 0.2,
      priority_exponent: float = 0.6,
      n_step: int = 5,
      epsilon: tf.Tensor = None,
      learning_rate: float = 1e-3,
      discount: float = 0.99,
  ):
    """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      network: the online Q network (the one being optimized)
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      min_replay_size: minimum replay size before updating. This and all
        following arguments are related to dataset construction and will be
        ignored if a dataset argument is passed.
      max_replay_size: maximum replay size.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      priority_exponent: exponent used in prioritized sampling.
      n_step: number of steps to squash into a single transition.
      epsilon: probability of taking a random action; ignored if a policy
        network is given.
      learning_rate: learning rate for the q-network update.
      discount: discount to use for TD updates.
    """

    # Create a replay server to add data to. This uses no limiter behavior in
    # order to allow the Agent interface to handle it.
    replay_table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Prioritized(priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(1))
    self._server = reverb.Server([replay_table], port=None)

    # The adder is used to insert observations into replay.
    address = f'localhost:{self._server.port}'
    adder = adders.NStepTransitionAdder(
        client=reverb.Client(address),
        n_step=n_step,
        discount=discount)

    # The dataset provides an interface to sample from replay.
    replay_client = reverb.TFClient(address)
    dataset = datasets.make_reverb_dataset(
        client=replay_client,
        environment_spec=environment_spec,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        transition_adder=True)

    # Use constant 0.05 epsilon greedy policy by default.
    if epsilon is None:
      epsilon = tf.Variable(0.05, trainable=False)
    policy_network = snt.Sequential([
        network,
        lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
    ])

    # Create a target network.
    target_network = copy.deepcopy(network)

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(network, [environment_spec.observations])
    tf2_utils.create_variables(target_network, [environment_spec.observations])

    # Create the actor which defines how we take actions.
    actor = actors.FeedForwardActor(policy_network, adder)

    # The learner updates the parameters (and initializes them).
    learner = learning.DQNLearner(
        network=network,
        target_network=target_network,
        discount=discount,
        importance_sampling_exponent=importance_sampling_exponent,
        learning_rate=learning_rate,
        target_update_period=target_update_period,
        dataset=dataset,
        replay_client=replay_client)

    self._checkpointer = tf2_savers.Checkpointer(
        objects_to_save=learner.state,
        subdirectory='dqn_learner',
        time_delta_minutes=60.)

    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(batch_size, min_replay_size),
        observations_per_step=float(batch_size) / samples_per_insert)
Ejemplo n.º 7
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        reverb_client = reverb.TFClient(address)
        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        dataset = datasets.make_reverb_dataset(
            client=reverb_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb_client,
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
Ejemplo n.º 8
0
  def __init__(
      self,
      policy_network: snt.Module,
      critic_network: snt.Module,
      target_policy_network: snt.Module,
      target_critic_network: snt.Module,
      discount: float,
      target_update_period: int,
      dataset: tf.data.Dataset,
      observation_network: types.TensorTransformation = lambda x: x,
      target_observation_network: types.TensorTransformation = lambda x: x,
      policy_optimizer: snt.Optimizer = None,
      critic_optimizer: snt.Optimizer = None,
      clipping: bool = True,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
  ):
    """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    # Store online and target networks.
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_policy_network = target_policy_network
    self._target_critic_network = target_critic_network

    # Make sure observation networks are snt.Module's so they have variables.
    self._observation_network = tf2_utils.to_sonnet_module(observation_network)
    self._target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')

    # Other learner parameters.
    self._discount = discount
    self._clipping = clipping

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    # TODO(b/155086959): Fix type stubs and remove.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
    self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

    # Expose the variables.
    policy_network_to_expose = snt.Sequential(
        [self._target_observation_network, self._target_policy_network])
    self._variables = {
        'critic': self._target_critic_network.variables,
        'policy': policy_network_to_expose.variables,
    }

    # Create a checkpointer and snapshotter objects.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          subdirectory='d4pg_learner',
          objects_to_save={
              'counter': self._counter,
              'policy': self._policy_network,
              'critic': self._critic_network,
              'observation': self._observation_network,
              'target_policy': self._target_policy_network,
              'target_critic': self._target_critic_network,
              'target_observation': self._target_observation_network,
              'policy_optimizer': self._policy_optimizer,
              'critic_optimizer': self._critic_optimizer,
              'num_steps': self._num_steps,
          })
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'policy': self._policy_network,
              'critic': critic_mean,
          })

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Ejemplo n.º 9
0
  def __init__(self,
               environment_spec: specs.EnvironmentSpec,
               network: snt.RNNCore,
               target_network: snt.RNNCore,
               burn_in_length: int,
               trace_length: int,
               replay_period: int,
               demonstration_dataset: tf.data.Dataset,
               demonstration_ratio: float,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               discount: float = 0.99,
               batch_size: int = 32,
               target_update_period: int = 100,
               importance_sampling_exponent: float = 0.2,
               epsilon: float = 0.01,
               learning_rate: float = 1e-3,
               log_to_bigtable: bool = False,
               log_name: str = 'agent',
               checkpoint: bool = True,
               min_replay_size: int = 1000,
               max_replay_size: int = 1000000,
               samples_per_insert: float = 32.0):

    replay_table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
    self._server = reverb.Server([replay_table], port=None)
    address = f'localhost:{self._server.port}'

    sequence_length = burn_in_length + trace_length + 1
    # Component to add things into replay.
    sequence_kwargs = dict(
        period=replay_period,
        sequence_length=sequence_length,
    )
    adder = adders.SequenceAdder(client=reverb.Client(address),
                                   **sequence_kwargs)

    # The dataset object to learn from.
    reverb_client = reverb.TFClient(address)
    extra_spec = {
        'core_state': network.initial_state(1),
    }
    # Remove batch dimensions.
    extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
    dataset = datasets.make_reverb_dataset(
        client=reverb_client,
        environment_spec=environment_spec,
        extra_spec=extra_spec,
        sequence_length=sequence_length)

    # Combine with demonstration dataset.
    transition = functools.partial(_sequence_from_episode,
                                   extra_spec=extra_spec,
                                   **sequence_kwargs)
    dataset_demos = demonstration_dataset.map(transition)
    dataset = tf.data.experimental.sample_from_datasets(
        [dataset, dataset_demos],
        [1 - demonstration_ratio, demonstration_ratio])

    # Batch and prefetch.
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    tf2_utils.create_variables(network, [environment_spec.observations])
    tf2_utils.create_variables(target_network, [environment_spec.observations])

    learner = learning.R2D2Learner(
        environment_spec=environment_spec,
        network=network,
        target_network=target_network,
        burn_in_length=burn_in_length,
        dataset=dataset,
        reverb_client=reverb_client,
        counter=counter,
        logger=logger,
        sequence_length=sequence_length,
        discount=discount,
        target_update_period=target_update_period,
        importance_sampling_exponent=importance_sampling_exponent,
        max_replay_size=max_replay_size,
        learning_rate=learning_rate,
        store_lstm_state=False,
    )

    self._checkpointer = tf2_savers.Checkpointer(
        subdirectory='r2d2_learner',
        time_delta_minutes=60,
        objects_to_save=learner.state,
        enable_checkpointing=checkpoint,
    )

    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    policy_network = snt.DeepRNN([
        network,
        lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
    ])

    actor = actors.RecurrentActor(policy_network, adder)
    observations_per_step = (float(replay_period * batch_size) /
                             samples_per_insert)
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=replay_period * max(batch_size, min_replay_size),
        observations_per_step=observations_per_step)