Exemplo n.º 1
0
def test_result_grid_future_checkpoint(ray_start_2_cpus, to_object):
    trainable_cls = get_trainable_cls("__fake")
    trial = Trial("__fake", stub=True)
    trial.config = {"some_config": 1}
    trial.last_result = {"some_result": 2, "config": trial.config}

    trainable = ray.remote(trainable_cls).remote()
    ray.get(trainable.set_info.remote({"info": 4}))

    if to_object:
        checkpoint_data = trainable.save_to_object.remote()
    else:
        checkpoint_data = trainable.save.remote()

    trial.on_checkpoint(
        _TrackedCheckpoint(checkpoint_data,
                           storage_mode=CheckpointStorage.MEMORY))
    trial.pickled_error_file = None
    trial.error_file = None
    result_grid = ResultGrid(None)

    # Internal result grid conversion
    result = result_grid._trial_to_result(trial)
    assert isinstance(result.checkpoint, Checkpoint)
    assert isinstance(result.metrics, dict)
    assert isinstance(result.config, dict)
    assert result.metrics_dataframe is None
    assert result.config == {"some_config": 1}
    assert result.metrics["config"] == result.config

    # Load checkpoint data (see ray.rllib.algorithms.mock.MockTrainer definition)
    with result.checkpoint.as_directory() as checkpoint_dir:
        with open(os.path.join(checkpoint_dir, "mock_agent.pkl"), "rb") as f:
            info = pickle.load(f)
            assert info["info"] == 4
Exemplo n.º 2
0
    def testBestTrialStr(self):
        """Assert that custom nested parameter columns are printed correctly"""
        config = {
            "nested": {
                "conf": "nested_value"
            },
            "toplevel": "toplevel_value"
        }

        trial = Trial("", config=config, stub=True)
        trial.last_result = {"metric": 1, "config": config}

        result = best_trial_str(trial, "metric")
        self.assertIn("nested_value", result)

        result = best_trial_str(trial,
                                "metric",
                                parameter_columns=["nested/conf"])
        self.assertIn("nested_value", result)