def test_load_save(tmp_path): path1 = os.path.join(tmp_path, "runid1") path2 = os.path.join(tmp_path, "runid2") trainer_params = TrainerSettings() policy = create_policy_mock(trainer_params) model_saver = TFModelSaver(trainer_params, path1) model_saver.register(policy) model_saver.initialize_or_load(policy) policy.set_step(2000) mock_brain_name = "MockBrain" model_saver.save_checkpoint(mock_brain_name, 2000) assert len(os.listdir(tmp_path)) > 0 # Try load from this path model_saver = TFModelSaver(trainer_params, path1, load=True) policy2 = create_policy_mock(trainer_params) model_saver.register(policy2) model_saver.initialize_or_load(policy2) _compare_two_policies(policy, policy2) assert policy2.get_current_step() == 2000 # Try initialize from path 1 trainer_params.init_path = path1 model_saver = TFModelSaver(trainer_params, path2) policy3 = create_policy_mock(trainer_params) model_saver.register(policy3) model_saver.initialize_or_load(policy3) _compare_two_policies(policy2, policy3) # Assert that the steps are 0. assert policy3.get_current_step() == 0
def test_register(tmp_path): trainer_params = TrainerSettings() model_saver = TFModelSaver(trainer_params, tmp_path) opt = mock.Mock(spec=PPOOptimizer) model_saver.register(opt) assert model_saver.policy is None trainer_params = TrainerSettings() policy = create_policy_mock(trainer_params) model_saver.register(policy) assert model_saver.policy is not None
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete): tf.reset_default_graph() dummy_config = TrainerSettings() model_path = os.path.join(tmpdir, "Mock_Brain") policy = create_policy_mock(dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual) trainer_params = TrainerSettings() model_saver = TFModelSaver(trainer_params, model_path) model_saver.register(policy) model_saver.save_checkpoint("Mock_Brain", 100) assert os.path.isfile(model_path + "/Mock_Brain-100.nn")
def test_policy_conversion(dummy_config, tmpdir, rnn, visual, discrete): tf.reset_default_graph() dummy_config["output_path"] = os.path.join(tmpdir, "test") policy = create_policy_mock(dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual) policy.save_model(1000) settings = SerializationSettings( policy.model_path, os.path.join(tmpdir, policy.brain.brain_name)) export_policy_model(settings, policy.graph, policy.sess) # These checks taken from test_barracuda_converter assert os.path.isfile(os.path.join(tmpdir, "test.nn")) assert os.path.getsize(os.path.join(tmpdir, "test.nn")) > 100
def test_version_compare(self): # Test write_stats with self.assertLogs("mlagents.trainers", level="WARNING") as cm: trainer_params = TrainerSettings() mock_path = tempfile.mkdtemp() policy = create_policy_mock(trainer_params) model_saver = TFModelSaver(trainer_params, mock_path) model_saver.register(policy) model_saver._check_model_version( "0.0.0") # This is not the right version for sure # Assert that 1 warning has been thrown with incorrect version assert len(cm.output) == 1 model_saver._check_model_version( __version__) # This should be the right version # Assert that no additional warnings have been thrown wth correct ver assert len(cm.output) == 1
def test_policy_conversion(tmpdir, rnn, visual, discrete): tf.reset_default_graph() dummy_config = TrainerSettings() policy = create_policy_mock( dummy_config, use_rnn=rnn, model_path=os.path.join(tmpdir, "test"), use_discrete=discrete, use_visual=visual, ) settings = SerializationSettings(policy.model_path, "MockBrain") checkpoint_path = f"{tmpdir}/MockBrain-1" policy.checkpoint(checkpoint_path, settings) # These checks taken from test_barracuda_converter assert os.path.isfile(checkpoint_path + ".nn") assert os.path.getsize(checkpoint_path + ".nn") > 100