Beispiel #1
0
def save_lucid_model(config, params, *, model_path, metadata_path):
    config = config.copy()
    config.pop("num_envs")
    library = config.get("library", "baselines")
    venv = create_env(1, **config)
    arch = get_arch(**config)

    with tf.Graph().as_default(), tf.Session() as sess:
        observation_space = venv.observation_space
        observations_placeholder = tf.placeholder(shape=(None, ) +
                                                  observation_space.shape,
                                                  dtype=tf.float32)

        if library == "baselines":
            from baselines.common.policies import build_policy

            with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE):
                policy_fn = build_policy(venv, arch)
                policy = policy_fn(
                    nbatch=None,
                    nsteps=1,
                    sess=sess,
                    observ_placeholder=(observations_placeholder * 255),
                )
                pd = policy.pd
                vf = policy.vf

        else:
            raise ValueError(f"Unsupported library: {library}")

        load_params(params, sess=sess)

        Model.save(
            model_path,
            input_name=observations_placeholder.op.name,
            output_names=[pd.logits.op.name, vf.op.name],
            image_shape=observation_space.shape,
            image_value_range=[0.0, 1.0],
        )

    metadata = {
        "policy_logits_name": pd.logits.op.name,
        "value_function_name": vf.op.name,
        "env_name": config.get("env_name"),
        "gae_gamma": config.get("gamma"),
        "gae_lambda": config.get("lambda"),
    }
    env = venv
    while hasattr(env, "env") and (not hasattr(env, "combos")):
        env = env.env
    if hasattr(env, "combos"):
        metadata["action_combos"] = env.combos
    else:
        metadata["action_combos"] = None

    save_joblib(metadata, metadata_path)
    return {
        "model_bytes": read(model_path, cache=False, mode="rb"),
        **metadata
    }
Beispiel #2
0
def save_trajectories(config, params, *, trajectories_path, num_envs,
                      num_steps, full_resolution):
    with get_step_fn(config,
                     params,
                     num_envs=num_envs,
                     full_resolution=full_resolution) as step_fn:
        step_fn()
        trajectories = [step_fn() for _ in range(num_steps)]
        get_and_stack = lambda ds, key, axis=1: np.stack([d[key] for d in ds],
                                                         axis=axis)
        result = {
            "observations": get_and_stack(trajectories, "ob"),
            "actions": get_and_stack(trajectories, "ac"),
            "rewards": get_and_stack(trajectories, "reward"),
            "firsts": get_and_stack(trajectories, "first"),
        }
        if full_resolution:
            result["observations_full"] = get_and_stack(
                trajectories, "ob_full")

    save_joblib(result, trajectories_path)
    return {"trajectories": result}
Beispiel #3
0
def save_observations(config, params, *, observations_path, num_envs, num_obs,
                      obs_every, full_resolution):
    with get_step_fn(config,
                     params,
                     num_envs=num_envs,
                     full_resolution=full_resolution) as step_fn:
        observations = []
        if full_resolution:
            observations_full = []
        for _ in range(num_obs):
            for _ in range(obs_every):
                step_result = step_fn()
            observations.append(step_result["ob"])
            if full_resolution:
                observations_full.append(step_result["ob_full"])
        observations = np.concatenate(observations, axis=0)
        if full_resolution:
            observations_full = np.concatenate(observations_full, axis=0)

    result = {"observations": observations}
    if full_resolution:
        result["observations_full"] = observations_full
    save_joblib(result, observations_path)
    return result