コード例 #1
0
ファイル: savers_test.py プロジェクト: vishalbelsare/acme
    def test_signal_handling(self):
        x = DummySaveable()

        # Increment the value of DummySavable.
        x.state['state'].assign_add(1)

        directory = self.get_tempdir()

        # Patch signals.add_handler so the registered signal handler sets the event.
        with mock.patch.object(
                launchpad,
                'register_stop_handler') as mock_register_stop_handler:

            def add_handler(fn):
                fn()

            mock_register_stop_handler.side_effect = add_handler

            runner = tf2_savers.CheckpointingRunner(wrapped=x,
                                                    time_delta_minutes=0,
                                                    directory=directory)
            with self.assertRaises(SystemExit):
                runner.run()

        # Recreate DummySavable(), its tf.Variable is initialized to 0.
        x = DummySaveable()
        # Recreate the CheckpointingRunner, which will restore DummySavable() to 1.
        tf2_savers.CheckpointingRunner(wrapped=x,
                                       time_delta_minutes=0,
                                       directory=directory)
        # Check DummyVariable() was restored properly.
        np.testing.assert_array_equal(x.state['state'].numpy(), np.int32(1))
コード例 #2
0
  def learner(self, queue: reverb.Client, counter: counting.Counter):
    """The Learning part of the agent."""
    # Use architect and create the environment.
    # Create the networks.
    network = self._network_factory(self._environment_spec.actions)
    tf2_utils.create_variables(network, [self._environment_spec.observations])

    # The dataset object to learn from.
    dataset = datasets.make_reverb_dataset(
        server_address=queue.server_address,
        batch_size=self._batch_size,
        prefetch_size=self._prefetch_size)

    logger = loggers.make_default_logger('learner', steps_key='learner_steps')
    counter = counting.Counter(counter, 'learner')

    # Return the learning agent.
    learner = learning.IMPALALearner(
        environment_spec=self._environment_spec,
        network=network,
        dataset=dataset,
        discount=self._discount,
        learning_rate=self._learning_rate,
        entropy_cost=self._entropy_cost,
        baseline_cost=self._baseline_cost,
        max_abs_reward=self._max_abs_reward,
        max_gradient_norm=self._max_gradient_norm,
        counter=counter,
        logger=logger,
    )

    return tf2_savers.CheckpointingRunner(learner,
                                          time_delta_minutes=5,
                                          subdirectory='impala_learner')
コード例 #3
0
ファイル: savers_test.py プロジェクト: vishalbelsare/acme
 def test_checkpoint_dir(self):
     directory = self.get_tempdir()
     ckpt_runner = tf2_savers.CheckpointingRunner(wrapped=DummySaveable(),
                                                  time_delta_minutes=0,
                                                  directory=directory)
     expected_dir_re = f'{directory}/[a-z0-9-]*/checkpoints/default'
     regexp = re.compile(expected_dir_re)
     self.assertIsNotNone(regexp.fullmatch(ckpt_runner.get_directory()))
コード例 #4
0
ファイル: savers_test.py プロジェクト: zerocurve/acme
  def test_tf_saveable(self):
    x = DummySaveable()

    directory = self.get_tempdir()
    checkpoint_runner = tf2_savers.CheckpointingRunner(
        x, time_delta_minutes=0, directory=directory)
    checkpoint_runner._checkpointer.save()

    x._state.assign_add(1)
    checkpoint_runner._checkpointer.restore()

    np.testing.assert_array_equal(x._state.numpy(), np.int32(0))
コード例 #5
0
ファイル: agent_distributed.py プロジェクト: deepmind/acme
  def learner(self, replay: reverb.Client, counter: counting.Counter):
    """The Learning part of the agent."""
    # Use architect and create the environment.
    # Create the networks.
    network = self._network_factory(self._environment_spec.actions)
    target_network = copy.deepcopy(network)

    tf2_utils.create_variables(network, [self._obs_spec])
    tf2_utils.create_variables(target_network, [self._obs_spec])

    # The dataset object to learn from.
    reverb_client = reverb.TFClient(replay.server_address)
    sequence_length = self._burn_in_length + self._trace_length + 1
    dataset = datasets.make_reverb_dataset(
        server_address=replay.server_address,
        batch_size=self._batch_size,
        prefetch_size=self._prefetch_size)

    counter = counting.Counter(counter, 'learner')
    logger = loggers.make_default_logger(
        'learner', save_data=True, steps_key='learner_steps')
    # Return the learning agent.
    learner = learning.R2D2Learner(
        environment_spec=self._environment_spec,
        network=network,
        target_network=target_network,
        burn_in_length=self._burn_in_length,
        sequence_length=sequence_length,
        dataset=dataset,
        reverb_client=reverb_client,
        counter=counter,
        logger=logger,
        discount=self._discount,
        target_update_period=self._target_update_period,
        importance_sampling_exponent=self._importance_sampling_exponent,
        learning_rate=self._learning_rate,
        max_replay_size=self._max_replay_size)
    return tf2_savers.CheckpointingRunner(
        wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner')
コード例 #6
0
    def learner(self, replay: reverb.Client, counter: counting.Counter):
        """The Learning part of the agent."""

        # Create the networks.
        network = self._network_factory(self._env_spec.actions)
        target_network = copy.deepcopy(network)

        tf2_utils.create_variables(network, [self._env_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [self._env_spec.observations])

        # The dataset object to learn from.
        replay_client = reverb.Client(replay.server_address)
        dataset = datasets.make_reverb_dataset(
            server_address=replay.server_address,
            batch_size=self._batch_size,
            prefetch_size=self._prefetch_size)

        logger = loggers.make_default_logger('learner',
                                             steps_key='learner_steps')

        # Return the learning agent.
        counter = counting.Counter(counter, 'learner')

        learner = learning.DQNLearner(
            network=network,
            target_network=target_network,
            discount=self._discount,
            importance_sampling_exponent=self._importance_sampling_exponent,
            learning_rate=self._learning_rate,
            target_update_period=self._target_update_period,
            dataset=dataset,
            replay_client=replay_client,
            counter=counter,
            logger=logger)
        return tf2_savers.CheckpointingRunner(learner,
                                              subdirectory='dqn_learner',
                                              time_delta_minutes=60)
コード例 #7
0
 def counter(self):
     return tf2_savers.CheckpointingRunner(counting.Counter(),
                                           time_delta_minutes=1,
                                           subdirectory='counter')
コード例 #8
0
ファイル: agent_distributed.py プロジェクト: deepmind/acme
 def counter(self):
   """Creates the master counter process."""
   return tf2_savers.CheckpointingRunner(
       counting.Counter(), time_delta_minutes=1, subdirectory='counter')