def export_model(self, name_behavior_id: str) -> None: """ Exports the model """ policy = self.get_policy(name_behavior_id) settings = SerializationSettings(policy.model_path, policy.brain.brain_name) export_policy_model(settings, policy.graph, policy.sess)
def test_load_save(tmp_path): path1 = os.path.join(tmp_path, "runid1") path2 = os.path.join(tmp_path, "runid2") trainer_params = TrainerSettings() policy = create_policy_mock(trainer_params, model_path=path1) policy.initialize_or_load() policy._set_step(2000) mock_brain_name = "MockBrain" checkpoint_path = f"{policy.model_path}/{mock_brain_name}-2000" serialization_settings = SerializationSettings(policy.model_path, mock_brain_name) policy.checkpoint(checkpoint_path, serialization_settings) assert len(os.listdir(tmp_path)) > 0 # Try load from this path policy2 = create_policy_mock(trainer_params, model_path=path1, load=True, seed=1) policy2.initialize_or_load() _compare_two_policies(policy, policy2) assert policy2.get_current_step() == 2000 # Try initialize from path 1 trainer_params.output_path = path2 trainer_params.init_path = path1 policy3 = create_policy_mock(trainer_params, model_path=path1, load=False, seed=2) policy3.initialize_or_load() _compare_two_policies(policy2, policy3) # Assert that the steps are 0. assert policy3.get_current_step() == 0
def _checkpoint(self) -> NNCheckpoint: """ Checkpoints the policy associated with this trainer. """ n_policies = len(self.policies.keys()) if n_policies > 1: logger.warning( "Trainer has multiple policies, but default behavior only saves the first." ) policy = list(self.policies.values())[0] model_path = policy.model_path settings = SerializationSettings(model_path, self.brain_name) checkpoint_path = os.path.join(model_path, f"{self.brain_name}-{self.step}") policy.checkpoint(checkpoint_path, settings) new_checkpoint = NNCheckpoint( int(self.step), f"{checkpoint_path}.nn", self._policy_mean_reward(), time.time(), ) NNCheckpointManager.add_checkpoint( self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints) return new_checkpoint
def test_policy_conversion(dummy_config, tmpdir, rnn, visual, discrete): tf.reset_default_graph() dummy_config["output_path"] = os.path.join(tmpdir, "test") policy = create_policy_mock(dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual) policy.save_model(1000) settings = SerializationSettings( policy.model_path, os.path.join(tmpdir, policy.brain.brain_name)) export_policy_model(settings, policy.graph, policy.sess) # These checks taken from test_barracuda_converter assert os.path.isfile(os.path.join(tmpdir, "test.nn")) assert os.path.getsize(os.path.join(tmpdir, "test.nn")) > 100
def test_checkpoint_writes_tf_and_nn_checkpoints(export_policy_model_mock): mock_brain = basic_mock_brain() test_seed = 4 # moving up in the world policy = FakePolicy(test_seed, mock_brain, TrainerSettings(), "output") n_steps = 5 policy.get_current_step = MagicMock(return_value=n_steps) policy.saver = MagicMock() serialization_settings = SerializationSettings("output", mock_brain.brain_name) checkpoint_path = f"output/{mock_brain.brain_name}-{n_steps}" policy.checkpoint(checkpoint_path, serialization_settings) policy.saver.save.assert_called_once_with(policy.sess, f"{checkpoint_path}.ckpt") export_policy_model_mock.assert_called_once_with(checkpoint_path, serialization_settings, policy.graph, policy.sess)
def test_policy_conversion(tmpdir, rnn, visual, discrete): tf.reset_default_graph() dummy_config = TrainerSettings() policy = create_policy_mock( dummy_config, use_rnn=rnn, model_path=os.path.join(tmpdir, "test"), use_discrete=discrete, use_visual=visual, ) settings = SerializationSettings(policy.model_path, "MockBrain") checkpoint_path = f"{tmpdir}/MockBrain-1" policy.checkpoint(checkpoint_path, settings) # These checks taken from test_barracuda_converter assert os.path.isfile(checkpoint_path + ".nn") assert os.path.getsize(checkpoint_path + ".nn") > 100
def save_model(self) -> None: """ Saves the policy associated with this trainer. """ n_policies = len(self.policies.keys()) if n_policies > 1: logger.warning( "Trainer has multiple policies, but default behavior only saves the first." ) policy = list(self.policies.values())[0] settings = SerializationSettings(policy.model_path, self.brain_name) model_checkpoint = self._checkpoint() final_checkpoint = attr.evolve( model_checkpoint, file_path=f"{policy.model_path}.nn" ) policy.save(policy.model_path, settings) NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)