示例#1
0
    def test_snapshot_distribution(self):
        """Test that snapshotter correctly calls saves/restores snapshots."""
        # Create a test network.
        net1 = snt.Sequential([
            networks.LayerNormMLP([10, 10]),
            networks.MultivariateNormalDiagHead(1)
        ])
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net1, [spec])

        # Save the test network.
        directory = self.get_tempdir()
        objects_to_save = {'net': net1}
        snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                             directory=directory)
        snapshotter.save()

        # Reload the test network.
        net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
        inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

        with tf.GradientTape() as tape:
            dist1 = net1(inputs)
            loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
            grads1 = tape.gradient(loss1, net1.trainable_variables)

        with tf.GradientTape() as tape:
            dist2 = net2(inputs)
            loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
            grads2 = tape.gradient(loss2, net2.trainable_variables)

        assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
示例#2
0
文件: learning.py 项目: zzp110/acme
    def __init__(self,
                 network: snt.Module,
                 learning_rate: float,
                 dataset: tf.data.Dataset,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      learning_rate: learning rate for the q-network update.
      dataset: dataset to learn from.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
    """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        # TODO(b/155086959): Fix type stubs and remove.

        self._network = network
        self._optimizer = snt.optimizers.Adam(learning_rate)

        self._variables: List[List[tf.Tensor]] = [network.trainable_variables]
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)
示例#3
0
文件: learning.py 项目: zzp110/acme
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: snt.RNNCore,
      dataset: tf.data.Dataset,
      learning_rate: float,
      discount: float = 0.99,
      entropy_cost: float = 0.,
      baseline_cost: float = 1.,
      max_abs_reward: Optional[float] = None,
      max_gradient_norm: Optional[float] = None,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
  ):

    # Internalise, optimizer, and dataset.
    self._env_spec = environment_spec
    self._optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
    self._network = network
    self._variables = network.variables
    # TODO(b/155086959): Fix type stubs and remove.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

    # Hyperparameters.
    self._discount = discount
    self._entropy_cost = entropy_cost
    self._baseline_cost = baseline_cost

    # Set up reward/gradient clipping.
    if max_abs_reward is None:
      max_abs_reward = np.inf
    if max_gradient_norm is None:
      max_gradient_norm = 1e10  # A very large number. Infinity results in NaNs.
    self._max_abs_reward = tf.convert_to_tensor(max_abs_reward)
    self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm)

    # Set up logging/counting.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
示例#4
0
    def test_force_snapshot(self):
        """Test that the force feature in Snapshotter.save() works correctly."""
        # Create a test network.
        net = snt.Linear(10)
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net, [spec])

        # Save the test network.
        directory = self.get_tempdir()
        objects_to_save = {'net': net}
        # Very long time_delta_minutes.
        snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                             directory=directory,
                                             time_delta_minutes=1000)
        self.assertTrue(snapshotter.save(force=False))

        # Due to the long time_delta_minutes, only force=True will create a new
        # snapshot. This also checks the default is force=False.
        self.assertFalse(snapshotter.save())
        self.assertTrue(snapshotter.save(force=True))
示例#5
0
    def test_rnn_snapshot(self):
        """Test that snapshotter correctly calls saves/restores snapshots on RNNs."""
        # Create a test network.
        net = snt.LSTM(10)
        spec = specs.Array([10], dtype=np.float32)
        tf2_utils.create_variables(net, [spec])

        # Test that if you add some postprocessing without rerunning
        # create_variables, it still works.
        wrapped_net = snt.DeepRNN([net, lambda x: x])

        for net1 in [net, wrapped_net]:
            # Save the test network.
            directory = self.get_tempdir()
            objects_to_save = {'net': net1}
            snapshotter = tf2_savers.Snapshotter(objects_to_save,
                                                 directory=directory)
            snapshotter.save()

            # Reload the test network.
            net2 = tf.saved_model.load(
                os.path.join(snapshotter.directory, 'net'))
            inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

            with tf.GradientTape() as tape:
                outputs1, next_state1 = net1(inputs, net1.initial_state(1))
                loss1 = tf.math.reduce_sum(outputs1)
                grads1 = tape.gradient(loss1, net1.trainable_variables)

            with tf.GradientTape() as tape:
                outputs2, next_state2 = net2(inputs, net2.initial_state(1))
                loss2 = tf.math.reduce_sum(outputs2)
                grads2 = tape.gradient(loss2, net2.trainable_variables)

            assert np.allclose(outputs1, outputs2)
            assert np.allclose(tree.flatten(next_state1),
                               tree.flatten(next_state2))
            assert all(
                tree.map_structure(np.allclose, list(grads1), list(grads2)))
示例#6
0
文件: learning.py 项目: vballoli/acme
    def __init__(
        self,
        policy_network: snt.Module,
        critic_network: snt.Module,
        target_policy_network: snt.Module,
        target_critic_network: snt.Module,
        discount: float,
        num_samples: int,
        target_policy_update_period: int,
        target_critic_update_period: int,
        dataset: tf.data.Dataset,
        observation_network: types.TensorTransformation = tf.identity,
        target_observation_network: types.TensorTransformation = tf.identity,
        policy_loss_module: Optional[snt.Module] = None,
        policy_optimizer: Optional[snt.Optimizer] = None,
        critic_optimizer: Optional[snt.Optimizer] = None,
        dual_optimizer: Optional[snt.Optimizer] = None,
        clipping: bool = True,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        checkpoint: bool = True,
    ):

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._num_samples = num_samples
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_policy_update_period = target_policy_update_period
        self._target_critic_update_period = target_critic_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self._policy_loss_module = policy_loss_module or losses.MPO(
            epsilon=1e-1,
            epsilon_penalty=1e-3,
            epsilon_mean=1e-3,
            epsilon_stddev=1e-6,
            init_log_temperature=1.,
            init_log_alpha_mean=1.,
            init_log_alpha_stddev=10.)

        # Create the optimizers.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='dmpo_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'dual_optimizer': self._dual_optimizer,
                    'policy_loss_module': self._policy_loss_module,
                    'num_steps': self._num_steps,
                })

            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save={
                    'policy':
                    snt.Sequential([
                        self._target_observation_network,
                        self._target_policy_network
                    ]),
                })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
示例#7
0
文件: agent.py 项目: yifeijiang/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        reverb_client = reverb.TFClient(address)
        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        dataset = datasets.make_reverb_dataset(
            client=reverb_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb_client,
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )
        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
示例#8
0
  def __init__(
      self,
      policy_network: snt.Module,
      critic_network: snt.Module,
      target_policy_network: snt.Module,
      target_critic_network: snt.Module,
      discount: float,
      target_update_period: int,
      dataset: tf.data.Dataset,
      observation_network: types.TensorTransformation = lambda x: x,
      target_observation_network: types.TensorTransformation = lambda x: x,
      policy_optimizer: snt.Optimizer = None,
      critic_optimizer: snt.Optimizer = None,
      clipping: bool = True,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
  ):
    """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    # Store online and target networks.
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_policy_network = target_policy_network
    self._target_critic_network = target_critic_network

    # Make sure observation networks are snt.Module's so they have variables.
    self._observation_network = tf2_utils.to_sonnet_module(observation_network)
    self._target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')

    # Other learner parameters.
    self._discount = discount
    self._clipping = clipping

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    # TODO(b/155086959): Fix type stubs and remove.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
    self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

    # Expose the variables.
    policy_network_to_expose = snt.Sequential(
        [self._target_observation_network, self._target_policy_network])
    self._variables = {
        'critic': self._target_critic_network.variables,
        'policy': policy_network_to_expose.variables,
    }

    # Create a checkpointer and snapshotter objects.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          subdirectory='d4pg_learner',
          objects_to_save={
              'counter': self._counter,
              'policy': self._policy_network,
              'critic': self._critic_network,
              'observation': self._observation_network,
              'target_policy': self._target_policy_network,
              'target_critic': self._target_critic_network,
              'target_observation': self._target_observation_network,
              'policy_optimizer': self._policy_optimizer,
              'critic_optimizer': self._critic_optimizer,
              'num_steps': self._num_steps,
          })
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'policy': self._policy_network,
              'critic': critic_mean,
          })

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
示例#9
0
  def __init__(self,
               environment_spec: specs.EnvironmentSpec,
               network: snt.RNNCore,
               target_network: snt.RNNCore,
               burn_in_length: int,
               trace_length: int,
               replay_period: int,
               demonstration_dataset: tf.data.Dataset,
               demonstration_ratio: float,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               discount: float = 0.99,
               batch_size: int = 32,
               target_update_period: int = 100,
               importance_sampling_exponent: float = 0.2,
               epsilon: float = 0.01,
               learning_rate: float = 1e-3,
               log_to_bigtable: bool = False,
               log_name: str = 'agent',
               checkpoint: bool = True,
               min_replay_size: int = 1000,
               max_replay_size: int = 1000000,
               samples_per_insert: float = 32.0):

    replay_table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1))
    self._server = reverb.Server([replay_table], port=None)
    address = f'localhost:{self._server.port}'

    sequence_length = burn_in_length + trace_length + 1
    # Component to add things into replay.
    sequence_kwargs = dict(
        period=replay_period,
        sequence_length=sequence_length,
    )
    adder = adders.SequenceAdder(client=reverb.Client(address),
                                   **sequence_kwargs)

    # The dataset object to learn from.
    reverb_client = reverb.TFClient(address)
    extra_spec = {
        'core_state': network.initial_state(1),
    }
    # Remove batch dimensions.
    extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
    dataset = datasets.make_reverb_dataset(
        client=reverb_client,
        environment_spec=environment_spec,
        extra_spec=extra_spec,
        sequence_length=sequence_length)

    # Combine with demonstration dataset.
    transition = functools.partial(_sequence_from_episode,
                                   extra_spec=extra_spec,
                                   **sequence_kwargs)
    dataset_demos = demonstration_dataset.map(transition)
    dataset = tf.data.experimental.sample_from_datasets(
        [dataset, dataset_demos],
        [1 - demonstration_ratio, demonstration_ratio])

    # Batch and prefetch.
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    tf2_utils.create_variables(network, [environment_spec.observations])
    tf2_utils.create_variables(target_network, [environment_spec.observations])

    learner = learning.R2D2Learner(
        environment_spec=environment_spec,
        network=network,
        target_network=target_network,
        burn_in_length=burn_in_length,
        dataset=dataset,
        reverb_client=reverb_client,
        counter=counter,
        logger=logger,
        sequence_length=sequence_length,
        discount=discount,
        target_update_period=target_update_period,
        importance_sampling_exponent=importance_sampling_exponent,
        max_replay_size=max_replay_size,
        learning_rate=learning_rate,
        store_lstm_state=False,
    )

    self._checkpointer = tf2_savers.Checkpointer(
        subdirectory='r2d2_learner',
        time_delta_minutes=60,
        objects_to_save=learner.state,
        enable_checkpointing=checkpoint,
    )

    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    policy_network = snt.DeepRNN([
        network,
        lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
    ])

    actor = actors.RecurrentActor(policy_network, adder)
    observations_per_step = (float(replay_period * batch_size) /
                             samples_per_insert)
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=replay_period * max(batch_size, min_replay_size),
        observations_per_step=observations_per_step)
示例#10
0
  def __init__(
      self,
      network: snt.Module,
      target_network: snt.Module,
      discount: float,
      importance_sampling_exponent: float,
      learning_rate: float,
      target_update_period: int,
      dataset: tf.data.Dataset,
      huber_loss_parameter: float = 1.,
      replay_client: reverb.TFClient = None,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
  ):
    """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      target_network: the target Q critic (which lags behind the online net).
      discount: discount to use for TD updates.
      importance_sampling_exponent: power to which importance weights are raised
        before normalizing.
      learning_rate: learning rate for the q-network update.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer (see
        `acme.datasets.reverb.make_dataset` documentation).
      huber_loss_parameter: Quadratic-linear boundary for Huber loss.
      replay_client: client to replay to allow for updating priorities.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
    """

    # Internalise agent components (replay buffer, networks, optimizer).
    # TODO(b/155086959): Fix type stubs and remove.
    self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
    self._network = network
    self._target_network = target_network
    self._optimizer = snt.optimizers.Adam(learning_rate)
    self._replay_client = replay_client

    # Internalise the hyperparameters.
    self._discount = discount
    self._target_update_period = target_update_period
    self._importance_sampling_exponent = importance_sampling_exponent
    self._huber_loss_parameter = huber_loss_parameter

    # Learner state.
    self._variables: List[List[tf.Tensor]] = [network.trainable_variables]
    self._num_steps = tf.Variable(0, dtype=tf.int32)

    # Internalise logging/counting objects.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None