def save_lucid_model(config, params, *, model_path, metadata_path): config = config.copy() config.pop("num_envs") library = config.get("library", "baselines") venv = create_env(1, **config) arch = get_arch(**config) with tf.Graph().as_default(), tf.Session() as sess: observation_space = venv.observation_space observations_placeholder = tf.placeholder(shape=(None, ) + observation_space.shape, dtype=tf.float32) if library == "baselines": from baselines.common.policies import build_policy with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE): policy_fn = build_policy(venv, arch) policy = policy_fn( nbatch=None, nsteps=1, sess=sess, observ_placeholder=(observations_placeholder * 255), ) pd = policy.pd vf = policy.vf else: raise ValueError(f"Unsupported library: {library}") load_params(params, sess=sess) Model.save( model_path, input_name=observations_placeholder.op.name, output_names=[pd.logits.op.name, vf.op.name], image_shape=observation_space.shape, image_value_range=[0.0, 1.0], ) metadata = { "policy_logits_name": pd.logits.op.name, "value_function_name": vf.op.name, "env_name": config.get("env_name"), "gae_gamma": config.get("gamma"), "gae_lambda": config.get("lambda"), } env = venv while hasattr(env, "env") and (not hasattr(env, "combos")): env = env.env if hasattr(env, "combos"): metadata["action_combos"] = env.combos else: metadata["action_combos"] = None save_joblib(metadata, metadata_path) return { "model_bytes": read(model_path, cache=False, mode="rb"), **metadata }
def save_trajectories(config, params, *, trajectories_path, num_envs, num_steps, full_resolution): with get_step_fn(config, params, num_envs=num_envs, full_resolution=full_resolution) as step_fn: step_fn() trajectories = [step_fn() for _ in range(num_steps)] get_and_stack = lambda ds, key, axis=1: np.stack([d[key] for d in ds], axis=axis) result = { "observations": get_and_stack(trajectories, "ob"), "actions": get_and_stack(trajectories, "ac"), "rewards": get_and_stack(trajectories, "reward"), "firsts": get_and_stack(trajectories, "first"), } if full_resolution: result["observations_full"] = get_and_stack( trajectories, "ob_full") save_joblib(result, trajectories_path) return {"trajectories": result}
def save_observations(config, params, *, observations_path, num_envs, num_obs, obs_every, full_resolution): with get_step_fn(config, params, num_envs=num_envs, full_resolution=full_resolution) as step_fn: observations = [] if full_resolution: observations_full = [] for _ in range(num_obs): for _ in range(obs_every): step_result = step_fn() observations.append(step_result["ob"]) if full_resolution: observations_full.append(step_result["ob_full"]) observations = np.concatenate(observations, axis=0) if full_resolution: observations_full = np.concatenate(observations_full, axis=0) result = {"observations": observations} if full_resolution: result["observations_full"] = observations_full save_joblib(result, observations_path) return result