def test_signal_handling(self): x = DummySaveable() # Increment the value of DummySavable. x.state['state'].assign_add(1) directory = self.get_tempdir() # Patch signals.add_handler so the registered signal handler sets the event. with mock.patch.object( launchpad, 'register_stop_handler') as mock_register_stop_handler: def add_handler(fn): fn() mock_register_stop_handler.side_effect = add_handler runner = tf2_savers.CheckpointingRunner(wrapped=x, time_delta_minutes=0, directory=directory) with self.assertRaises(SystemExit): runner.run() # Recreate DummySavable(), its tf.Variable is initialized to 0. x = DummySaveable() # Recreate the CheckpointingRunner, which will restore DummySavable() to 1. tf2_savers.CheckpointingRunner(wrapped=x, time_delta_minutes=0, directory=directory) # Check DummyVariable() was restored properly. np.testing.assert_array_equal(x.state['state'].numpy(), np.int32(1))
def learner(self, queue: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Use architect and create the environment. # Create the networks. network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._environment_spec.observations]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=queue.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') counter = counting.Counter(counter, 'learner') # Return the learning agent. learner = learning.IMPALALearner( environment_spec=self._environment_spec, network=network, dataset=dataset, discount=self._discount, learning_rate=self._learning_rate, entropy_cost=self._entropy_cost, baseline_cost=self._baseline_cost, max_abs_reward=self._max_abs_reward, max_gradient_norm=self._max_gradient_norm, counter=counter, logger=logger, ) return tf2_savers.CheckpointingRunner(learner, time_delta_minutes=5, subdirectory='impala_learner')
def test_checkpoint_dir(self): directory = self.get_tempdir() ckpt_runner = tf2_savers.CheckpointingRunner(wrapped=DummySaveable(), time_delta_minutes=0, directory=directory) expected_dir_re = f'{directory}/[a-z0-9-]*/checkpoints/default' regexp = re.compile(expected_dir_re) self.assertIsNotNone(regexp.fullmatch(ckpt_runner.get_directory()))
def test_tf_saveable(self): x = DummySaveable() directory = self.get_tempdir() checkpoint_runner = tf2_savers.CheckpointingRunner( x, time_delta_minutes=0, directory=directory) checkpoint_runner._checkpointer.save() x._state.assign_add(1) checkpoint_runner._checkpointer.restore() np.testing.assert_array_equal(x._state.numpy(), np.int32(0))
def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Use architect and create the environment. # Create the networks. network = self._network_factory(self._environment_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._obs_spec]) tf2_utils.create_variables(target_network, [self._obs_spec]) # The dataset object to learn from. reverb_client = reverb.TFClient(replay.server_address) sequence_length = self._burn_in_length + self._trace_length + 1 dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger( 'learner', save_data=True, steps_key='learner_steps') # Return the learning agent. learner = learning.R2D2Learner( environment_spec=self._environment_spec, network=network, target_network=target_network, burn_in_length=self._burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=self._discount, target_update_period=self._target_update_period, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, max_replay_size=self._max_replay_size) return tf2_savers.CheckpointingRunner( wrapped=learner, time_delta_minutes=60, subdirectory='r2d2_learner')
def learner(self, replay: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Create the networks. network = self._network_factory(self._env_spec.actions) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [self._env_spec.observations]) tf2_utils.create_variables(target_network, [self._env_spec.observations]) # The dataset object to learn from. replay_client = reverb.Client(replay.server_address) dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') # Return the learning agent. counter = counting.Counter(counter, 'learner') learner = learning.DQNLearner( network=network, target_network=target_network, discount=self._discount, importance_sampling_exponent=self._importance_sampling_exponent, learning_rate=self._learning_rate, target_update_period=self._target_update_period, dataset=dataset, replay_client=replay_client, counter=counter, logger=logger) return tf2_savers.CheckpointingRunner(learner, subdirectory='dqn_learner', time_delta_minutes=60)
def counter(self): return tf2_savers.CheckpointingRunner(counting.Counter(), time_delta_minutes=1, subdirectory='counter')
def counter(self): """Creates the master counter process.""" return tf2_savers.CheckpointingRunner( counting.Counter(), time_delta_minutes=1, subdirectory='counter')