def run_gym( params: OpenAiGymParameters, score_bar, embed_rl_dataset: RLDataset, gym_env: Env, mdnrnn: MemoryNetwork, max_embed_seq_len: int, ): assert params.rl is not None rl_parameters = params.rl env_type = params.env model_type = params.model_type epsilon, epsilon_decay, minimum_epsilon = create_epsilon( offline_train=True, rl_parameters=rl_parameters, params=params ) replay_buffer = OpenAIGymMemoryPool(params.max_replay_memory_size) for row in embed_rl_dataset.rows: replay_buffer.insert_into_memory(**row) assert replay_buffer.memory_buffer is not None state_mem = replay_buffer.memory_buffer.state state_min_value = torch.min(state_mem).item() state_max_value = torch.max(state_mem).item() state_embed_env = StateEmbedGymEnvironment( gym_env, mdnrnn, max_embed_seq_len, state_min_value, state_max_value ) open_ai_env = OpenAIGymEnvironment( state_embed_env, epsilon, rl_parameters.softmax_policy, rl_parameters.gamma, epsilon_decay, minimum_epsilon, ) rl_trainer = create_trainer(params, open_ai_env) rl_predictor = create_predictor( rl_trainer, model_type, params.use_gpu, open_ai_env.action_dim ) assert ( params.run_details.max_steps is not None and params.run_details.offline_train_epochs is not None ), "Missing data required for offline training: {}".format(str(params.run_details)) return train_gym_offline_rl( gym_env=open_ai_env, replay_buffer=replay_buffer, model_type=model_type, trainer=rl_trainer, predictor=rl_predictor, test_run_name="{} offline rl state embed".format(env_type), score_bar=score_bar, max_steps=params.run_details.max_steps, avg_over_num_episodes=params.run_details.avg_over_num_episodes, offline_train_epochs=params.run_details.offline_train_epochs, num_batch_per_epoch=None, )
def run_gym( params, use_gpu, score_bar, embed_rl_dataset: RLDataset, gym_env: Env, mdnrnn: MemoryNetwork, max_embed_seq_len: int, ): rl_parameters = RLParameters(**params["rl"]) env_type = params["env"] model_type = params["model_type"] epsilon, epsilon_decay, minimum_epsilon = create_epsilon( offline_train=True, rl_parameters=rl_parameters, params=params ) replay_buffer = OpenAIGymMemoryPool(params["max_replay_memory_size"]) for row in embed_rl_dataset.rows: replay_buffer.insert_into_memory(**row) state_mem = torch.cat([m[0] for m in replay_buffer.replay_memory]) state_min_value = torch.min(state_mem).item() state_max_value = torch.max(state_mem).item() state_embed_env = StateEmbedGymEnvironment( gym_env, mdnrnn, max_embed_seq_len, state_min_value, state_max_value ) open_ai_env = OpenAIGymEnvironment( state_embed_env, epsilon, rl_parameters.softmax_policy, rl_parameters.gamma, epsilon_decay, minimum_epsilon, ) rl_trainer = create_trainer( params["model_type"], params, rl_parameters, use_gpu, open_ai_env ) rl_predictor = create_predictor( rl_trainer, model_type, use_gpu, open_ai_env.action_dim ) return train_gym_offline_rl( open_ai_env, replay_buffer, model_type, rl_trainer, rl_predictor, "{} offline rl state embed".format(env_type), score_bar, max_steps=params["run_details"]["max_steps"], avg_over_num_episodes=params["run_details"]["avg_over_num_episodes"], offline_train_epochs=params["run_details"]["offline_train_epochs"], bcq_imitator_hyper_params=None, )