def job( random_seed: int, base_dir: Path, theta_min: float, theta_max: float, theta_dot_min: float, theta_dot_max: float, ): rng = random.PRNGKey(random_seed) rng, train_rng = random.split(rng) callback_rngs = random.split(rng, num_episodes) params = [None] tracking_params = [None] train_reward_per_episode = [] policy_value_per_episode = [] episode_lengths = [] elapsed_per_episode = [] def callback(info): episode = info['episode'] params[0] = info["optimizer"].value tracking_params[0] = info["tracking_params"] policy_value = run_ddpg.eval_policy(callback_rngs[episode], info["optimizer"].value[0]) train_reward_per_episode.append(info['reward']) policy_value_per_episode.append(policy_value) episode_lengths.append(info["episode_length"]) elapsed_per_episode.append(info["elapsed"]) run_ddpg.train( train_rng, num_episodes, lambda t, s: lax.bitwise_or( lax.ge(t, config.episode_length), lax.bitwise_or( lax.le(s[0], theta_min), lax.bitwise_or( lax.ge(s[0], theta_max), lax.bitwise_or(lax.le(s[1], theta_dot_min), lax.ge(s[1], theta_dot_max))))), callback, ) with (base_dir / f"seed={random_seed}.pkl").open(mode="wb") as f: pickle.dump( { "final_params": params[0], "final_tracking_params": tracking_params[0], "train_reward_per_episode": train_reward_per_episode, "policy_value_per_episode": policy_value_per_episode, "episode_lengths": episode_lengths, "elapsed_per_episode": elapsed_per_episode, }, f)
def main(): rng = random.PRNGKey(0) num_episodes = 10000 print(f"Loading best seed from {experiment_folder}... ", end="") best_seed_data = load_best_seed() print("done") print("Building support set... ", end="") rng, ss_rng = random.split(rng) actor_params, _ = best_seed_data["final_params"] support_set = build_support_set(ss_rng, actor_params) support_set_flat = jp.reshape(support_set, (-1, support_set.shape[-1])) # theta_min = jp.min(support_set_flat[:, 0]) - epsilon # theta_max = jp.max(support_set_flat[:, 0]) + epsilon # theta_dot_min = jp.min(support_set_flat[:, 1]) - epsilon # theta_dot_max = jp.max(support_set_flat[:, 1]) + epsilon print("done") rng, train_rng = random.split(rng) callback_rngs = random.split(rng, num_episodes) train_reward_per_episode = [] policy_value_per_episode = [] episode_lengths = [] def callback(info): episode = info['episode'] reward = info['reward'] current_actor_params = info["optimizer"].value[0] policy_value = run_ddpg.eval_policy(callback_rngs[episode], current_actor_params) print(f"Episode {episode}, " f"episode_length = {info['episode_length']}, " f"reward = {reward}, " f"policy_value = {policy_value}, " f"elapsed = {info['elapsed']}") train_reward_per_episode.append(reward) policy_value_per_episode.append(policy_value) episode_lengths.append(info["episode_length"]) # if episode == num_episodes - 1: # if episode % 5000 == 0 or episode == num_episodes - 1: # for rollout in range(5): # states, actions, _ = ddpg.rollout( # random.fold_in(callback_rngs[episode], rollout), # config.env, # policy(current_actor_params), # num_timesteps=250, # ) # viz_pendulum_rollout(states, 2 * actions / config.max_torque) run_ddpg.train( train_rng, num_episodes, # lambda t, s: lax.bitwise_or( # lax.ge(t, config.episode_length), # lax.bitwise_or( # lax.le(s[0], theta_min), # lax.bitwise_or( # lax.ge(s[0], theta_max), # lax.bitwise_or(lax.le(s[1], theta_dot_min), # lax.ge(s[1], theta_dot_max))))), # lambda t, s: lax.bitwise_or( # lax.ge(t, config.episode_length), # lax.bitwise_or(lax.ge(jp.abs(s[1]), 10.0), # lax.ge(jp.abs(s[0] - jp.pi), 0.5))), lambda loop_state: lax.bitwise_or( lax.ge(loop_state.episode_length, config.episode_length), lax.ge( jp.min( jp.sum((support_set_flat[:, :2] - loop_state.state[:2])**2, axis=1)), max_squared_dist)), callback, )