示例#1
0
def test_agent(agent_step):
    for coef_index in range(len(CLAC_COEFS)):
        mut_coef = CLAC_COEFS[coef_index]

        if (agent_step == 1):
            print(mut_coef, "  ", NUM_TRAINING_STEPS, "  ", ENVIRONMENT_NAME,
                  "  ", FOLDER)

        features = pd.DataFrame()

        mirl_env = gym.make(ENVIRONMENT_NAME)
        mirl_env = DummyVecEnv([lambda: mirl_env])

        mirl_model = CLAC(CLAC_MlpPolicy,
                          mirl_env,
                          mut_inf_coef=mut_coef,
                          coef_schedule=3.3e-4,
                          verbose=1)

        (mirl_model, learning_results) = mirl_model.learn(
            total_timesteps=NUM_TRAINING_STEPS, log_interval=10)
        learning_results['AgentID'] = agent_step
        learning_results.to_pickle(FOLDER + "/results/MIRL_" +
                                   str(mut_coef).replace(".", "p") + "_" +
                                   str(agent_step) + "_0.pkl")

        for resample_step in range(1, NUM_RESAMPLES):
            # Set both environments to the same resampled values
            if (RANDOMIZATION_LEVEL == "Normal"):
                mirl_env.env_method("randomize", 0)
            elif (RANDOMIZATION_LEVEL == "Extreme"):
                mirl_env.env_method("randomize", 1)
            elif (RANDOMIZATION_LEVEL == "Test"):
                mirl_env.env_method("randomize", -1)
            else:
                print("Error resampling unknown value: ", RANDOMIZATION_LEVEL)
                continue

            if (agent_step == 1):
                print(mut_coef, "  ", NUM_TRAINING_STEPS, "  ",
                      ENVIRONMENT_NAME, "  ", FOLDER, " resample step ",
                      resample_step)

            (mirl_model, learning_results) = mirl_model.learn(
                total_timesteps=NUM_TRAINING_STEPS,
                reset_num_timesteps=False,
                log_interval=10)
            learning_results.to_pickle(FOLDER + "/results/MIRL_" +
                                       str(mut_coef).replace(".", "p") + "_" +
                                       str(agent_step) + "_" +
                                       str(resample_step) + ".pkl")

        mirl_model.save(FOLDER + "/models/MIRL_" +
                        str(mut_coef).replace(".", "p") + "_" +
                        str(agent_step) + "_0")

        del mirl_model
        del mirl_env