def _tc3(): env = loader.get_env("AntBulletEnv-v0", "ppo2_norm") actor = loader.get_original_policy("AntBulletEnv-v0", "ppo2_norm") obs = env.reset() s_size = len(obs) a_size = env.action_space.shape[0] q_net = QNetwork(s_size, a_size) sarsa = SARSA(env, actor, q_net, attack_type="obs") sarsa.train(1000, 100, 5, 10)
def learn_q(env_name, algo, step_num, rollout_num, update_intv, update_iter, gamma=0.99, attack_type="state", optimizer=None): env = loader.get_env(env_name, algo) actor = loader.get_original_policy(env_name, algo) if attack_type == "state": env.reset() s_size = len(env.unwrapped.state) else: s_size = env.observation_space.shape[0] a_size = env.action_space.shape[0] q_net = QNetwork(s_size, a_size) sarsa = SARSA(env, actor, q_net, optimizer, attack_type=attack_type) sarsa.train(step_num, rollout_num, update_intv, update_iter, gamma) return q_net, sarsa.losses
def _surrogate(): env = loader.get_env(env_name, algo, reward_type=None) pi = loader.get_original_policy(env_name, algo) return simulation_attack_obs(env, pi, pi_net, q_net, step_num, rollout_num // thread_number, attack_fn, attack_freq, attack_kwargs)
def _surrogate(): env = loader.get_env(env_name, algo) pi = loader.get_original_policy(env_name, algo) return simulation(env, pi, actor_net, step_num, rollout_num // thread_number, attack_fn, attack_freq, attack_kwargs)