def test_limit_duration_wrapped_env_forwards_calls(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 10) action_spec = env.action_spec() self.assertEqual((), action_spec.shape) self.assertEqual(0, action_spec.minimum) self.assertEqual(1, action_spec.maximum) observation_spec = env.observation_spec() self.assertEqual((4,), observation_spec.shape) high = np.array([ 4.8, np.finfo(np.float32).max, 2 / 15.0 * math.pi, np.finfo(np.float32).max ]) np.testing.assert_array_almost_equal(-high, observation_spec.minimum) np.testing.assert_array_almost_equal(high, observation_spec.maximum)
def __init__(self, name, stock_list, feature_set_factories, actor_fc_layers=ACTOR_FC_LAYERS, actor_dropout_layer_params=ACTOR_DROPOUT_LAYER_PARAMS, critic_observation_fc_layer_params=CRITIC_OBS_FC_LAYERS, critic_action_fc_layer_params=CRITIC_ACTION_FC_LAYERS, critic_joint_fc_layer_params=CRITIC_JOINT_FC_LAYERS, actor_alpha=ACTOR_ALPHA, critic_alpha=CRITIC_ALPHA, alpha_alpha=ALPHA_ALPHA, gamma=GAMMA): self.name = name self.stock_list = stock_list self.feature_generator = FeatureGenerator(feature_set_factories) self.reset() action_space = self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1, ), dtype=np.float32) self.gym_training_env = TrainingStockEnv(stock_list, self.feature_generator.copy(), action_space, self.convert_action) self.tf_training_env = tf_py_environment.TFPyEnvironment( gym_wrapper.GymWrapper(self.gym_training_env, discount=gamma, auto_reset=True)) self.actor = self.create_actor_network(actor_fc_layers, actor_dropout_layer_params) self.critic = self.create_critic_network( critic_observation_fc_layer_params, critic_action_fc_layer_params, critic_joint_fc_layer_params) self.tf_agent = self.create_sac_agent(self.actor, self.critic, actor_alpha, critic_alpha, alpha_alpha, gamma) self.eval_policy = self.tf_agent.policy self.eval_env = EvalEnv(self.stock_list, self) self.tf_agent.initialize()
def test_automatic_reset(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 2) # Episode 1 first_time_step = env.step(0) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(0) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(0) self.assertTrue(last_time_step.is_last()) # Episode 2 first_time_step = env.step(0) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(0) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(0) self.assertTrue(last_time_step.is_last())
def test_duration_applied_after_episode_terminates_early(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.TimeLimit(env, 10000) # Episode 1 stepped until termination occurs. time_step = env.step(np.array(1, dtype=np.int32)) while not time_step.is_last(): time_step = env.step(np.array(1, dtype=np.int32)) self.assertTrue(time_step.is_last()) env._duration = 2 # Episode 2 short duration hits step limit. first_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(last_time_step.is_last())
def test_automatic_reset(self): cartpole_env = gym.make('CartPole-v1') env = gym_wrapper.GymWrapper(cartpole_env) env = wrappers.FixedLength(env, 2) # Episode 1 first_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(last_time_step.is_last()) # Episode 2 first_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(first_time_step.is_first()) mid_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(mid_time_step.is_mid()) last_time_step = env.step(np.array(0, dtype=np.int32)) self.assertTrue(last_time_step.is_last())
def test_wrapped_cartpole_specs(self): # Note we use spec.make on gym envs to avoid getting a TimeLimit wrapper on # the environment. cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) action_spec = env.action_spec() self.assertEqual((), action_spec.shape) self.assertEqual(0, action_spec.minimum) self.assertEqual(1, action_spec.maximum) observation_spec = env.observation_spec() self.assertEqual((4, ), observation_spec.shape) self.assertEqual(np.float32, observation_spec.dtype) high = np.array([ 4.8, np.finfo(np.float32).max, 2 / 15.0 * math.pi, np.finfo(np.float32).max ]) np.testing.assert_array_almost_equal(-high, observation_spec.minimum) np.testing.assert_array_almost_equal(high, observation_spec.maximum)
def env_load_fn( environment_name, # pylint: disable=dangerous-default-value max_episode_steps=50, resize_factor=1, env_kwargs=dict(action_noise=0., start=(0, 3)), goal_env_kwargs=dict(goal=(7, 3)), terminate_on_timeout=True): """Loads the selected environment and wraps it with the specified wrappers. Args: environment_name: Name for the environment to load. max_episode_steps: If None the max_episode_steps will be set to the default step limit defined in the environment's spec. No limit is applied if set to 0 or if there is no timestep_limit set in the environment's spec. resize_factor: A factor for resizing. env_kwargs: Arguments for envs. goal_env_kwargs: Arguments for goal envs. terminate_on_timeout: Whether to set done = True when the max episode steps is reached. Returns: A PyEnvironmentBase instance. """ gym_env = PointMassEnv(environment_name, resize_factor=resize_factor, **env_kwargs) gym_env = GoalConditionedPointWrapper(gym_env, **goal_env_kwargs) env = gym_wrapper.GymWrapper(gym_env, discount=1.0, auto_reset=True, simplify_box_bounds=False) if max_episode_steps > 0: if terminate_on_timeout: env = TimeLimitBonus(env, max_episode_steps) else: env = NonTerminatingTimeLimit(env, max_episode_steps) return env
def env_load_fn(environment_name, max_episode_steps=None, resize_factor=1, gym_env_wrappers=(GoalConditionedPointWrapper,), terminate_on_timeout=False): """Loads the selected environment and wraps it with the specified wrappers. Args: environment_name: Name for the environment to load. max_episode_steps: If None the max_episode_steps will be set to the default step limit defined in the environment's spec. No limit is applied if set to 0 or if there is no timestep_limit set in the environment's spec. gym_env_wrappers: Iterable with references to wrapper classes to use directly on the gym environment. terminate_on_timeout: Whether to set done = True when the max episode steps is reached. Returns: A PyEnvironmentBase instance. """ gym_env = PointEnv(walls=environment_name, resize_factor=resize_factor) for wrapper in gym_env_wrappers: gym_env = wrapper(gym_env) env = gym_wrapper.GymWrapper( gym_env, discount=1.0, auto_reset=True, ) if max_episode_steps > 0: if terminate_on_timeout: env = wrappers.TimeLimit(env, max_episode_steps) else: env = NonTerminatingTimeLimit(env, max_episode_steps) return tf_py_environment.TFPyEnvironment(env)
def get_onpolicy_dataset(env_name, tabular_obs, policy_fn, policy_info_spec): """Gets target policy.""" if env_name == 'taxi': env = taxi.Taxi(tabular_obs=tabular_obs) elif env_name == 'grid': env = navigation.GridWalk(tabular_obs=tabular_obs) elif env_name == 'lowrank_tree': env = tree.Tree(branching=2, depth=3, duplicate=10) elif env_name == 'frozenlake': env = InfiniteFrozenLake() elif env_name == 'low_rank': env = low_rank.LowRank() else: raise ValueError('Unknown environment: %s.' % env_name) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) tf_policy = common_utils.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) return TFAgentsOnpolicyDataset(tf_env, tf_policy)
def wrap_env(env, discount=1.0, max_episode_steps=0, gym_env_wrappers=(), time_limit_wrapper=wrappers.TimeLimit, env_wrappers=(), spec_dtype_map=None, auto_reset=True): for wrapper in gym_env_wrappers: gym_env = wrapper(gym_env) env = gym_wrapper.GymWrapper(env, discount=discount, spec_dtype_map=spec_dtype_map, match_obs_space_dtype=True, auto_reset=auto_reset, simplify_box_bounds=True) if max_episode_steps > 0: env = time_limit_wrapper(env, max_episode_steps) for wrapper in env_wrappers: env = wrapper(env) return env
def test_automatic_reset_after_create(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) first_time_step = env.step(0) self.assertTrue(first_time_step.is_first())
def test_render(self): cartpole_env = gym.spec('CartPole-v1').make() cartpole_env.render = mock.MagicMock() env = gym_wrapper.GymWrapper(cartpole_env) env.render() self.assertEqual(1, cartpole_env.render.call_count)
def get_py_env(env_name, max_episode_steps=None, constant_task=None, use_neg_rew=None, margin=None): """Load an environment. Args: env_name: (str) name of the environment. max_episode_steps: (int) maximum number of steps per episode. Set to None to not include a limit. constant_task: specifies a fixed task to use for all episodes. Set to None to use tasks sampled from the task distribution. use_neg_rew: (bool) For the goal-reaching tasks, indicates whether to use a (-1, 0) sparse reward (use_neg_reward = True) or a (0, 1) sparse reward. margin: (float) For goal-reaching tasks, indicates the desired distance to the goal. Returns: env: the environment, build from a dynamics and task distribution task_distribution: the task distribution used for the environment. """ if "sawyer" in env_name: print(("ERROR: Modify utils.py to import sawyer_env and not dm_env. " "Currently the sawyer_env import is commented out to prevent " "a segfault from occuring when trying to import both sawyer_env " "and dm_env")) assert False if env_name.split("_")[0] == "point": _, walls, resize_factor = env_name.split("_") dynamics = point_env.PointDynamics( walls=walls, resize_factor=int(resize_factor)) task_distribution = point_env.PointGoalDistribution(dynamics) elif env_name.split("_")[0] == "pointTask": _, walls, resize_factor = env_name.split("_") dynamics = point_env.PointDynamics( walls=walls, resize_factor=int(resize_factor)) task_distribution = point_env.PointTaskDistribution(dynamics) elif env_name == "quadruped-run": dynamics = dm_env.QuadrupedRunDynamics() task_distribution = dm_env.QuadrupedRunTaskDistribution(dynamics) elif env_name == "quadruped": dynamics = dm_env.QuadrupedRunDynamics() task_distribution = dm_env.QuadrupedContinuousTaskDistribution(dynamics) elif env_name == "hopper": dynamics = dm_env.HopperDynamics() task_distribution = dm_env.HopperTaskDistribution(dynamics) elif env_name == "hopper-discrete": dynamics = dm_env.HopperDynamics() task_distribution = dm_env.HopperDiscreteTaskDistribution(dynamics) elif env_name == "walker": dynamics = dm_env.WalkerDynamics() task_distribution = dm_env.WalkerTaskDistribution(dynamics) elif env_name == "humanoid": dynamics = dm_env.HumanoidDynamics() task_distribution = dm_env.HumanoidTaskDistribution(dynamics) ### sparse tasks elif env_name == "finger": dynamics = dm_env.FingerDynamics() task_distribution = dm_env.FingerTaskDistribution( dynamics, use_neg_rew=use_neg_rew, margin=margin) elif env_name == "manipulator": dynamics = dm_env.ManipulatorDynamics() task_distribution = dm_env.ManipulatorTaskDistribution(dynamics) elif env_name == "point-mass": dynamics = dm_env.PointMassDynamics() task_distribution = dm_env.PointMassTaskDistribution( dynamics, use_neg_rew=use_neg_rew, margin=margin) elif env_name == "stacker": dynamics = dm_env.StackerDynamics() task_distribution = dm_env.FingerStackerDistribution( dynamics, use_neg_rew=use_neg_rew, margin=margin) elif env_name == "swimmer": dynamics = dm_env.SwimmerDynamics() task_distribution = dm_env.SwimmerTaskDistribution(dynamics) elif env_name == "fish": dynamics = dm_env.FishDynamics() task_distribution = dm_env.FishTaskDistribution(dynamics) elif env_name == "sawyer-reach": dynamics = sawyer_env.SawyerDynamics() task_distribution = sawyer_env.SawyerReachTaskDistribution(dynamics) else: raise NotImplementedError("Unknown environment: %s" % env_name) gym_env = multitask.Environment( dynamics, task_distribution, constant_task=constant_task) if max_episode_steps is not None: # Add a placeholder spec so the TimeLimit wrapper works. gym_env.spec = registration.EnvSpec("env-v0") gym_env = TimeLimit(gym_env, max_episode_steps) wrapped_env = gym_wrapper.GymWrapper(gym_env, discount=1.0, auto_reset=True) return wrapped_env, task_distribution
def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset( task_name=FLAGS.task_name, batch_size=FLAGS.batch_size) env = gym_wrapper.GymWrapper(gym_env) env = tf_py_environment.TFPyEnvironment(env) dataset_iter = iter(dataset) tf.random.set_seed(FLAGS.seed) hparam_str = utils.make_hparam_string(FLAGS.xm_parameters, algo_name=FLAGS.algo_name, seed=FLAGS.seed, task_name=FLAGS.task_name, data_name=FLAGS.data_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) result_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) if FLAGS.algo_name == 'bc': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec()) elif FLAGS.algo_name == 'bc_mix': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec(), mixture=True) elif 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG(env.observation_spec(), env.action_spec()) elif 'crr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='exp_mean') elif 'bcq' in FLAGS.algo_name: model = bcq.BCQ(env.observation_spec(), env.action_spec()) elif 'asac' in FLAGS.algo_name: model = asac.ASAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'cql' in FLAGS.algo_name: model = cql.CQL(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'brac' in FLAGS.algo_name: if 'fbrac' in FLAGS.algo_name: model = fisher_brac.FBRAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], f_reg=FLAGS.f_reg, reward_bonus=FLAGS.reward_bonus) else: model = brac.BRAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) model_folder = os.path.join( FLAGS.save_dir, 'models', f'{FLAGS.task_name}_{FLAGS.data_name}_{FLAGS.seed}') if not tf.gfile.io.isdir(model_folder): bc_pretraining_steps = 1_000_000 for i in tqdm.tqdm(range(bc_pretraining_steps)): info_dict = model.bc.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i - bc_pretraining_steps) # model.bc.policy.save_weights(os.path.join(model_folder, 'model')) else: model.bc.policy.load_weights(os.path.join(model_folder, 'model')) for i in tqdm.tqdm(range(FLAGS.num_updates)): with summary_writer.as_default(): info_dict = model.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i) if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate(env, model) if FLAGS.data_name is None: average_returns = gym_env.get_normalized_score( average_returns) * 100.0 with result_writer.as_default(): tf.summary.scalar('evaluation/returns', average_returns, step=i + 1) tf.summary.scalar('evaluation/length', average_length, step=i + 1)
def get_env_and_policy(load_dir, env_name, alpha, env_seed=0, tabular_obs=False): if env_name == 'taxi': env = taxi.Taxi(tabular_obs=tabular_obs) env.seed(env_seed) policy_fn, policy_info_spec = taxi.get_taxi_policy(load_dir, env, alpha=alpha, py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'grid': env = navigation.GridWalk(tabular_obs=tabular_obs) env.seed(env_seed) policy_fn, policy_info_spec = navigation.get_navigation_policy( env, epsilon_explore=0.1 + 0.6 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'low_rank': env = low_rank.LowRank() env.seed(env_seed) policy_fn, policy_info_spec = low_rank.get_low_rank_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'tree': env = tree.Tree(branching=2, depth=10) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'lowrank_tree': env = tree.Tree(branching=2, depth=3, duplicate=10) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name.startswith('bandit'): num_arms = int(env_name[6:]) if len(env_name) > 6 else 2 env = bandit.Bandit(num_arms=num_arms) env.seed(env_seed) policy_fn, policy_info_spec = bandit.get_bandit_policy( env, epsilon_explore=1 - alpha, py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'small_tree': env = tree.Tree(branching=2, depth=3, loop=True) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'CartPole-v0': tf_env, policy = get_env_and_dqn_policy( env_name, os.path.join(load_dir, 'CartPole-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.3 + 0.15 * (1 - alpha)) elif env_name == 'cartpole': # Infinite-horizon cartpole. tf_env, policy = get_env_and_dqn_policy( 'CartPole-v0', os.path.join(load_dir, 'CartPole-v0-250', 'train', 'policy'), env_seed=env_seed, epsilon=0.3 + 0.15 * (1 - alpha)) env = InfiniteCartPole() tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) elif env_name == 'FrozenLake-v0': tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0', os.path.join( load_dir, 'FrozenLake-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.2 * (1 - alpha), ckpt_file='ckpt-100000') elif env_name == 'frozenlake': # Infinite-horizon frozenlake. tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0', os.path.join( load_dir, 'FrozenLake-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.2 * (1 - alpha), ckpt_file='ckpt-100000') env = InfiniteFrozenLake() tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) elif env_name in ['Reacher-v2', 'reacher']: if env_name == 'Reacher-v2': env = suites.load_mujoco(env_name) else: env = gym_wrapper.GymWrapper(InfiniteReacher()) env.seed(env_seed) tf_env = tf_py_environment.TFPyEnvironment(env) sac_policy = get_sac_policy(tf_env) directory = os.path.join(load_dir, 'Reacher-v2', 'train', 'policy') policy = load_policy(sac_policy, env_name, directory) policy = GaussianPolicy(policy, 0.4 - 0.3 * alpha, emit_log_probability=True) elif env_name == 'HalfCheetah-v2': env = suites.load_mujoco(env_name) env.seed(env_seed) tf_env = tf_py_environment.TFPyEnvironment(env) sac_policy = get_sac_policy(tf_env) directory = os.path.join(load_dir, env_name, 'train', 'policy') policy = load_policy(sac_policy, env_name, directory) policy = GaussianPolicy(policy, 0.2 - 0.1 * alpha, emit_log_probability=True) else: raise ValueError('Unrecognized environment %s.' % env_name) return tf_env, policy
def load_carla_env(env_name='carla-v0', discount=1.0, number_of_vehicles=100, number_of_walkers=0, display_size=256, max_past_step=1, dt=0.1, discrete=False, discrete_acc=[-3.0, 0.0, 3.0], discrete_steer=[-0.2, 0.0, 0.2], continuous_accel_range=[-3.0, 3.0], continuous_steer_range=[-0.3, 0.3], ego_vehicle_filter='vehicle.lincoln*', port=2000, town='Town03', task_mode='random', max_time_episode=500, max_waypt=12, obs_range=32, lidar_bin=0.5, d_behind=12, out_lane_thres=2.0, desired_speed=8, max_ego_spawn_times=200, display_route=True, pixor_size=64, pixor=False, obs_channels=None, action_repeat=1): """Loads train and eval environments.""" env_params = { 'number_of_vehicles': number_of_vehicles, 'number_of_walkers': number_of_walkers, 'display_size': display_size, # screen size of bird-eye render 'max_past_step': max_past_step, # the number of past steps to draw 'dt': dt, # time interval between two frames 'discrete': discrete, # whether to use discrete control space 'discrete_acc': discrete_acc, # discrete value of accelerations 'discrete_steer': discrete_steer, # discrete value of steering angles 'continuous_accel_range': continuous_accel_range, # continuous acceleration range 'continuous_steer_range': continuous_steer_range, # continuous steering angle range 'ego_vehicle_filter': ego_vehicle_filter, # filter for defining ego vehicle 'port': port, # connection port 'town': town, # which town to simulate 'task_mode': task_mode, # mode of the task, [random, roundabout (only for Town03)] 'max_time_episode': max_time_episode, # maximum timesteps per episode 'max_waypt': max_waypt, # maximum number of waypoints 'obs_range': obs_range, # observation range (meter) 'lidar_bin': lidar_bin, # bin size of lidar sensor (meter) 'd_behind': d_behind, # distance behind the ego vehicle (meter) 'out_lane_thres': out_lane_thres, # threshold for out of lane 'desired_speed': desired_speed, # desired speed (m/s) 'max_ego_spawn_times': max_ego_spawn_times, # maximum times to spawn ego vehicle 'display_route': display_route, # whether to render the desired route 'pixor_size': pixor_size, # size of the pixor labels 'pixor': pixor, # whether to output PIXOR observation } gym_spec = gym.spec(env_name) gym_env = gym_spec.make(params=env_params) if obs_channels: gym_env = filter_observation_wrapper.FilterObservationWrapper( gym_env, obs_channels) py_env = gym_wrapper.GymWrapper( gym_env, discount=discount, auto_reset=True, ) eval_py_env = py_env if action_repeat > 1: py_env = wrappers.ActionRepeat(py_env, action_repeat) return py_env, eval_py_env
def env_load_fn(environment_name='DrunkSpiderShort', max_episode_steps=50, resize_factor=(1, 1), terminate_on_timeout=True, start=(0, 3), goal=(7, 3), goal_bounds=[(6, 2), (7, 4)], fall_penalty=0., reset_on_fall=False, gym_env_wrappers=[], gym=False): """Loads the selected environment and wraps it with the specified wrappers. Args: environment_name: Name for the environment to load. max_episode_steps: If None the max_episode_steps will be set to the default step limit defined in the environment's spec. No limit is applied if set to 0 or if there is no timestep_limit set in the environment's spec. resize_factor: A factor for resizing. terminate_on_timeout: Whether to set done = True when the max episode steps is reached. Returns: A PyEnvironmentBase instance. """ if resize_factor != (1, 1): if start: start = (start[0] * resize_factor[0], start[1] * resize_factor[1]) if goal: goal = (goal[0] * resize_factor[0], goal[1] * resize_factor[1]) if goal_bounds: goal_bounds = [(g[0] * resize_factor[0], g[1] * resize_factor[1]) for g in goal_bounds] if 'acnoise' in environment_name.split('-'): environment_name = environment_name.split('-')[0] gym_env = PointMassAcNoiseEnv(start=start, env_name=environment_name, resize_factor=resize_factor) elif 'acscale' in environment_name.split('-'): environment_name = environment_name.split('-')[0] gym_env = PointMassAcScaleEnv(start=start, env_name=environment_name, resize_factor=resize_factor) else: gym_env = PointMassEnv(start=start, env_name=environment_name, resize_factor=resize_factor) gym_env = GoalConditionedPointWrapper(gym_env, goal=goal, goal_bounds=goal_bounds, fall_penalty=-abs(fall_penalty), reset_on_fall=reset_on_fall, max_episode_steps=max_episode_steps) for wrapper in gym_env_wrappers: gym_env = wrapper(gym_env) if gym: return gym_env from tf_agents.environments import gym_wrapper env = gym_wrapper.GymWrapper(gym_env, discount=1.0, auto_reset=True, simplify_box_bounds=False) if max_episode_steps > 0: if terminate_on_timeout: env = TimeLimitBonus(env, max_episode_steps) else: env = NonTerminatingTimeLimit(env, max_episode_steps) return env
def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset( task_name=FLAGS.task_name, batch_size=FLAGS.batch_size) env = gym_wrapper.GymWrapper(gym_env) env = tf_py_environment.TFPyEnvironment(env) dataset_iter = iter(dataset) tf.random.set_seed(FLAGS.seed) hparam_str = f'{FLAGS.algo_name}_{FLAGS.task_name}_seed={FLAGS.seed}' summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) result_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) if FLAGS.algo_name == 'bc': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec()) else: model = fisher_brc.FBRC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], f_reg=FLAGS.f_reg, reward_bonus=FLAGS.reward_bonus) for i in tqdm.tqdm(range(FLAGS.bc_pretraining_steps)): info_dict = model.bc.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i - FLAGS.bc_pretraining_steps) for i in tqdm.tqdm(range(FLAGS.num_updates)): with summary_writer.as_default(): info_dict = model.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i) if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate(env, model) average_returns = gym_env.get_normalized_score( average_returns) * 100.0 with result_writer.as_default(): tf.summary.scalar('evaluation/returns', average_returns, step=i + 1) tf.summary.scalar('evaluation/length', average_length, step=i + 1)
window_size = 10 num_eval_episodes = 30 network_shape = (128, 128,) learning_rate = 1e-3 tau = 0.1 gradient_clipping = 1 policy_dir = 'policy' train_py_env = gym_wrapper.GymWrapper(StockTradingEnv(df=train_data, window_size=window_size)) eval_py_env = gym_wrapper.GymWrapper(StockTradingEnv(df=eval_data, window_size=window_size, eval=True)) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) def eval_policy(policy, render=False): total_reward = 0 for _ in range(num_eval_episodes): time_step = eval_env.reset() episode_reward = 0 while not time_step.is_last(): time_step = eval_env.step(policy.action(time_step).action) episode_reward += time_step.reward
def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) def preprocess_fn(dataset): return dataset.cache().shuffle(1_000_000, reshuffle_each_iteration=True) def state_mask_fn(states): if FLAGS.state_mask_dims == 0: return states assert FLAGS.state_mask_dims <= states.shape[1] state_mask_dims = ( states.shape[1] if FLAGS.state_mask_dims == -1 else FLAGS.state_mask_dims) if FLAGS.state_mask_index == 'fixed': mask_indices = range(states.shape[1] - state_mask_dims, states.shape[1]) else: mask_indices = np.random.permutation(np.arange( states.shape[1]))[:state_mask_dims] if FLAGS.state_mask_value == 'gaussian': mask_values = states[:, mask_indices] mask_values = ( mask_values + np.std(mask_values, axis=0) * np.random.normal(size=mask_values.shape)) elif 'quantize' in FLAGS.state_mask_value: mask_values = states[:, mask_indices] mask_values = np.around( mask_values, decimals=int(FLAGS.state_mask_value[-1])) else: mask_values = 0 states[:, mask_indices] = mask_values return states gym_env, dataset, embed_dataset = d4rl_utils.create_d4rl_env_and_dataset( task_name=FLAGS.task_name, batch_size=FLAGS.batch_size, sliding_window=FLAGS.embed_training_window, state_mask_fn=state_mask_fn) downstream_embed_dataset = None if (FLAGS.downstream_task_name is not None or FLAGS.downstream_data_name is not None or FLAGS.downstream_data_size is not None): downstream_data_name = FLAGS.downstream_data_name assert downstream_data_name is None gym_env, dataset, downstream_embed_dataset = d4rl_utils.create_d4rl_env_and_dataset( task_name=FLAGS.downstream_task_name, batch_size=FLAGS.batch_size, sliding_window=FLAGS.embed_training_window, data_size=FLAGS.downstream_data_size, state_mask_fn=state_mask_fn) if FLAGS.proportion_downstream_data: zipped_dataset = tf.data.Dataset.zip((embed_dataset, downstream_embed_dataset)) def combine(*elems1_and_2): batch_size = tf.shape(elems1_and_2[0][0])[0] which = tf.random.uniform([batch_size]) >= FLAGS.proportion_downstream_data from1 = tf.where(which) from2 = tf.where(tf.logical_not(which)) new_elems = map( lambda x: tf.concat([tf.gather_nd(x[0], from1), tf.gather_nd(x[1], from2)], 0), zip(*elems1_and_2)) return tuple(new_elems) embed_dataset = zipped_dataset.map(combine) if FLAGS.embed_learner and 'action' in FLAGS.embed_learner: assert FLAGS.embed_training_window >= 2 dataset = downstream_embed_dataset or embed_dataset if FLAGS.downstream_mode == 'online': downstream_task = FLAGS.downstream_task_name or FLAGS.task_name try: train_gym_env = gym.make(downstream_task) except: train_gym_env = gym.make('DM-' + downstream_task) train_env = gym_wrapper.GymWrapper(train_gym_env) train_env = tf_py_environment.TFPyEnvironment(train_env) replay_spec = ( train_env.observation_spec(), train_env.action_spec(), train_env.reward_spec(), train_env.reward_spec(), # discount spec train_env.observation_spec(), # next observation spec ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( replay_spec, batch_size=1, max_length=FLAGS.num_updates, dataset_window_shift=1 if get_ctx_length() else None) @tf.function def add_to_replay(state, action, reward, discount, next_states, replay_buffer=replay_buffer): replay_buffer.add_batch((state, action, reward, discount, next_states)) initial_collect_policy = random_tf_policy.RandomTFPolicy( train_env.time_step_spec(), train_env.action_spec()) dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=FLAGS.batch_size, num_steps=FLAGS.embed_training_window if get_ctx_length() else None).prefetch(3) dataset = dataset.map(lambda *data: data[0]) else: train_env = None replay_buffer = None add_to_replay = None initial_collect_policy = None env = gym_wrapper.GymWrapper(gym_env) env = tf_py_environment.TFPyEnvironment(env) dataset_iter = iter(dataset) embed_dataset_iter = iter(embed_dataset) if embed_dataset else None tf.random.set_seed(FLAGS.seed) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb')) result_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results')) if (FLAGS.state_embed_dim or FLAGS.state_action_embed_dim ) and FLAGS.embed_learner and FLAGS.embed_pretraining_steps != 0: embed_model = get_embed_model(env) if FLAGS.finetune: other_embed_model = get_embed_model(env) other_embed_model2 = get_embed_model(env) else: other_embed_model = None other_embed_model2 = None else: embed_model = None other_embed_model = None other_embed_model2 = None config_str = f'{FLAGS.task_name}_{FLAGS.embed_learner}_{FLAGS.state_embed_dim}_{FLAGS.state_embed_dists}_{FLAGS.embed_training_window}_{FLAGS.downstream_input_mode}_{FLAGS.finetune}_{FLAGS.network}_{FLAGS.seed}' if FLAGS.embed_learner == 'acl': config_str += f'_{FLAGS.predict_actions}_{FLAGS.policy_decoder_on_embeddings}_{FLAGS.reward_decoder_on_embeddings}_{FLAGS.predict_rewards}_{FLAGS.embed_on_input}_{FLAGS.extra_embedder}_{FLAGS.positional_encoding_type}_{FLAGS.direction}' elif FLAGS.embed_learner and 'action' in FLAGS.embed_learner: config_str += f'_{FLAGS.state_action_embed_dim}_{FLAGS.state_action_fourier_dim}' save_dir = os.path.join(FLAGS.save_dir, config_str) # Embed pretraining if FLAGS.embed_pretraining_steps > 0 and embed_model is not None: model_folder = os.path.join( save_dir, 'embed_models%d' % FLAGS.embed_pretraining_steps, config_str) if not tf.io.gfile.isdir(model_folder): embed_pretraining_steps = FLAGS.embed_pretraining_steps for i in tqdm.tqdm(range(embed_pretraining_steps)): embed_dict = embed_model.update_step(embed_dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in embed_dict.items(): tf.summary.scalar(f'embed/{k}', v, step=i-embed_pretraining_steps) print(k, v) print('embed pretraining') embed_model.save_weights(os.path.join(model_folder, 'embed')) else: time.sleep(np.random.randint(5, 20)) # Try to suppress checksum errors. embed_model.load_weights(os.path.join(model_folder, 'embed')) if other_embed_model and other_embed_model2: try: # Try to suppress checksum errors. other_embed_model.load_weights(os.path.join(model_folder, 'embed')) other_embed_model2.load_weights(os.path.join(model_folder, 'embed')) except: embed_model.save_weights(os.path.join(model_folder, 'embed')) other_embed_model.load_weights(os.path.join(model_folder, 'embed')) other_embed_model2.load_weights(os.path.join(model_folder, 'embed')) if FLAGS.algo_name == 'bc': hidden_dims = ([] if FLAGS.network == 'none' else (256,) if FLAGS.network == 'small' else (256, 256)) model = behavioral_cloning.BehavioralCloning( env.observation_spec().shape[0], env.action_spec(), hidden_dims=hidden_dims, embed_model=embed_model, finetune=FLAGS.finetune) elif FLAGS.algo_name == 'latent_bc': hidden_dims = ([] if FLAGS.network == 'none' else (256,) if FLAGS.network == 'small' else (256, 256)) model = latent_behavioral_cloning.LatentBehavioralCloning( env.observation_spec().shape[0], env.action_spec(), hidden_dims=hidden_dims, embed_model=embed_model, finetune=FLAGS.finetune, finetune_primitive=FLAGS.finetune_primitive, learning_rate=FLAGS.latent_bc_lr, latent_bc_lr_decay=FLAGS.latent_bc_lr_decay, kl_regularizer=FLAGS.kl_regularizer) elif 'sac' in FLAGS.algo_name: model = sac.SAC( env.observation_spec().shape[0], env.action_spec(), target_entropy=-env.action_spec().shape[0], embed_model=embed_model, other_embed_model=other_embed_model, network=FLAGS.network, finetune=FLAGS.finetune) elif 'brac' in FLAGS.algo_name: model = brac.BRAC( env.observation_spec().shape[0], env.action_spec(), target_entropy=-env.action_spec().shape[0], embed_model=embed_model, other_embed_model=other_embed_model, bc_embed_model=other_embed_model2, network=FLAGS.network, finetune=FLAGS.finetune) # Agent pretraining. if not tf.io.gfile.isdir(os.path.join(save_dir, 'model')): bc_pretraining_steps = 200_000 for i in tqdm.tqdm(range(bc_pretraining_steps)): if get_ctx_length(): info_dict = model.bc.update_step(embed_dataset_iter) else: info_dict = model.bc.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar( f'training/{k}', v, step=i - bc_pretraining_steps) print('bc pretraining') model.bc.policy.save_weights(os.path.join(save_dir, 'model')) else: model.bc.policy.load_weights(os.path.join(save_dir, 'model')) if train_env: timestep = train_env.reset() else: timestep = None actor = None if hasattr(model, 'actor'): actor = model.actor elif hasattr(model, 'policy'): actor = model.policy ctx_states = [] ctx_actions = [] ctx_rewards = [] for i in tqdm.tqdm(range(FLAGS.num_updates)): if (train_env and timestep and replay_buffer and initial_collect_policy and add_to_replay and actor): if timestep.is_last(): timestep = train_env.reset() if replay_buffer.num_frames() < FLAGS.num_random_actions: policy_step = initial_collect_policy.action(timestep) action = policy_step.action ctx_states.append(state_mask_fn(timestep.observation.numpy())) ctx_actions.append(action) ctx_rewards.append(timestep.reward) else: states = state_mask_fn(timestep.observation.numpy()) actions = None rewards = None if get_ctx_length(): ctx_states.append(states) states = tf.stack(ctx_states[-get_ctx_length():], axis=1) actions = tf.stack(ctx_actions[-get_ctx_length() + 1:], axis=1) rewards = tf.stack(ctx_rewards[-get_ctx_length() + 1:], axis=1) if hasattr(model, 'embed_model') and model.embed_model: states = model.embed_model(states, actions, rewards) action = actor(states, sample=True) ctx_actions.append(action) next_timestep = train_env.step(action) ctx_rewards.append(next_timestep.reward) add_to_replay( state_mask_fn(timestep.observation.numpy()), action, next_timestep.reward, next_timestep.discount, state_mask_fn(next_timestep.observation.numpy())) timestep = next_timestep with summary_writer.as_default(): if embed_model and FLAGS.embed_pretraining_steps == -1: embed_dict = embed_model.update_step(embed_dataset_iter) if other_embed_model: other_embed_dict = other_embed_model.update_step(embed_dataset_iter) embed_dict.update(dict(('other_%s' % k, v) for k, v in other_embed_dict.items())) else: embed_dict = {} if FLAGS.downstream_mode == 'offline': if get_ctx_length(): info_dict = model.update_step(embed_dataset_iter) else: info_dict = model.update_step(dataset_iter) elif i + 1 >= FLAGS.num_random_actions: info_dict = model.update_step(dataset_iter) else: info_dict = {} if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i) for k, v in embed_dict.items(): tf.summary.scalar(f'embed/{k}', v, step=i) print(k, v) if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate( env, model, ctx_length=get_ctx_length(), embed_training_window=(FLAGS.embed_training_window if FLAGS.embed_learner and 'action' in FLAGS.embed_learner else None), state_mask_fn=state_mask_fn if FLAGS.state_mask_eval else None) average_returns = gym_env.get_normalized_score(average_returns) * 100.0 with result_writer.as_default(): tf.summary.scalar('evaluation/returns', average_returns, step=i+1) tf.summary.scalar('evaluation/length', average_length, step=i+1) print('evaluation/returns', average_returns) print('evaluation/length', average_length)
def train_eval( root_dir, load_root_dir=None, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, initial_collect_driver_class=None, collect_driver_class=None, online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, eval_metrics=None, train_metrics_callback=None, # Params for SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args train_sc_steps=10, train_sc_interval=1000, online_critic=False, n_envs=None, finetune_sc=False, # Ensemble Critic training args n_critics=30, critic_learning_rate=3e-4, # Wcpg Critic args critic_preprocessing_layer_size=256, actor_preprocessing_layer_size=256, # Params for train batch_size=256, # Params for eval run_eval=False, num_eval_episodes=1, max_episode_len=500, eval_interval=10000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, monitor_interval=1000, summaries_flush_secs=10, early_termination_fn=None, debug_summaries=False, seed=None, eager_debug=False, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" n_envs = n_envs or num_eval_episodes root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] eval_metrics = eval_metrics or [] sc_metrics = eval_metrics or [] if online_critic: sc_dir = os.path.join(root_dir, 'sc') sc_summary_writer = tf.compat.v2.summary.create_file_writer( sc_dir, flush_millis=summaries_flush_secs * 1000) sc_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs, name='SafeAverageReturn'), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=n_envs, name='SafeAverageEpisodeLength') ] + [tf_py_metric.TFPyMetric(m) for m in sc_metrics] sc_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: sc_tf_env.seed([seed + i for i in range(n_envs)]) if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=n_envs), ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: eval_tf_env.seed([seed + n_envs + i for i in range(n_envs)]) if monitor: vid_path = os.path.join(root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if seed: tf_env.seed(seed + 2 * n_envs + i for i in range(n_envs)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() logging.debug('obs spec: %s', observation_spec) logging.debug('action spec: %s', action_spec) if agent_class: #is not wcpg_agent.WcpgAgent: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) else: alpha_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.float32, minimum=0., maximum=1., name='alpha') input_tensor_spec = (observation_spec, action_spec, alpha_spec) critic_preprocessing_layers = ( tf.keras.layers.Dense(critic_preprocessing_layer_size), tf.keras.layers.Dense(critic_preprocessing_layer_size), tf.keras.layers.Lambda(lambda x: x)) critic_net = agents.DistributionalCriticNetwork( input_tensor_spec, joint_fc_layer_params=critic_joint_fc_layers) actor_preprocessing_layers = ( tf.keras.layers.Dense(actor_preprocessing_layer_size), tf.keras.layers.Dense(actor_preprocessing_layer_size), tf.keras.layers.Lambda(lambda x: x)) actor_net = agents.WcpgActorNetwork( input_tensor_spec, preprocessing_layers=actor_preprocessing_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) elif agent_class is ensemble_sac_agent.EnsembleSacAgent: critic_nets, critic_optimizers = [critic_net], [ tf.keras.optimizers.Adam(critic_learning_rate) ] for _ in range(n_critics - 1): critic_nets.append( agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers)) critic_optimizers.append( tf.keras.optimizers.Adam(critic_learning_rate)) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_nets, critic_optimizers=critic_optimizers, debug_summaries=debug_summaries) else: # assume is using SacAgent logging.debug(critic_net.input_tensor_spec) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.debug('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) logging.debug('RB capacity: %i', replay_buffer.capacity) logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=max_episode_len * num_eval_episodes) agent_observers.append(online_replay_buffer.add_batch) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = online_replay_buffer.clear train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy else: eval_policy = tf_agent.policy # pylint: disable=protected-access collect_policy = tf_agent.collect_policy # pylint: disable=protected-access online_collect_policy = tf_agent._safe_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer = common.Checkpointer( ckpt_dir=sc_dir, safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.initialize_or_restore() env_metrics = [] if env_metric_factories: for env_metric in env_metric_factories: env_metrics.append( tf_py_metric.TFPyMetric(env_metric([py_env.gym]))) # TODO: get env factory with parallel py envs # if run_eval: # eval_metrics.append(env_metric([env.gym for env in eval_tf_env.pyenv._envs])) # if online_critic: # sc_metrics.append(env_metric([env.gym for env in sc_tf_env.pyenv._envs])) collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics + env_metrics) if online_critic: logging.debug('online driver class: %s', online_driver_class) if online_driver_class is safe_dynamic_episode_driver.SafeDynamicEpisodeDriver: online_temp_buffer = episodic_replay_buffer.EpisodicReplayBuffer( collect_data_spec) online_temp_buffer_stateful = episodic_replay_buffer.StatefulEpisodicReplayBuffer( online_temp_buffer, num_episodes=num_eval_episodes) online_driver = safe_dynamic_episode_driver.SafeDynamicEpisodeDriver( sc_tf_env, online_collect_policy, online_temp_buffer, online_replay_buffer, observers=[online_temp_buffer_stateful.add_batch] + sc_metrics, num_episodes=num_eval_episodes) else: online_driver = online_driver_class( sc_tf_env, online_collect_policy, observers=[online_replay_buffer.add_batch] + sc_metrics, num_episodes=num_eval_episodes) online_driver.run = common.function(online_driver.run) if not eager_debug: config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if agent_class is sac_agent.SacAgent: collect_driver.run = common.function(collect_driver.run) if eager_debug: tf.config.experimental_run_functions_eagerly(True) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics + env_metrics).run() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) if online_critic: last_id = online_replay_buffer._get_last_id() # pylint: disable=protected-access logging.debug( 'Data saved in online buffer after initial collection: %d steps', last_id) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) critic_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.TruePositives(name='safety_critic_tp'), tf.keras.metrics.FalsePositives(name='safety_critic_fp'), tf.keras.metrics.TrueNegatives(name='safety_critic_tn'), tf.keras.metrics.FalseNegatives(name='safety_critic_fn'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc') ] @common.function def critic_train_step(): """Builds critic training step.""" start_time = time.time() experience, buf_info = next(online_iterator) if env_name.split('-')[0] in SAFETY_ENVS: safe_rew = experience.observation['task_agn_rew'][:, 1] else: safe_rew = misc.process_replay_buffer(online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew, metrics=critic_metrics, weights=None) logging.debug('critic train step: {} sec'.format(time.time() - start_time)) return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: logging.debug('starting safety critic pretraining') safety_eps = tf_agent._safe_policy._safety_threshold tf_agent._safe_policy._safety_threshold = 0.6 resample_counter = online_collect_policy._resample_counter mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_freq') # don't fine-tune safety critic if (global_step.numpy() == 0 and load_root_dir is None): for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable tf_agent._safe_policy._safety_threshold = safety_eps logging.debug('starting policy pretraining') while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() current_step = global_step.numpy() if online_critic: mean_resample_ac(resample_counter.result()) resample_counter.reset() if time_step is None or time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) logging.debug('policy eval: {} sec'.format(time.time() - start_time)) train_time = time.time() for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if current_step == 0: logging.debug('train policy: {} sec'.format(time.time() - train_time)) if online_critic and current_step % train_sc_interval == 0: batch_time_step = sc_tf_env.reset() batch_policy_state = online_collect_policy.get_initial_state( sc_tf_env.batch_size) online_driver.run(time_step=batch_time_step, policy_state=batch_policy_state) for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable metric_utils.log_metrics(sc_metrics) with sc_summary_writer.as_default(): for sc_metric in sc_metrics: sc_metric.tf_summaries(train_step=global_step, step_metrics=sc_metrics[:2]) tf.compat.v2.summary.scalar(name='resample_ac_freq', data=resample_ac_freq, step=global_step) total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True logging.debug( 'Loss diverged, critic_loss: %s, actor_loss: %s, alpha_loss: %s', train_loss.extra.critic_loss, train_loss.extra.actor_loss, train_loss.extra.alpha_loss) break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if current_step % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 train_results = [] for train_metric in train_metrics: if isinstance(train_metric, (metrics.AverageEarlyFailureMetric, metrics.AverageFallenMetric, metrics.AverageSuccessMetric)): # Plot failure as a fn of return train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:3]) else: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) train_results.append( (train_metric.name, train_metric.result().numpy())) if env_metrics: for env_metric in env_metrics: env_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) train_results.append( (env_metric.name, env_metric.result().numpy())) if online_critic: for critic_metric in critic_metrics: train_results.append( (critic_metric.name, critic_metric.result().numpy())) critic_metric.reset_states() if train_metrics_callback is not None: train_metrics_callback(collections.OrderedDict(train_results), global_step.numpy()) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save( global_step=global_step_val) if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) elif online_critic: clear_rb() if run_eval and global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) if monitor and current_step % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 monitor_py_env.reset() logging.debug( 'saved rollout at timestep {}, rollout length: {}, {} sec'. format(global_step_val, ep_len, time.time() - monitor_start)) logging.debug('iteration time: {} sec'.format(time.time() - start_time)) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
num_eval_episodes = 10 # @param {type:"integer"} eval_interval = 1000 # @param {type:"integer"} tempdir = '/content/drive/MyDrive/5242/Project' # @param {type:"string"} #env_name = 'MountainCar-v0' #train_py_env = suite_gym.load(env_name) #eval_py_env = suite_gym.load(env_name) #train_env = tf_py_environment.TFPyEnvironment(train_py_env) #eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) train_py_env = gym_wrapper.GymWrapper( ChangeRewardMountainCarEnv(), discount=1, spec_dtype_map=None, auto_reset=True, render_kwargs=None, ) eval_py_env = gym_wrapper.GymWrapper( ChangeRewardMountainCarEnv(), discount=1, spec_dtype_map=None, auto_reset=True, render_kwargs=None, ) train_py_env = wrappers.TimeLimit(train_py_env, duration=200) eval_py_env = wrappers.TimeLimit(eval_py_env, duration=200) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
time_step = tf_environment.reset() while not time_step.is_last(): py_environment.render() action_step = policy.action(time_step) time_step = tf_environment.step(action_step.action) py_environment.close() tf_environment.close() df = pd.read_csv(data_path) env = gym_wrapper.GymWrapper( gym.make(env_name, df=df, window_size=window_size, frame_bound=(window_size, len(df)))) train_env = tf_py_environment.TFPyEnvironment(env) eval_env = tf_py_environment.TFPyEnvironment(env) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent(train_env.time_step_spec(), train_env.action_spec(),
def test_obs_dtype(self): cartpole_env = gym.spec('CartPole-v1').make() env = gym_wrapper.GymWrapper(cartpole_env) time_step = env.reset() self.assertEqual(env.observation_spec().dtype, time_step.observation.dtype)
def main(_): tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) hparam_str = make_hparam_string(seed=FLAGS.seed, env_name=FLAGS.env_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) summary_writer.set_as_default() if FLAGS.d4rl: d4rl_env = gym.make(FLAGS.env_name) gym_spec = gym.spec(FLAGS.env_name) if gym_spec.max_episode_steps in [0, None]: # Add TimeLimit wrapper. gym_env = time_limit.TimeLimit(d4rl_env, max_episode_steps=1000) else: gym_env = d4rl_env gym_env.seed(FLAGS.seed) env = tf_py_environment.TFPyEnvironment( gym_wrapper.GymWrapper(gym_env)) behavior_dataset = D4rlDataset( d4rl_env, normalize_states=FLAGS.normalize_states, normalize_rewards=FLAGS.normalize_rewards, noise_scale=FLAGS.noise_scale, bootstrap=FLAGS.bootstrap) else: env = suite_mujoco.load(FLAGS.env_name) env.seed(FLAGS.seed) env = tf_py_environment.TFPyEnvironment(env) data_file_name = os.path.join( FLAGS.data_dir, FLAGS.env_name, '0', f'dualdice_{FLAGS.behavior_policy_std}.pckl') behavior_dataset = Dataset(data_file_name, FLAGS.num_trajectories, normalize_states=FLAGS.normalize_states, normalize_rewards=FLAGS.normalize_rewards, noise_scale=FLAGS.noise_scale, bootstrap=FLAGS.bootstrap) tf_dataset = behavior_dataset.with_uniform_sampling( FLAGS.sample_batch_size) tf_dataset_iter = iter(tf_dataset) if FLAGS.d4rl: with tf.io.gfile.GFile(FLAGS.d4rl_policy_filename, 'rb') as f: policy_weights = pickle.load(f) actor = utils.D4rlActor(env, policy_weights, is_dapg='dapg' in FLAGS.d4rl_policy_filename) else: actor = Actor(env.observation_spec().shape[0], env.action_spec()) actor.load_weights(behavior_dataset.model_filename) policy_returns = utils.estimate_monte_carlo_returns( env, FLAGS.discount, actor, FLAGS.target_policy_std, FLAGS.num_mc_episodes) logging.info('Estimated Per-Step Average Returns=%f', policy_returns) if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo: model = QFitter(env.observation_spec().shape[0], env.action_spec().shape[0], FLAGS.lr, FLAGS.weight_decay, FLAGS.tau) elif 'mb' in FLAGS.algo: model = ModelBased(env.observation_spec().shape[0], env.action_spec().shape[0], learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif 'dual_dice' in FLAGS.algo: model = DualDICE(env.observation_spec().shape[0], env.action_spec().shape[0], FLAGS.weight_decay) if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo: behavior = BehaviorCloning(env.observation_spec().shape[0], env.action_spec(), FLAGS.lr, FLAGS.weight_decay) @tf.function def get_target_actions(states): return actor(tf.cast(behavior_dataset.unnormalize_states(states), env.observation_spec().dtype), std=FLAGS.target_policy_std)[1] @tf.function def get_target_logprobs(states, actions): log_probs = actor(tf.cast(behavior_dataset.unnormalize_states(states), env.observation_spec().dtype), actions=actions, std=FLAGS.target_policy_std)[2] if tf.rank(log_probs) > 1: log_probs = tf.reduce_sum(log_probs, -1) return log_probs min_reward = tf.reduce_min(behavior_dataset.rewards) max_reward = tf.reduce_max(behavior_dataset.rewards) min_state = tf.reduce_min(behavior_dataset.states, 0) max_state = tf.reduce_max(behavior_dataset.states, 0) @tf.function def update_step(): (states, actions, next_states, rewards, masks, weights, _) = next(tf_dataset_iter) initial_actions = get_target_actions(behavior_dataset.initial_states) next_actions = get_target_actions(next_states) if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo: model.update(states, actions, next_states, next_actions, rewards, masks, weights, FLAGS.discount, min_reward, max_reward) elif 'mb' in FLAGS.algo: model.update(states, actions, next_states, rewards, masks, weights) elif 'dual_dice' in FLAGS.algo: model.update(behavior_dataset.initial_states, initial_actions, behavior_dataset.initial_weights, states, actions, next_states, next_actions, masks, weights, FLAGS.discount) if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo: behavior.update(states, actions, weights) gc.collect() for i in tqdm.tqdm(range(FLAGS.num_updates), desc='Running Training'): update_step() if i % FLAGS.eval_interval == 0: if 'fqe' in FLAGS.algo: pred_returns = model.estimate_returns( behavior_dataset.initial_states, behavior_dataset.initial_weights, get_target_actions) elif 'mb' in FLAGS.algo: pred_returns = model.estimate_returns( behavior_dataset.initial_states, behavior_dataset.initial_weights, get_target_actions, FLAGS.discount, min_reward, max_reward, min_state, max_state) elif FLAGS.algo in ['dual_dice']: pred_returns, pred_ratio = model.estimate_returns( iter(tf_dataset)) tf.summary.scalar('train/pred ratio', pred_ratio, step=i) elif 'iw' in FLAGS.algo or 'dr' in FLAGS.algo: discount = FLAGS.discount _, behavior_log_probs = behavior(behavior_dataset.states, behavior_dataset.actions) target_log_probs = get_target_logprobs( behavior_dataset.states, behavior_dataset.actions) offset = 0.0 rewards = behavior_dataset.rewards if 'dr' in FLAGS.algo: # Doubly-robust is effectively the same as importance-weighting but # transforming rewards at (s,a) to r(s,a) + gamma * V^pi(s') - # Q^pi(s,a) and adding an offset to each trajectory equal to V^pi(s0). offset = model.estimate_returns( behavior_dataset.initial_states, behavior_dataset.initial_weights, get_target_actions) q_values = (model(behavior_dataset.states, behavior_dataset.actions) / (1 - discount)) n_samples = 10 next_actions = [ get_target_actions(behavior_dataset.next_states) for _ in range(n_samples) ] next_q_values = sum([ model(behavior_dataset.next_states, next_action) / (1 - discount) for next_action in next_actions ]) / n_samples rewards = rewards + discount * next_q_values - q_values # Now we compute the self-normalized importance weights. # Self-normalization happens over trajectories per-step, so we # restructure the dataset as [num_trajectories, num_steps]. num_trajectories = len(behavior_dataset.initial_states) max_trajectory_length = np.max(behavior_dataset.steps) + 1 trajectory_weights = behavior_dataset.initial_weights trajectory_starts = np.where( np.equal(behavior_dataset.steps, 0))[0] batched_rewards = np.zeros( [num_trajectories, max_trajectory_length]) batched_masks = np.zeros( [num_trajectories, max_trajectory_length]) batched_log_probs = np.zeros( [num_trajectories, max_trajectory_length]) for traj_idx, traj_start in enumerate(trajectory_starts): traj_end = (trajectory_starts[traj_idx + 1] if traj_idx + 1 < len(trajectory_starts) else len(rewards)) traj_length = traj_end - traj_start batched_rewards[ traj_idx, :traj_length] = rewards[traj_start:traj_end] batched_masks[traj_idx, :traj_length] = 1. batched_log_probs[traj_idx, :traj_length] = ( -behavior_log_probs[traj_start:traj_end] + target_log_probs[traj_start:traj_end]) batched_weights = ( batched_masks * (discount**np.arange(max_trajectory_length))[None, :]) clipped_log_probs = np.clip(batched_log_probs, -6., 2.) cum_log_probs = batched_masks * np.cumsum(clipped_log_probs, axis=1) cum_log_probs_offset = np.max(cum_log_probs, axis=0) cum_probs = np.exp(cum_log_probs - cum_log_probs_offset[None, :]) avg_cum_probs = ( np.sum(cum_probs * trajectory_weights[:, None], axis=0) / (1e-10 + np.sum( batched_masks * trajectory_weights[:, None], axis=0))) norm_cum_probs = cum_probs / (1e-10 + avg_cum_probs[None, :]) weighted_rewards = batched_weights * batched_rewards * norm_cum_probs trajectory_values = np.sum(weighted_rewards, axis=1) avg_trajectory_value = ( (1 - discount) * np.sum(trajectory_values * trajectory_weights) / np.sum(trajectory_weights)) pred_returns = offset + avg_trajectory_value pred_returns = behavior_dataset.unnormalize_rewards(pred_returns) tf.summary.scalar('train/pred returns', pred_returns, step=i) logging.info('pred returns=%f', pred_returns) tf.summary.scalar('train/true minus pred returns', policy_returns - pred_returns, step=i) logging.info('true minus pred returns=%f', policy_returns - pred_returns)
def test_render(self): cartpole_env = gym.spec('CartPole-v1').make() cartpole_env.render = mock.MagicMock() env = gym_wrapper.GymWrapper(cartpole_env) env.render() cartpole_env.render.assert_called_once()
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, load_root_dir=None, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, initial_collect_driver_class=None, collect_driver_class=None, online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver, num_global_steps=1000000, rb_size=None, train_steps_per_iteration=1, train_metrics=None, eval_metrics=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args sc_rb_size=None, target_safety=None, train_sc_steps=10, train_sc_interval=1000, online_critic=False, n_envs=None, finetune_sc=False, pretraining=True, lambda_schedule_nsteps=0, lambda_initial=0., lambda_final=1., kstep_fail=0, # Ensemble Critic training args num_critics=None, critic_learning_rate=3e-4, # Wcpg Critic args critic_preprocessing_layer_size=256, # Params for train batch_size=256, # Params for eval run_eval=False, num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, early_termination_fn=None, debug_summaries=False, seed=None, eager_debug=False, env_metric_factories=None, wandb=False): # pylint: disable=unused-argument """train and eval script for SQRL.""" if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class) agent_class = ALGOS.get(agent_class) n_envs = n_envs or num_eval_episodes root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') # =====================================================================# # Setup summary metrics, file writers, and create env # # =====================================================================# train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] eval_metrics = eval_metrics or [] updating_sc = online_critic and (not load_root_dir or finetune_sc) logging.debug('updating safety critic: %s', updating_sc) if seed: tf.compat.v1.set_random_seed(seed) if agent_class in SAFETY_AGENTS: if online_critic: sc_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: sc_tf_env.pyenv.seed(seeds) except: pass if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs), ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: try: for i, pyenv in enumerate(eval_tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass elif 'Drunk' in env_name: # Just visualizes trajectories in drunk spider environment eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name)) else: eval_tf_env = None if monitor: vid_path = os.path.join(root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): py_env = env_load_fn(env_name) tf_env = tf_py_environment.TFPyEnvironment(py_env) if seed: try: for i, pyenv in enumerate(tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() logging.debug('obs spec: %s', observation_spec) logging.debug('action spec: %s', action_spec) # =====================================================================# # Setup agent class # # =====================================================================# if agent_class == wcpg_agent.WcpgAgent: alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1., name='alpha') input_tensor_spec = (observation_spec, action_spec, alpha_spec) critic_net = agents.DistributionalCriticNetwork( input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size, joint_fc_layer_params=critic_joint_fc_layers) actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: logging.debug('Making SQRL agent') if lambda_schedule_nsteps > 0: lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps lambda_scheduler = lambda lam: common.periodically( body=lambda: tf.group(lam.assign(lam + step_size)), period=lambda_update_every_nsteps) else: lambda_scheduler = None safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) ts = target_safety thresholds = [ts, 0.5] sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5)] tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=debug_summaries, safety_pretraining=pretraining, train_critic_online=online_critic, initial_log_lambda=lambda_initial, log_lambda=(lambda_scheduler is None), lambda_scheduler=lambda_scheduler) elif agent_class is ensemble_sac_agent.EnsembleSacAgent: critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)] for _ in range(num_critics - 1): critic_nets.append(agents.CriticNetwork((observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers)) critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate)) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_networks=critic_nets, critic_optimizers=critic_optimizers, debug_summaries=debug_summaries ) else: # agent is either SacAgent or WcpgAgent logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) tf_agent.initialize() # =====================================================================# # Setup replay buffer # # =====================================================================# collect_data_spec = tf_agent.collect_data_spec logging.debug('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. rb_size = rb_size or 1000000 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=rb_size) logging.debug('RB capacity: %i', replay_buffer.capacity) logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec) if agent_class in SAFETY_AGENTS: sc_rb_size = sc_rb_size or num_eval_episodes * 500 sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=sc_rb_size, dataset_window_shift=1) num_episodes = tf_metrics.NumberOfEpisodes() num_env_steps = tf_metrics.EnvironmentSteps() return_metric = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ num_episodes, num_env_steps, return_metric, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if 'Minitaur' in env_name and not pretraining: goal_vel = gin.query_parameter("%GOAL_VELOCITY") early_termination_fn = train_utils.MinitaurTerminationFn( speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3], env_steps_metric=num_env_steps, goal_speed=goal_vel) if env_metric_factories: for env_metric in env_metric_factories: train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs))) if run_eval: eval_metrics.append(env_metric([env for env in eval_tf_env.pyenv._envs])) # =====================================================================# # Setup collect policies # # =====================================================================# if not online_critic: eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy if not pretraining and agent_class in SAFETY_AGENTS: collect_policy = tf_agent.safe_policy else: eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy online_collect_policy = tf_agent.safe_policy # if pretraining else tf_agent.collect_policy if pretraining: online_collect_policy._training = False if not load_root_dir: initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec) else: initial_collect_policy = collect_policy if agent_class == wcpg_agent.WcpgAgent: initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy) # =====================================================================# # Setup Checkpointing # # =====================================================================# train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer( ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if online_critic: online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=sc_buffer) # loads agent, replay buffer, and online sc/buffer if online_critic if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_agent_ckpt(load_train_dir, tf_agent) if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1: load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer') misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer) if online_critic: load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc') load_online_rb_ckpt_dir = os.path.join(load_train_dir, 'online_replay_buffer') if osp.exists(load_online_rb_ckpt_dir): misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer) if osp.exists(load_online_sc_ckpt_dir): misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir, safety_critic_net) elif agent_class in SAFETY_AGENTS: offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1] load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline', offline_run, 'safety_critic') if osp.exists(load_sc_ckpt_dir): sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=(512, 512), name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate) _ = misc.load_safety_critic_ckpt( load_sc_ckpt_dir, safety_critic_net=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=sc_optimizer) tf_agent._safety_critic_network = sc_net_off tf_agent._target_safety_critic_network = target_sc_net_off tf_agent._safety_critic_optimizer = sc_optimizer else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if online_critic: online_rb_checkpointer.initialize_or_restore() if agent_class in SAFETY_AGENTS: sc_dir = os.path.join(root_dir, 'sc') safety_critic_checkpointer = common.Checkpointer( ckpt_dir=sc_dir, safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access target_safety_critic=tf_agent._target_safety_critic_network, optimizer=tf_agent._safety_critic_optimizer, global_step=global_step) if not (load_root_dir and not online_critic): safety_critic_checkpointer.initialize_or_restore() agent_observers = [replay_buffer.add_batch] + train_metrics collect_driver = collect_driver_class( tf_env, collect_policy, observers=agent_observers) collect_driver.run = common.function_in_tf1()(collect_driver.run) if online_critic: logging.debug('online driver class: %s', online_driver_class) online_agent_observers = [num_episodes, num_env_steps, sc_buffer.add_batch] online_driver = online_driver_class( sc_tf_env, online_collect_policy, observers=online_agent_observers, num_episodes=num_eval_episodes) online_driver.run = common.function_in_tf1()(online_driver.run) if eager_debug: tf.config.experimental_run_functions_eagerly(True) else: config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if global_step == 0: logging.info('Performing initial collection ...') init_collect_observers = agent_observers if agent_class in SAFETY_AGENTS: init_collect_observers += [sc_buffer.add_batch] initial_collect_driver_class( tf_env, initial_collect_policy, observers=init_collect_observers).run() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) if agent_class in SAFETY_AGENTS: last_id = sc_buffer._get_last_id() # pylint: disable=protected-access logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size) if agent_class in SAFETY_AGENTS: critic_train_step = train_utils.get_critic_train_step( tf_agent, replay_buffer, sc_buffer, batch_size=batch_size, updating_sc=updating_sc, metrics=sc_metrics) if early_termination_fn is None: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if agent_class in SAFETY_AGENTS: resample_counter = collect_policy._resample_counter mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq') sc_metrics.append(mean_resample_ac) if online_critic: logging.debug('starting safety critic pretraining') # don't fine-tune safety critic if global_step.numpy() == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback: train_metrics_callback(collections.OrderedDict(critic_results), step=global_step.numpy()) logging.debug('Starting main train loop...') curr_ep = [] global_step_val = global_step.numpy() while global_step_val <= num_global_steps and not early_termination_fn(): start_time = time.time() # MEASURE ACTION RESAMPLING FREQUENCY if agent_class in SAFETY_AGENTS: if pretraining and global_step_val == num_global_steps // 2: if online_critic: online_collect_policy._training = True collect_policy._training = True if online_critic or collect_policy._training: mean_resample_ac(resample_counter.result()) resample_counter.reset() if time_step is None or time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar( name='resample_ac_freq', data=resample_ac_freq, step=global_step) # RUN COLLECTION time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # get last step taken by step_driver traj = replay_buffer._data_table.read(replay_buffer._get_last_id() % replay_buffer._capacity) curr_ep.append(traj) if time_step.is_last(): if agent_class in SAFETY_AGENTS: if time_step.observation['task_agn_rew']: if kstep_fail: # applies task agn rew. over last k steps for i, traj in enumerate(curr_ep[-kstep_fail:]): traj.observation['task_agn_rew'] = 1. sc_buffer.add_batch(traj) else: [sc_buffer.add_batch(traj) for traj in curr_ep] curr_ep = [] if agent_class == wcpg_agent.WcpgAgent: collect_policy._alpha = None # reset WCPG alpha if (global_step_val + 1) % log_interval == 0: logging.debug('policy eval: %4.2f sec', time.time() - start_time) # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) current_step = global_step.numpy() total_loss = mean_train_loss.result() mean_train_loss.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback( collections.OrderedDict([(k, v.numpy()) for k, v in train_loss.extra._asdict().items()]), step=current_step) train_metrics_callback( {'train_loss': total_loss.numpy()}, step=current_step) # TRAIN AND/OR EVAL SAFETY CRITIC if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0: if online_critic: batch_time_step = sc_tf_env.reset() # run online critic training collect & update batch_policy_state = online_collect_policy.get_initial_state( sc_tf_env.batch_size) online_driver.run(time_step=batch_time_step, policy_state=batch_policy_state) for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # log safety_critic loss results critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] metric_utils.log_metrics(sc_metrics) for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(critic_results), step=current_step) # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True logging.info('Loss diverged, critic_loss: %s, actor_loss: %s', train_loss.extra.critic_loss, train_loss.extra.actor_loss) break else: loss_divergence_counter = 0 time_acc += time.time() - start_time # LOGGING AND METRICS if current_step % log_interval == 0: metric_utils.log_metrics(train_metrics) logging.info('step = %d, loss = %f', current_step, total_loss) steps_per_sec = (current_step - timed_at_step) / time_acc logging.info('%4.2f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = current_step time_acc = 0 train_results = [] for metric in train_metrics[2:]: if isinstance(metric, (metrics.AverageEarlyFailureMetric, metrics.AverageFallenMetric, metrics.AverageSuccessMetric)): # Plot failure as a fn of return metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_episodes, return_metric]) else: metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_env_steps]) train_results.append((metric.name, metric.result().numpy())) if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(train_results), step=global_step.numpy()) if current_step % train_checkpoint_interval == 0: train_checkpointer.save(global_step=current_step) if current_step % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=current_step) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=current_step) if online_critic: online_rb_checkpointer.save(global_step=current_step) if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=current_step) if wandb and current_step % eval_interval == 0 and "Drunk" in env_name: misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step) if online_critic: misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy, current_step, 'safe-trajectory') if run_eval and current_step % eval_interval == 0: eval_results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(eval_results, current_step) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries(train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and current_step % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec', current_step, ep_len, time.time() - monitor_start) global_step_val = current_step if early_termination_fn(): # Early stopped, save all checkpoints if not saved if global_step_val % train_checkpoint_interval != 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval != 0: policy_checkpointer.save(global_step=global_step_val) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=global_step_val) if online_critic: online_rb_checkpointer.save(global_step=global_step_val) if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
replay_buffer_capacity = 100000 fc_layer_params = (100, ) batch_size = 64 learning_rate = 1e-3 log_interval = 200 num_eval_episodes = 10 eval_interval = 1000 eval_interval = 50 # env_name = 'CartPole-v0' env_name = 'oscillator-v0' train_env = envs.make(env_name) train_py_env = gym_wrapper.GymWrapper(train_env) eval_env = envs.make(env_name) eval_py_env = gym_wrapper.GymWrapper(eval_env) print('Observation spec: {}'.format(train_py_env.time_step_spec().observation)) print('Action spec: {}'.format(train_py_env.action_spec())) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) # optimizer = tf.optimizers.Adam(learning_rate=learning_rate)