Ejemplo n.º 1
0
  def test_limit_duration_wrapped_env_forwards_calls(self):
    cartpole_env = gym.spec('CartPole-v1').make()
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 10)

    action_spec = env.action_spec()
    self.assertEqual((), action_spec.shape)
    self.assertEqual(0, action_spec.minimum)
    self.assertEqual(1, action_spec.maximum)

    observation_spec = env.observation_spec()
    self.assertEqual((4,), observation_spec.shape)
    high = np.array([
        4.8,
        np.finfo(np.float32).max, 2 / 15.0 * math.pi,
        np.finfo(np.float32).max
    ])
    np.testing.assert_array_almost_equal(-high, observation_spec.minimum)
    np.testing.assert_array_almost_equal(high, observation_spec.maximum)
Ejemplo n.º 2
0
    def __init__(self,
                 name,
                 stock_list,
                 feature_set_factories,
                 actor_fc_layers=ACTOR_FC_LAYERS,
                 actor_dropout_layer_params=ACTOR_DROPOUT_LAYER_PARAMS,
                 critic_observation_fc_layer_params=CRITIC_OBS_FC_LAYERS,
                 critic_action_fc_layer_params=CRITIC_ACTION_FC_LAYERS,
                 critic_joint_fc_layer_params=CRITIC_JOINT_FC_LAYERS,
                 actor_alpha=ACTOR_ALPHA,
                 critic_alpha=CRITIC_ALPHA,
                 alpha_alpha=ALPHA_ALPHA,
                 gamma=GAMMA):
        self.name = name
        self.stock_list = stock_list
        self.feature_generator = FeatureGenerator(feature_set_factories)
        self.reset()
        action_space = self.action_space = spaces.Box(low=-1.0,
                                                      high=1.0,
                                                      shape=(1, ),
                                                      dtype=np.float32)
        self.gym_training_env = TrainingStockEnv(stock_list,
                                                 self.feature_generator.copy(),
                                                 action_space,
                                                 self.convert_action)
        self.tf_training_env = tf_py_environment.TFPyEnvironment(
            gym_wrapper.GymWrapper(self.gym_training_env,
                                   discount=gamma,
                                   auto_reset=True))

        self.actor = self.create_actor_network(actor_fc_layers,
                                               actor_dropout_layer_params)
        self.critic = self.create_critic_network(
            critic_observation_fc_layer_params, critic_action_fc_layer_params,
            critic_joint_fc_layer_params)
        self.tf_agent = self.create_sac_agent(self.actor, self.critic,
                                              actor_alpha, critic_alpha,
                                              alpha_alpha, gamma)
        self.eval_policy = self.tf_agent.policy
        self.eval_env = EvalEnv(self.stock_list, self)

        self.tf_agent.initialize()
Ejemplo n.º 3
0
  def test_automatic_reset(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 2)

    # Episode 1
    first_time_step = env.step(0)
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(0)
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(0)
    self.assertTrue(last_time_step.is_last())

    # Episode 2
    first_time_step = env.step(0)
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(0)
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(0)
    self.assertTrue(last_time_step.is_last())
Ejemplo n.º 4
0
  def test_duration_applied_after_episode_terminates_early(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.TimeLimit(env, 10000)

    # Episode 1 stepped until termination occurs.
    time_step = env.step(np.array(1, dtype=np.int32))
    while not time_step.is_last():
      time_step = env.step(np.array(1, dtype=np.int32))

    self.assertTrue(time_step.is_last())
    env._duration = 2

    # Episode 2 short duration hits step limit.
    first_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(last_time_step.is_last())
Ejemplo n.º 5
0
  def test_automatic_reset(self):
    cartpole_env = gym.make('CartPole-v1')
    env = gym_wrapper.GymWrapper(cartpole_env)
    env = wrappers.FixedLength(env, 2)

    # Episode 1
    first_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(last_time_step.is_last())

    # Episode 2
    first_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(first_time_step.is_first())
    mid_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(mid_time_step.is_mid())
    last_time_step = env.step(np.array(0, dtype=np.int32))
    self.assertTrue(last_time_step.is_last())
Ejemplo n.º 6
0
    def test_wrapped_cartpole_specs(self):
        # Note we use spec.make on gym envs to avoid getting a TimeLimit wrapper on
        # the environment.
        cartpole_env = gym.spec('CartPole-v1').make()
        env = gym_wrapper.GymWrapper(cartpole_env)

        action_spec = env.action_spec()
        self.assertEqual((), action_spec.shape)
        self.assertEqual(0, action_spec.minimum)
        self.assertEqual(1, action_spec.maximum)

        observation_spec = env.observation_spec()
        self.assertEqual((4, ), observation_spec.shape)
        self.assertEqual(np.float32, observation_spec.dtype)
        high = np.array([
            4.8,
            np.finfo(np.float32).max, 2 / 15.0 * math.pi,
            np.finfo(np.float32).max
        ])
        np.testing.assert_array_almost_equal(-high, observation_spec.minimum)
        np.testing.assert_array_almost_equal(high, observation_spec.maximum)
Ejemplo n.º 7
0
def env_load_fn(
        environment_name,  # pylint: disable=dangerous-default-value
        max_episode_steps=50,
        resize_factor=1,
        env_kwargs=dict(action_noise=0., start=(0, 3)),
        goal_env_kwargs=dict(goal=(7, 3)),
        terminate_on_timeout=True):
    """Loads the selected environment and wraps it with the specified wrappers.

  Args:
    environment_name: Name for the environment to load.
    max_episode_steps: If None the max_episode_steps will be set to the default
      step limit defined in the environment's spec. No limit is applied if set
      to 0 or if there is no timestep_limit set in the environment's spec.
    resize_factor: A factor for resizing.
    env_kwargs: Arguments for envs.
    goal_env_kwargs: Arguments for goal envs.
    terminate_on_timeout: Whether to set done = True when the max episode steps
      is reached.

  Returns:
    A PyEnvironmentBase instance.
  """
    gym_env = PointMassEnv(environment_name,
                           resize_factor=resize_factor,
                           **env_kwargs)

    gym_env = GoalConditionedPointWrapper(gym_env, **goal_env_kwargs)
    env = gym_wrapper.GymWrapper(gym_env,
                                 discount=1.0,
                                 auto_reset=True,
                                 simplify_box_bounds=False)

    if max_episode_steps > 0:
        if terminate_on_timeout:
            env = TimeLimitBonus(env, max_episode_steps)
        else:
            env = NonTerminatingTimeLimit(env, max_episode_steps)

    return env
Ejemplo n.º 8
0
def env_load_fn(environment_name,
				 max_episode_steps=None,
				 resize_factor=1,
				 gym_env_wrappers=(GoalConditionedPointWrapper,),
				 terminate_on_timeout=False):
	"""Loads the selected environment and wraps it with the specified wrappers.

	Args:
		environment_name: Name for the environment to load.
		max_episode_steps: If None the max_episode_steps will be set to the default
			step limit defined in the environment's spec. No limit is applied if set
			to 0 or if there is no timestep_limit set in the environment's spec.
		gym_env_wrappers: Iterable with references to wrapper classes to use
			directly on the gym environment.
		terminate_on_timeout: Whether to set done = True when the max episode
			steps is reached.

	Returns:
		A PyEnvironmentBase instance.
	"""
	gym_env = PointEnv(walls=environment_name,
										 resize_factor=resize_factor)
	
	for wrapper in gym_env_wrappers:
		gym_env = wrapper(gym_env)
	env = gym_wrapper.GymWrapper(
			gym_env,
			discount=1.0,
			auto_reset=True,
	)

	if max_episode_steps > 0:
		if terminate_on_timeout:
			env = wrappers.TimeLimit(env, max_episode_steps)
		else:
			env = NonTerminatingTimeLimit(env, max_episode_steps)

	return tf_py_environment.TFPyEnvironment(env)
Ejemplo n.º 9
0
def get_onpolicy_dataset(env_name, tabular_obs, policy_fn, policy_info_spec):
    """Gets target policy."""
    if env_name == 'taxi':
        env = taxi.Taxi(tabular_obs=tabular_obs)
    elif env_name == 'grid':
        env = navigation.GridWalk(tabular_obs=tabular_obs)
    elif env_name == 'lowrank_tree':
        env = tree.Tree(branching=2, depth=3, duplicate=10)
    elif env_name == 'frozenlake':
        env = InfiniteFrozenLake()
    elif env_name == 'low_rank':
        env = low_rank.LowRank()
    else:
        raise ValueError('Unknown environment: %s.' % env_name)

    tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    tf_policy = common_utils.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                   tf_env.action_spec(),
                                                   policy_fn,
                                                   policy_info_spec,
                                                   emit_log_probability=True)

    return TFAgentsOnpolicyDataset(tf_env, tf_policy)
Ejemplo n.º 10
0
def wrap_env(env,
             discount=1.0,
             max_episode_steps=0,
             gym_env_wrappers=(),
             time_limit_wrapper=wrappers.TimeLimit,
             env_wrappers=(),
             spec_dtype_map=None,
             auto_reset=True):
    for wrapper in gym_env_wrappers:
        gym_env = wrapper(gym_env)
    env = gym_wrapper.GymWrapper(env,
                                 discount=discount,
                                 spec_dtype_map=spec_dtype_map,
                                 match_obs_space_dtype=True,
                                 auto_reset=auto_reset,
                                 simplify_box_bounds=True)

    if max_episode_steps > 0:
        env = time_limit_wrapper(env, max_episode_steps)

    for wrapper in env_wrappers:
        env = wrapper(env)

    return env
Ejemplo n.º 11
0
    def test_automatic_reset_after_create(self):
        cartpole_env = gym.spec('CartPole-v1').make()
        env = gym_wrapper.GymWrapper(cartpole_env)

        first_time_step = env.step(0)
        self.assertTrue(first_time_step.is_first())
Ejemplo n.º 12
0
 def test_render(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     cartpole_env.render = mock.MagicMock()
     env = gym_wrapper.GymWrapper(cartpole_env)
     env.render()
     self.assertEqual(1, cartpole_env.render.call_count)
Ejemplo n.º 13
0
def get_py_env(env_name,
               max_episode_steps=None,
               constant_task=None,
               use_neg_rew=None,
               margin=None):
  """Load an environment.

  Args:
    env_name: (str) name of the environment.
    max_episode_steps: (int) maximum number of steps per episode. Set to None
      to not include a limit.
    constant_task: specifies a fixed task to use for all episodes. Set to None
      to use tasks sampled from the task distribution.
    use_neg_rew: (bool) For the goal-reaching tasks, indicates whether to use
      a (-1, 0) sparse reward (use_neg_reward = True) or a (0, 1) sparse reward.
    margin: (float) For goal-reaching tasks, indicates the desired distance
      to the goal.
  Returns:
    env: the environment, build from a dynamics and task distribution
    task_distribution: the task distribution used for the environment.
  """
  if "sawyer" in env_name:
    print(("ERROR: Modify utils.py to import sawyer_env and not dm_env. "
           "Currently the sawyer_env import is commented out to prevent "
           "a segfault from occuring when trying to import both sawyer_env "
           "and dm_env"))
    assert False

  if env_name.split("_")[0] == "point":
    _, walls, resize_factor = env_name.split("_")
    dynamics = point_env.PointDynamics(
        walls=walls, resize_factor=int(resize_factor))
    task_distribution = point_env.PointGoalDistribution(dynamics)
  elif env_name.split("_")[0] == "pointTask":
    _, walls, resize_factor = env_name.split("_")
    dynamics = point_env.PointDynamics(
        walls=walls, resize_factor=int(resize_factor))
    task_distribution = point_env.PointTaskDistribution(dynamics)

  elif env_name == "quadruped-run":
    dynamics = dm_env.QuadrupedRunDynamics()
    task_distribution = dm_env.QuadrupedRunTaskDistribution(dynamics)
  elif env_name == "quadruped":
    dynamics = dm_env.QuadrupedRunDynamics()
    task_distribution = dm_env.QuadrupedContinuousTaskDistribution(dynamics)
  elif env_name == "hopper":
    dynamics = dm_env.HopperDynamics()
    task_distribution = dm_env.HopperTaskDistribution(dynamics)
  elif env_name == "hopper-discrete":
    dynamics = dm_env.HopperDynamics()
    task_distribution = dm_env.HopperDiscreteTaskDistribution(dynamics)
  elif env_name == "walker":
    dynamics = dm_env.WalkerDynamics()
    task_distribution = dm_env.WalkerTaskDistribution(dynamics)
  elif env_name == "humanoid":
    dynamics = dm_env.HumanoidDynamics()
    task_distribution = dm_env.HumanoidTaskDistribution(dynamics)

  ### sparse tasks
  elif env_name == "finger":
    dynamics = dm_env.FingerDynamics()
    task_distribution = dm_env.FingerTaskDistribution(
        dynamics, use_neg_rew=use_neg_rew, margin=margin)
  elif env_name == "manipulator":
    dynamics = dm_env.ManipulatorDynamics()
    task_distribution = dm_env.ManipulatorTaskDistribution(dynamics)
  elif env_name == "point-mass":
    dynamics = dm_env.PointMassDynamics()
    task_distribution = dm_env.PointMassTaskDistribution(
        dynamics, use_neg_rew=use_neg_rew, margin=margin)
  elif env_name == "stacker":
    dynamics = dm_env.StackerDynamics()
    task_distribution = dm_env.FingerStackerDistribution(
        dynamics, use_neg_rew=use_neg_rew, margin=margin)
  elif env_name == "swimmer":
    dynamics = dm_env.SwimmerDynamics()
    task_distribution = dm_env.SwimmerTaskDistribution(dynamics)
  elif env_name == "fish":
    dynamics = dm_env.FishDynamics()
    task_distribution = dm_env.FishTaskDistribution(dynamics)

  elif env_name == "sawyer-reach":
    dynamics = sawyer_env.SawyerDynamics()
    task_distribution = sawyer_env.SawyerReachTaskDistribution(dynamics)
  else:
    raise NotImplementedError("Unknown environment: %s" % env_name)
  gym_env = multitask.Environment(
      dynamics, task_distribution, constant_task=constant_task)
  if max_episode_steps is not None:
    # Add a placeholder spec so the TimeLimit wrapper works.
    gym_env.spec = registration.EnvSpec("env-v0")
    gym_env = TimeLimit(gym_env, max_episode_steps)
  wrapped_env = gym_wrapper.GymWrapper(gym_env, discount=1.0, auto_reset=True)
  return wrapped_env, task_distribution
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset(
        task_name=FLAGS.task_name, batch_size=FLAGS.batch_size)

    env = gym_wrapper.GymWrapper(gym_env)
    env = tf_py_environment.TFPyEnvironment(env)

    dataset_iter = iter(dataset)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.task_name,
                                          data_name=FLAGS.data_name)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    if FLAGS.algo_name == 'bc':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec())
    elif FLAGS.algo_name == 'bc_mix':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec(),
                                                     mixture=True)
    elif 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(env.observation_spec(), env.action_spec())
    elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max')
    elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(env.observation_spec(),
                        env.action_spec(),
                        f='exp_mean')
    elif 'bcq' in FLAGS.algo_name:
        model = bcq.BCQ(env.observation_spec(), env.action_spec())
    elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(env.observation_spec(),
                          env.action_spec(),
                          target_entropy=-env.action_spec().shape[0])
    elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(env.observation_spec(),
                        env.action_spec(),
                        target_entropy=-env.action_spec().shape[0])
    elif 'cql' in FLAGS.algo_name:
        model = cql.CQL(env.observation_spec(),
                        env.action_spec(),
                        target_entropy=-env.action_spec().shape[0])
    elif 'brac' in FLAGS.algo_name:
        if 'fbrac' in FLAGS.algo_name:
            model = fisher_brac.FBRAC(
                env.observation_spec(),
                env.action_spec(),
                target_entropy=-env.action_spec().shape[0],
                f_reg=FLAGS.f_reg,
                reward_bonus=FLAGS.reward_bonus)
        else:
            model = brac.BRAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])

        model_folder = os.path.join(
            FLAGS.save_dir, 'models',
            f'{FLAGS.task_name}_{FLAGS.data_name}_{FLAGS.seed}')
        if not tf.gfile.io.isdir(model_folder):
            bc_pretraining_steps = 1_000_000
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.bc.update_step(dataset_iter)

                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            tf.summary.scalar(f'training/{k}',
                                              v,
                                              step=i - bc_pretraining_steps)
            # model.bc.policy.save_weights(os.path.join(model_folder, 'model'))
        else:
            model.bc.policy.load_weights(os.path.join(model_folder, 'model'))

    for i in tqdm.tqdm(range(FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(dataset_iter)

        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            if FLAGS.data_name is None:
                average_returns = gym_env.get_normalized_score(
                    average_returns) * 100.0

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length',
                                  average_length,
                                  step=i + 1)
Ejemplo n.º 15
0
def get_env_and_policy(load_dir,
                       env_name,
                       alpha,
                       env_seed=0,
                       tabular_obs=False):
    if env_name == 'taxi':
        env = taxi.Taxi(tabular_obs=tabular_obs)
        env.seed(env_seed)
        policy_fn, policy_info_spec = taxi.get_taxi_policy(load_dir,
                                                           env,
                                                           alpha=alpha,
                                                           py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'grid':
        env = navigation.GridWalk(tabular_obs=tabular_obs)
        env.seed(env_seed)
        policy_fn, policy_info_spec = navigation.get_navigation_policy(
            env, epsilon_explore=0.1 + 0.6 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'low_rank':
        env = low_rank.LowRank()
        env.seed(env_seed)
        policy_fn, policy_info_spec = low_rank.get_low_rank_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'tree':
        env = tree.Tree(branching=2, depth=10)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'lowrank_tree':
        env = tree.Tree(branching=2, depth=3, duplicate=10)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name.startswith('bandit'):
        num_arms = int(env_name[6:]) if len(env_name) > 6 else 2
        env = bandit.Bandit(num_arms=num_arms)
        env.seed(env_seed)
        policy_fn, policy_info_spec = bandit.get_bandit_policy(
            env, epsilon_explore=1 - alpha, py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'small_tree':
        env = tree.Tree(branching=2, depth=3, loop=True)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'CartPole-v0':
        tf_env, policy = get_env_and_dqn_policy(
            env_name,
            os.path.join(load_dir, 'CartPole-v0', 'train', 'policy'),
            env_seed=env_seed,
            epsilon=0.3 + 0.15 * (1 - alpha))
    elif env_name == 'cartpole':  # Infinite-horizon cartpole.
        tf_env, policy = get_env_and_dqn_policy(
            'CartPole-v0',
            os.path.join(load_dir, 'CartPole-v0-250', 'train', 'policy'),
            env_seed=env_seed,
            epsilon=0.3 + 0.15 * (1 - alpha))
        env = InfiniteCartPole()
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    elif env_name == 'FrozenLake-v0':
        tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0',
                                                os.path.join(
                                                    load_dir, 'FrozenLake-v0',
                                                    'train', 'policy'),
                                                env_seed=env_seed,
                                                epsilon=0.2 * (1 - alpha),
                                                ckpt_file='ckpt-100000')
    elif env_name == 'frozenlake':  # Infinite-horizon frozenlake.
        tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0',
                                                os.path.join(
                                                    load_dir, 'FrozenLake-v0',
                                                    'train', 'policy'),
                                                env_seed=env_seed,
                                                epsilon=0.2 * (1 - alpha),
                                                ckpt_file='ckpt-100000')
        env = InfiniteFrozenLake()
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    elif env_name in ['Reacher-v2', 'reacher']:
        if env_name == 'Reacher-v2':
            env = suites.load_mujoco(env_name)
        else:
            env = gym_wrapper.GymWrapper(InfiniteReacher())
        env.seed(env_seed)
        tf_env = tf_py_environment.TFPyEnvironment(env)
        sac_policy = get_sac_policy(tf_env)
        directory = os.path.join(load_dir, 'Reacher-v2', 'train', 'policy')
        policy = load_policy(sac_policy, env_name, directory)
        policy = GaussianPolicy(policy,
                                0.4 - 0.3 * alpha,
                                emit_log_probability=True)
    elif env_name == 'HalfCheetah-v2':
        env = suites.load_mujoco(env_name)
        env.seed(env_seed)
        tf_env = tf_py_environment.TFPyEnvironment(env)
        sac_policy = get_sac_policy(tf_env)
        directory = os.path.join(load_dir, env_name, 'train', 'policy')
        policy = load_policy(sac_policy, env_name, directory)
        policy = GaussianPolicy(policy,
                                0.2 - 0.1 * alpha,
                                emit_log_probability=True)
    else:
        raise ValueError('Unrecognized environment %s.' % env_name)

    return tf_env, policy
def load_carla_env(env_name='carla-v0',
                   discount=1.0,
                   number_of_vehicles=100,
                   number_of_walkers=0,
                   display_size=256,
                   max_past_step=1,
                   dt=0.1,
                   discrete=False,
                   discrete_acc=[-3.0, 0.0, 3.0],
                   discrete_steer=[-0.2, 0.0, 0.2],
                   continuous_accel_range=[-3.0, 3.0],
                   continuous_steer_range=[-0.3, 0.3],
                   ego_vehicle_filter='vehicle.lincoln*',
                   port=2000,
                   town='Town03',
                   task_mode='random',
                   max_time_episode=500,
                   max_waypt=12,
                   obs_range=32,
                   lidar_bin=0.5,
                   d_behind=12,
                   out_lane_thres=2.0,
                   desired_speed=8,
                   max_ego_spawn_times=200,
                   display_route=True,
                   pixor_size=64,
                   pixor=False,
                   obs_channels=None,
                   action_repeat=1):
    """Loads train and eval environments."""
    env_params = {
        'number_of_vehicles': number_of_vehicles,
        'number_of_walkers': number_of_walkers,
        'display_size': display_size,  # screen size of bird-eye render
        'max_past_step': max_past_step,  # the number of past steps to draw
        'dt': dt,  # time interval between two frames
        'discrete': discrete,  # whether to use discrete control space
        'discrete_acc': discrete_acc,  # discrete value of accelerations
        'discrete_steer': discrete_steer,  # discrete value of steering angles
        'continuous_accel_range':
        continuous_accel_range,  # continuous acceleration range
        'continuous_steer_range':
        continuous_steer_range,  # continuous steering angle range
        'ego_vehicle_filter':
        ego_vehicle_filter,  # filter for defining ego vehicle
        'port': port,  # connection port
        'town': town,  # which town to simulate
        'task_mode':
        task_mode,  # mode of the task, [random, roundabout (only for Town03)]
        'max_time_episode': max_time_episode,  # maximum timesteps per episode
        'max_waypt': max_waypt,  # maximum number of waypoints
        'obs_range': obs_range,  # observation range (meter)
        'lidar_bin': lidar_bin,  # bin size of lidar sensor (meter)
        'd_behind': d_behind,  # distance behind the ego vehicle (meter)
        'out_lane_thres': out_lane_thres,  # threshold for out of lane
        'desired_speed': desired_speed,  # desired speed (m/s)
        'max_ego_spawn_times':
        max_ego_spawn_times,  # maximum times to spawn ego vehicle
        'display_route': display_route,  # whether to render the desired route
        'pixor_size': pixor_size,  # size of the pixor labels
        'pixor': pixor,  # whether to output PIXOR observation
    }

    gym_spec = gym.spec(env_name)
    gym_env = gym_spec.make(params=env_params)

    if obs_channels:
        gym_env = filter_observation_wrapper.FilterObservationWrapper(
            gym_env, obs_channels)

    py_env = gym_wrapper.GymWrapper(
        gym_env,
        discount=discount,
        auto_reset=True,
    )

    eval_py_env = py_env

    if action_repeat > 1:
        py_env = wrappers.ActionRepeat(py_env, action_repeat)

    return py_env, eval_py_env
Ejemplo n.º 17
0
def env_load_fn(environment_name='DrunkSpiderShort',
                max_episode_steps=50,
                resize_factor=(1, 1),
                terminate_on_timeout=True,
                start=(0, 3),
                goal=(7, 3),
                goal_bounds=[(6, 2), (7, 4)],
                fall_penalty=0.,
                reset_on_fall=False,
                gym_env_wrappers=[],
                gym=False):
    """Loads the selected environment and wraps it with the specified wrappers.

  Args:
    environment_name: Name for the environment to load.
    max_episode_steps: If None the max_episode_steps will be set to the default
      step limit defined in the environment's spec. No limit is applied if set
      to 0 or if there is no timestep_limit set in the environment's spec.
    resize_factor: A factor for resizing.
    terminate_on_timeout: Whether to set done = True when the max episode steps
      is reached.

  Returns:
    A PyEnvironmentBase instance.
  """
    if resize_factor != (1, 1):
        if start:
            start = (start[0] * resize_factor[0], start[1] * resize_factor[1])
        if goal:
            goal = (goal[0] * resize_factor[0], goal[1] * resize_factor[1])
        if goal_bounds:
            goal_bounds = [(g[0] * resize_factor[0], g[1] * resize_factor[1])
                           for g in goal_bounds]

    if 'acnoise' in environment_name.split('-'):
        environment_name = environment_name.split('-')[0]
        gym_env = PointMassAcNoiseEnv(start=start,
                                      env_name=environment_name,
                                      resize_factor=resize_factor)
    elif 'acscale' in environment_name.split('-'):
        environment_name = environment_name.split('-')[0]
        gym_env = PointMassAcScaleEnv(start=start,
                                      env_name=environment_name,
                                      resize_factor=resize_factor)
    else:
        gym_env = PointMassEnv(start=start,
                               env_name=environment_name,
                               resize_factor=resize_factor)

    gym_env = GoalConditionedPointWrapper(gym_env,
                                          goal=goal,
                                          goal_bounds=goal_bounds,
                                          fall_penalty=-abs(fall_penalty),
                                          reset_on_fall=reset_on_fall,
                                          max_episode_steps=max_episode_steps)
    for wrapper in gym_env_wrappers:
        gym_env = wrapper(gym_env)
    if gym:
        return gym_env

    from tf_agents.environments import gym_wrapper
    env = gym_wrapper.GymWrapper(gym_env,
                                 discount=1.0,
                                 auto_reset=True,
                                 simplify_box_bounds=False)

    if max_episode_steps > 0:
        if terminate_on_timeout:
            env = TimeLimitBonus(env, max_episode_steps)
        else:
            env = NonTerminatingTimeLimit(env, max_episode_steps)

    return env
Ejemplo n.º 18
0
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset(
        task_name=FLAGS.task_name, batch_size=FLAGS.batch_size)

    env = gym_wrapper.GymWrapper(gym_env)
    env = tf_py_environment.TFPyEnvironment(env)

    dataset_iter = iter(dataset)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = f'{FLAGS.algo_name}_{FLAGS.task_name}_seed={FLAGS.seed}'

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    if FLAGS.algo_name == 'bc':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec())
    else:
        model = fisher_brc.FBRC(env.observation_spec(),
                                env.action_spec(),
                                target_entropy=-env.action_spec().shape[0],
                                f_reg=FLAGS.f_reg,
                                reward_bonus=FLAGS.reward_bonus)

        for i in tqdm.tqdm(range(FLAGS.bc_pretraining_steps)):
            info_dict = model.bc.update_step(dataset_iter)

            if i % FLAGS.log_interval == 0:
                with summary_writer.as_default():
                    for k, v in info_dict.items():
                        tf.summary.scalar(f'training/{k}',
                                          v,
                                          step=i - FLAGS.bc_pretraining_steps)

    for i in tqdm.tqdm(range(FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(dataset_iter)

        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            average_returns = gym_env.get_normalized_score(
                average_returns) * 100.0

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length',
                                  average_length,
                                  step=i + 1)
Ejemplo n.º 19
0

window_size = 10

num_eval_episodes = 30

network_shape = (128, 128,)
learning_rate = 1e-3

tau = 0.1
gradient_clipping = 1

policy_dir = 'policy'


train_py_env = gym_wrapper.GymWrapper(StockTradingEnv(df=train_data, window_size=window_size))
eval_py_env = gym_wrapper.GymWrapper(StockTradingEnv(df=eval_data, window_size=window_size, eval=True))

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)


def eval_policy(policy, render=False):
    total_reward = 0
    for _ in range(num_eval_episodes):
        time_step = eval_env.reset()
        episode_reward = 0
        while not time_step.is_last():
            time_step = eval_env.step(policy.action(time_step).action)
            episode_reward += time_step.reward
def main(_):
  tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  def preprocess_fn(dataset):
    return dataset.cache().shuffle(1_000_000, reshuffle_each_iteration=True)

  def state_mask_fn(states):
    if FLAGS.state_mask_dims == 0:
      return states
    assert FLAGS.state_mask_dims <= states.shape[1]
    state_mask_dims = (
        states.shape[1]
        if FLAGS.state_mask_dims == -1 else FLAGS.state_mask_dims)
    if FLAGS.state_mask_index == 'fixed':
      mask_indices = range(states.shape[1] - state_mask_dims, states.shape[1])
    else:
      mask_indices = np.random.permutation(np.arange(
          states.shape[1]))[:state_mask_dims]
    if FLAGS.state_mask_value == 'gaussian':
      mask_values = states[:, mask_indices]
      mask_values = (
          mask_values + np.std(mask_values, axis=0) *
          np.random.normal(size=mask_values.shape))
    elif 'quantize' in FLAGS.state_mask_value:
      mask_values = states[:, mask_indices]
      mask_values = np.around(
          mask_values, decimals=int(FLAGS.state_mask_value[-1]))
    else:
      mask_values = 0
    states[:, mask_indices] = mask_values
    return states

  gym_env, dataset, embed_dataset = d4rl_utils.create_d4rl_env_and_dataset(
      task_name=FLAGS.task_name,
      batch_size=FLAGS.batch_size,
      sliding_window=FLAGS.embed_training_window,
      state_mask_fn=state_mask_fn)

  downstream_embed_dataset = None
  if (FLAGS.downstream_task_name is not None or
      FLAGS.downstream_data_name is not None or
      FLAGS.downstream_data_size is not None):
    downstream_data_name = FLAGS.downstream_data_name
    assert downstream_data_name is None
    gym_env, dataset, downstream_embed_dataset = d4rl_utils.create_d4rl_env_and_dataset(
        task_name=FLAGS.downstream_task_name,
        batch_size=FLAGS.batch_size,
        sliding_window=FLAGS.embed_training_window,
        data_size=FLAGS.downstream_data_size,
        state_mask_fn=state_mask_fn)

    if FLAGS.proportion_downstream_data:
      zipped_dataset = tf.data.Dataset.zip((embed_dataset, downstream_embed_dataset))

      def combine(*elems1_and_2):
        batch_size = tf.shape(elems1_and_2[0][0])[0]
        which = tf.random.uniform([batch_size]) >= FLAGS.proportion_downstream_data
        from1 = tf.where(which)
        from2 = tf.where(tf.logical_not(which))
        new_elems = map(
            lambda x: tf.concat([tf.gather_nd(x[0], from1), tf.gather_nd(x[1], from2)], 0),
            zip(*elems1_and_2))
        return tuple(new_elems)

      embed_dataset = zipped_dataset.map(combine)

  if FLAGS.embed_learner and 'action' in FLAGS.embed_learner:
    assert FLAGS.embed_training_window >= 2
    dataset = downstream_embed_dataset or embed_dataset

  if FLAGS.downstream_mode == 'online':

    downstream_task = FLAGS.downstream_task_name or FLAGS.task_name
    try:
      train_gym_env = gym.make(downstream_task)
    except:
      train_gym_env = gym.make('DM-' + downstream_task)
    train_env = gym_wrapper.GymWrapper(train_gym_env)

    train_env = tf_py_environment.TFPyEnvironment(train_env)

    replay_spec = (
        train_env.observation_spec(),
        train_env.action_spec(),
        train_env.reward_spec(),
        train_env.reward_spec(),  # discount spec
        train_env.observation_spec(),  # next observation spec
    )
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        replay_spec,
        batch_size=1,
        max_length=FLAGS.num_updates,
        dataset_window_shift=1 if get_ctx_length() else None)

    @tf.function
    def add_to_replay(state, action, reward, discount, next_states,
                      replay_buffer=replay_buffer):
      replay_buffer.add_batch((state, action, reward, discount, next_states))

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        train_env.time_step_spec(), train_env.action_spec())

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=FLAGS.batch_size,
        num_steps=FLAGS.embed_training_window
        if get_ctx_length() else None).prefetch(3)
    dataset = dataset.map(lambda *data: data[0])
  else:
    train_env = None
    replay_buffer = None
    add_to_replay = None
    initial_collect_policy = None

  env = gym_wrapper.GymWrapper(gym_env)
  env = tf_py_environment.TFPyEnvironment(env)

  dataset_iter = iter(dataset)
  embed_dataset_iter = iter(embed_dataset) if embed_dataset else None

  tf.random.set_seed(FLAGS.seed)

  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'tb'))
  result_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'results'))

  if (FLAGS.state_embed_dim or FLAGS.state_action_embed_dim
     ) and FLAGS.embed_learner and FLAGS.embed_pretraining_steps != 0:
    embed_model = get_embed_model(env)
    if FLAGS.finetune:
      other_embed_model = get_embed_model(env)
      other_embed_model2 = get_embed_model(env)
    else:
      other_embed_model = None
      other_embed_model2 = None
  else:
    embed_model = None
    other_embed_model = None
    other_embed_model2 = None

  config_str = f'{FLAGS.task_name}_{FLAGS.embed_learner}_{FLAGS.state_embed_dim}_{FLAGS.state_embed_dists}_{FLAGS.embed_training_window}_{FLAGS.downstream_input_mode}_{FLAGS.finetune}_{FLAGS.network}_{FLAGS.seed}'
  if FLAGS.embed_learner == 'acl':
    config_str += f'_{FLAGS.predict_actions}_{FLAGS.policy_decoder_on_embeddings}_{FLAGS.reward_decoder_on_embeddings}_{FLAGS.predict_rewards}_{FLAGS.embed_on_input}_{FLAGS.extra_embedder}_{FLAGS.positional_encoding_type}_{FLAGS.direction}'
  elif FLAGS.embed_learner and 'action' in FLAGS.embed_learner:
    config_str += f'_{FLAGS.state_action_embed_dim}_{FLAGS.state_action_fourier_dim}'
  save_dir = os.path.join(FLAGS.save_dir, config_str)

  # Embed pretraining
  if FLAGS.embed_pretraining_steps > 0 and embed_model is not None:
    model_folder = os.path.join(
        save_dir, 'embed_models%d' % FLAGS.embed_pretraining_steps,
        config_str)
    if not tf.io.gfile.isdir(model_folder):
      embed_pretraining_steps = FLAGS.embed_pretraining_steps
      for i in tqdm.tqdm(range(embed_pretraining_steps)):
        embed_dict = embed_model.update_step(embed_dataset_iter)
        if i % FLAGS.log_interval == 0:
          with summary_writer.as_default():
            for k, v in embed_dict.items():
              tf.summary.scalar(f'embed/{k}', v, step=i-embed_pretraining_steps)
              print(k, v)
            print('embed pretraining')
      embed_model.save_weights(os.path.join(model_folder, 'embed'))
    else:
      time.sleep(np.random.randint(5, 20))  # Try to suppress checksum errors.
      embed_model.load_weights(os.path.join(model_folder, 'embed'))

    if other_embed_model and other_embed_model2:
      try:  # Try to suppress checksum errors.
        other_embed_model.load_weights(os.path.join(model_folder, 'embed'))
        other_embed_model2.load_weights(os.path.join(model_folder, 'embed'))
      except:
        embed_model.save_weights(os.path.join(model_folder, 'embed'))
        other_embed_model.load_weights(os.path.join(model_folder, 'embed'))
        other_embed_model2.load_weights(os.path.join(model_folder, 'embed'))

  if FLAGS.algo_name == 'bc':
    hidden_dims = ([] if FLAGS.network == 'none' else
                   (256,) if FLAGS.network == 'small' else
                   (256, 256))
    model = behavioral_cloning.BehavioralCloning(
        env.observation_spec().shape[0],
        env.action_spec(),
        hidden_dims=hidden_dims,
        embed_model=embed_model,
        finetune=FLAGS.finetune)
  elif FLAGS.algo_name == 'latent_bc':
    hidden_dims = ([] if FLAGS.network == 'none' else
                   (256,) if FLAGS.network == 'small' else (256, 256))
    model = latent_behavioral_cloning.LatentBehavioralCloning(
        env.observation_spec().shape[0],
        env.action_spec(),
        hidden_dims=hidden_dims,
        embed_model=embed_model,
        finetune=FLAGS.finetune,
        finetune_primitive=FLAGS.finetune_primitive,
        learning_rate=FLAGS.latent_bc_lr,
        latent_bc_lr_decay=FLAGS.latent_bc_lr_decay,
        kl_regularizer=FLAGS.kl_regularizer)
  elif 'sac' in FLAGS.algo_name:
    model = sac.SAC(
        env.observation_spec().shape[0],
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0],
        embed_model=embed_model,
        other_embed_model=other_embed_model,
        network=FLAGS.network,
        finetune=FLAGS.finetune)
  elif 'brac' in FLAGS.algo_name:
    model = brac.BRAC(
        env.observation_spec().shape[0],
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0],
        embed_model=embed_model,
        other_embed_model=other_embed_model,
        bc_embed_model=other_embed_model2,
        network=FLAGS.network,
        finetune=FLAGS.finetune)

    # Agent pretraining.
    if not tf.io.gfile.isdir(os.path.join(save_dir, 'model')):
      bc_pretraining_steps = 200_000
      for i in tqdm.tqdm(range(bc_pretraining_steps)):
        if get_ctx_length():
          info_dict = model.bc.update_step(embed_dataset_iter)
        else:
          info_dict = model.bc.update_step(dataset_iter)

        if i % FLAGS.log_interval == 0:
          with summary_writer.as_default():
            for k, v in info_dict.items():
              tf.summary.scalar(
                  f'training/{k}', v, step=i - bc_pretraining_steps)
            print('bc pretraining')
      model.bc.policy.save_weights(os.path.join(save_dir, 'model'))
    else:
      model.bc.policy.load_weights(os.path.join(save_dir, 'model'))

  if train_env:
    timestep = train_env.reset()
  else:
    timestep = None

  actor = None
  if hasattr(model, 'actor'):
    actor = model.actor
  elif hasattr(model, 'policy'):
    actor = model.policy

  ctx_states = []
  ctx_actions = []
  ctx_rewards = []
  for i in tqdm.tqdm(range(FLAGS.num_updates)):
    if (train_env and timestep and
        replay_buffer and initial_collect_policy and
        add_to_replay and actor):
      if timestep.is_last():
        timestep = train_env.reset()
      if replay_buffer.num_frames() < FLAGS.num_random_actions:
        policy_step = initial_collect_policy.action(timestep)
        action = policy_step.action
        ctx_states.append(state_mask_fn(timestep.observation.numpy()))
        ctx_actions.append(action)
        ctx_rewards.append(timestep.reward)
      else:
        states = state_mask_fn(timestep.observation.numpy())
        actions = None
        rewards = None
        if get_ctx_length():
          ctx_states.append(states)
          states = tf.stack(ctx_states[-get_ctx_length():], axis=1)
          actions = tf.stack(ctx_actions[-get_ctx_length() + 1:], axis=1)
          rewards = tf.stack(ctx_rewards[-get_ctx_length() + 1:], axis=1)
        if hasattr(model, 'embed_model') and model.embed_model:
          states = model.embed_model(states, actions, rewards)
        action = actor(states, sample=True)
        ctx_actions.append(action)
      next_timestep = train_env.step(action)
      ctx_rewards.append(next_timestep.reward)
      add_to_replay(
          state_mask_fn(timestep.observation.numpy()), action,
          next_timestep.reward, next_timestep.discount,
          state_mask_fn(next_timestep.observation.numpy()))
      timestep = next_timestep

    with summary_writer.as_default():
      if embed_model and FLAGS.embed_pretraining_steps == -1:
        embed_dict = embed_model.update_step(embed_dataset_iter)
        if other_embed_model:
          other_embed_dict = other_embed_model.update_step(embed_dataset_iter)
          embed_dict.update(dict(('other_%s' % k, v) for k, v in other_embed_dict.items()))
      else:
        embed_dict = {}

      if FLAGS.downstream_mode == 'offline':
        if get_ctx_length():
          info_dict = model.update_step(embed_dataset_iter)
        else:
          info_dict = model.update_step(dataset_iter)
      elif i + 1 >= FLAGS.num_random_actions:
        info_dict = model.update_step(dataset_iter)
      else:
        info_dict = {}

    if i % FLAGS.log_interval == 0:
      with summary_writer.as_default():
        for k, v in info_dict.items():
          tf.summary.scalar(f'training/{k}', v, step=i)
        for k, v in embed_dict.items():
          tf.summary.scalar(f'embed/{k}', v, step=i)
          print(k, v)

    if (i + 1) % FLAGS.eval_interval == 0:
      average_returns, average_length = evaluation.evaluate(
          env,
          model,
          ctx_length=get_ctx_length(),
          embed_training_window=(FLAGS.embed_training_window
                                 if FLAGS.embed_learner and
                                 'action' in FLAGS.embed_learner else None),
          state_mask_fn=state_mask_fn if FLAGS.state_mask_eval else None)

      average_returns = gym_env.get_normalized_score(average_returns) * 100.0

      with result_writer.as_default():
        tf.summary.scalar('evaluation/returns', average_returns, step=i+1)
        tf.summary.scalar('evaluation/length', average_length, step=i+1)
        print('evaluation/returns', average_returns)
        print('evaluation/length', average_length)
Ejemplo n.º 21
0
def train_eval(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        initial_collect_driver_class=None,
        collect_driver_class=None,
        online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        eval_metrics=None,
        train_metrics_callback=None,
        # Params for SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=1000,
        online_critic=False,
        n_envs=None,
        finetune_sc=False,
        # Ensemble Critic training args
        n_critics=30,
        critic_learning_rate=3e-4,
        # Wcpg Critic args
        critic_preprocessing_layer_size=256,
        actor_preprocessing_layer_size=256,
        # Params for train
        batch_size=256,
        # Params for eval
        run_eval=False,
        num_eval_episodes=1,
        max_episode_len=500,
        eval_interval=10000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        monitor_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        debug_summaries=False,
        seed=None,
        eager_debug=False,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    n_envs = n_envs or num_eval_episodes
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []
    eval_metrics = eval_metrics or []
    sc_metrics = eval_metrics or []

    if online_critic:
        sc_dir = os.path.join(root_dir, 'sc')
        sc_summary_writer = tf.compat.v2.summary.create_file_writer(
            sc_dir, flush_millis=summaries_flush_secs * 1000)
        sc_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=n_envs,
                                           name='SafeAverageReturn'),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=n_envs,
                name='SafeAverageEpisodeLength')
        ] + [tf_py_metric.TFPyMetric(m) for m in sc_metrics]
        sc_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            sc_tf_env.seed([seed + i for i in range(n_envs)])

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=n_envs),
        ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            eval_tf_env.seed([seed + n_envs + i for i in range(n_envs)])

    if monitor:
        vid_path = os.path.join(root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        if seed:
            tf_env.seed(seed + 2 * n_envs + i for i in range(n_envs))
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        logging.debug('obs spec: %s', observation_spec)
        logging.debug('action spec: %s', action_spec)

        if agent_class:  #is not wcpg_agent.WcpgAgent:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=agents.normal_projection_net)
            critic_net = agents.CriticNetwork(
                (observation_spec, action_spec),
                joint_fc_layer_params=critic_joint_fc_layers)
        else:
            alpha_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                       dtype=tf.float32,
                                                       minimum=0.,
                                                       maximum=1.,
                                                       name='alpha')
            input_tensor_spec = (observation_spec, action_spec, alpha_spec)
            critic_preprocessing_layers = (
                tf.keras.layers.Dense(critic_preprocessing_layer_size),
                tf.keras.layers.Dense(critic_preprocessing_layer_size),
                tf.keras.layers.Lambda(lambda x: x))
            critic_net = agents.DistributionalCriticNetwork(
                input_tensor_spec,
                joint_fc_layer_params=critic_joint_fc_layers)
            actor_preprocessing_layers = (
                tf.keras.layers.Dense(actor_preprocessing_layer_size),
                tf.keras.layers.Dense(actor_preprocessing_layer_size),
                tf.keras.layers.Lambda(lambda x: x))
            actor_net = agents.WcpgActorNetwork(
                input_tensor_spec,
                preprocessing_layers=actor_preprocessing_layers)

        if agent_class in SAFETY_AGENTS:
            safety_critic_net = agents.CriticNetwork(
                (observation_spec, action_spec),
                joint_fc_layer_params=critic_joint_fc_layers)
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   actor_network=actor_net,
                                   critic_network=critic_net,
                                   safety_critic_network=safety_critic_net,
                                   train_step_counter=global_step,
                                   debug_summaries=debug_summaries)
        elif agent_class is ensemble_sac_agent.EnsembleSacAgent:
            critic_nets, critic_optimizers = [critic_net], [
                tf.keras.optimizers.Adam(critic_learning_rate)
            ]
            for _ in range(n_critics - 1):
                critic_nets.append(
                    agents.CriticNetwork(
                        (observation_spec, action_spec),
                        joint_fc_layer_params=critic_joint_fc_layers))
                critic_optimizers.append(
                    tf.keras.optimizers.Adam(critic_learning_rate))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   actor_network=actor_net,
                                   critic_network=critic_nets,
                                   critic_optimizers=critic_optimizers,
                                   debug_summaries=debug_summaries)
        else:  # assume is using SacAgent
            logging.debug(critic_net.input_tensor_spec)
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   actor_network=actor_net,
                                   critic_network=critic_net,
                                   train_step_counter=global_step,
                                   debug_summaries=debug_summaries)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.debug('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, batch_size=1, max_length=1000000)
        logging.debug('RB capacity: %i', replay_buffer.capacity)
        logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec,
                batch_size=1,
                max_length=max_episode_len * num_eval_episodes)
            agent_observers.append(online_replay_buffer.add_batch)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = online_replay_buffer.clear

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy
        else:
            eval_policy = tf_agent.policy  # pylint: disable=protected-access
            collect_policy = tf_agent.collect_policy  # pylint: disable=protected-access
            online_collect_policy = tf_agent._safe_policy

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        if agent_class in SAFETY_AGENTS:
            safety_critic_checkpointer = common.Checkpointer(
                ckpt_dir=sc_dir,
                safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
                global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        if agent_class in SAFETY_AGENTS:
            safety_critic_checkpointer.initialize_or_restore()

        env_metrics = []
        if env_metric_factories:
            for env_metric in env_metric_factories:
                env_metrics.append(
                    tf_py_metric.TFPyMetric(env_metric([py_env.gym])))
                # TODO: get env factory with parallel py envs
                # if run_eval:
                #   eval_metrics.append(env_metric([env.gym for env in eval_tf_env.pyenv._envs]))
                # if online_critic:
                #   sc_metrics.append(env_metric([env.gym for env in sc_tf_env.pyenv._envs]))

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics + env_metrics)
        if online_critic:
            logging.debug('online driver class: %s', online_driver_class)
            if online_driver_class is safe_dynamic_episode_driver.SafeDynamicEpisodeDriver:
                online_temp_buffer = episodic_replay_buffer.EpisodicReplayBuffer(
                    collect_data_spec)
                online_temp_buffer_stateful = episodic_replay_buffer.StatefulEpisodicReplayBuffer(
                    online_temp_buffer, num_episodes=num_eval_episodes)
                online_driver = safe_dynamic_episode_driver.SafeDynamicEpisodeDriver(
                    sc_tf_env,
                    online_collect_policy,
                    online_temp_buffer,
                    online_replay_buffer,
                    observers=[online_temp_buffer_stateful.add_batch] +
                    sc_metrics,
                    num_episodes=num_eval_episodes)
            else:
                online_driver = online_driver_class(
                    sc_tf_env,
                    online_collect_policy,
                    observers=[online_replay_buffer.add_batch] + sc_metrics,
                    num_episodes=num_eval_episodes)
            online_driver.run = common.function(online_driver.run)

        if not eager_debug:
            config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                     summarize_config=True)
            tf.function(config_saver.after_create_session)()

        if agent_class is sac_agent.SacAgent:
            collect_driver.run = common.function(collect_driver.run)
        if eager_debug:
            tf.config.experimental_run_functions_eagerly(True)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            initial_collect_driver_class(tf_env,
                                         initial_collect_policy,
                                         observers=agent_observers +
                                         train_metrics + env_metrics).run()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            if online_critic:
                last_id = online_replay_buffer._get_last_id()  # pylint: disable=protected-access
                logging.debug(
                    'Data saved in online buffer after initial collection: %d steps',
                    last_id)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='EvalMetrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3,
                sample_batch_size=batch_size,
                num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)
            critic_metrics = [
                tf.keras.metrics.AUC(name='safety_critic_auc'),
                tf.keras.metrics.TruePositives(name='safety_critic_tp'),
                tf.keras.metrics.FalsePositives(name='safety_critic_fp'),
                tf.keras.metrics.TrueNegatives(name='safety_critic_tn'),
                tf.keras.metrics.FalseNegatives(name='safety_critic_fn'),
                tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc')
            ]

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                start_time = time.time()
                experience, buf_info = next(online_iterator)
                if env_name.split('-')[0] in SAFETY_ENVS:
                    safe_rew = experience.observation['task_agn_rew'][:, 1]
                else:
                    safe_rew = misc.process_replay_buffer(online_replay_buffer,
                                                          as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience,
                                        safe_rew,
                                        metrics=critic_metrics,
                                        weights=None)
                logging.debug('critic train step: {} sec'.format(time.time() -
                                                                 start_time))
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')

        if online_critic:
            logging.debug('starting safety critic pretraining')
            safety_eps = tf_agent._safe_policy._safety_threshold
            tf_agent._safe_policy._safety_threshold = 0.6
            resample_counter = online_collect_policy._resample_counter
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_freq')
            # don't fine-tune safety critic
            if (global_step.numpy() == 0 and load_root_dir is None):
                for _ in range(train_sc_steps):
                    sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable
            tf_agent._safe_policy._safety_threshold = safety_eps

        logging.debug('starting policy pretraining')
        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            current_step = global_step.numpy()

            if online_critic:
                mean_resample_ac(resample_counter.result())
                resample_counter.reset()
                if time_step is None or time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()

            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            logging.debug('policy eval: {} sec'.format(time.time() -
                                                       start_time))

            train_time = time.time()
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)
            if current_step == 0:
                logging.debug('train policy: {} sec'.format(time.time() -
                                                            train_time))

            if online_critic and current_step % train_sc_interval == 0:
                batch_time_step = sc_tf_env.reset()
                batch_policy_state = online_collect_policy.get_initial_state(
                    sc_tf_env.batch_size)
                online_driver.run(time_step=batch_time_step,
                                  policy_state=batch_policy_state)
                for _ in range(train_sc_steps):
                    sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

                metric_utils.log_metrics(sc_metrics)
                with sc_summary_writer.as_default():
                    for sc_metric in sc_metrics:
                        sc_metric.tf_summaries(train_step=global_step,
                                               step_metrics=sc_metrics[:2])
                    tf.compat.v2.summary.scalar(name='resample_ac_freq',
                                                data=resample_ac_freq,
                                                step=global_step)

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    logging.debug(
                        'Loss diverged, critic_loss: %s, actor_loss: %s, alpha_loss: %s',
                        train_loss.extra.critic_loss,
                        train_loss.extra.actor_loss,
                        train_loss.extra.alpha_loss)
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if current_step % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            train_results = []
            for train_metric in train_metrics:
                if isinstance(train_metric, (metrics.AverageEarlyFailureMetric,
                                             metrics.AverageFallenMetric,
                                             metrics.AverageSuccessMetric)):
                    # Plot failure as a fn of return
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:3])
                else:
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:2])
                train_results.append(
                    (train_metric.name, train_metric.result().numpy()))
            if env_metrics:
                for env_metric in env_metrics:
                    env_metric.tf_summaries(train_step=global_step,
                                            step_metrics=train_metrics[:2])
                    train_results.append(
                        (env_metric.name, env_metric.result().numpy()))
            if online_critic:
                for critic_metric in critic_metrics:
                    train_results.append(
                        (critic_metric.name, critic_metric.result().numpy()))
                    critic_metric.reset_states()
            if train_metrics_callback is not None:
                train_metrics_callback(collections.OrderedDict(train_results),
                                       global_step.numpy())

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                if agent_class in SAFETY_AGENTS:
                    safety_critic_checkpointer.save(
                        global_step=global_step_val)

            if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)
            elif online_critic:
                clear_rb()

            if run_eval and global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='EvalMetrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if monitor and current_step % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                monitor_py_env.reset()
                logging.debug(
                    'saved rollout at timestep {}, rollout length: {}, {} sec'.
                    format(global_step_val, ep_len,
                           time.time() - monitor_start))

            logging.debug('iteration time: {} sec'.format(time.time() -
                                                          start_time))

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
Ejemplo n.º 22
0
num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

tempdir = '/content/drive/MyDrive/5242/Project' # @param {type:"string"}

#env_name = 'MountainCar-v0'
#train_py_env = suite_gym.load(env_name)
#eval_py_env = suite_gym.load(env_name)
#train_env = tf_py_environment.TFPyEnvironment(train_py_env)
#eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

train_py_env = gym_wrapper.GymWrapper(
      ChangeRewardMountainCarEnv(),
      discount=1,
      spec_dtype_map=None,
      auto_reset=True,
      render_kwargs=None,
  )
eval_py_env = gym_wrapper.GymWrapper(
      ChangeRewardMountainCarEnv(),
      discount=1,
      spec_dtype_map=None,
      auto_reset=True,
      render_kwargs=None,
  )
train_py_env = wrappers.TimeLimit(train_py_env, duration=200)
eval_py_env = wrappers.TimeLimit(eval_py_env, duration=200)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
Ejemplo n.º 23
0
        time_step = tf_environment.reset()

        while not time_step.is_last():
            py_environment.render()
            action_step = policy.action(time_step)
            time_step = tf_environment.step(action_step.action)

    py_environment.close()
    tf_environment.close()


df = pd.read_csv(data_path)

env = gym_wrapper.GymWrapper(
    gym.make(env_name,
             df=df,
             window_size=window_size,
             frame_bound=(window_size, len(df))))

train_env = tf_py_environment.TFPyEnvironment(env)
eval_env = tf_py_environment.TFPyEnvironment(env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(train_env.time_step_spec(),
                           train_env.action_spec(),
Ejemplo n.º 24
0
 def test_obs_dtype(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     env = gym_wrapper.GymWrapper(cartpole_env)
     time_step = env.reset()
     self.assertEqual(env.observation_spec().dtype,
                      time_step.observation.dtype)
Ejemplo n.º 25
0
def main(_):
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    hparam_str = make_hparam_string(seed=FLAGS.seed, env_name=FLAGS.env_name)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    summary_writer.set_as_default()

    if FLAGS.d4rl:
        d4rl_env = gym.make(FLAGS.env_name)
        gym_spec = gym.spec(FLAGS.env_name)
        if gym_spec.max_episode_steps in [0, None]:  # Add TimeLimit wrapper.
            gym_env = time_limit.TimeLimit(d4rl_env, max_episode_steps=1000)
        else:
            gym_env = d4rl_env
        gym_env.seed(FLAGS.seed)
        env = tf_py_environment.TFPyEnvironment(
            gym_wrapper.GymWrapper(gym_env))

        behavior_dataset = D4rlDataset(
            d4rl_env,
            normalize_states=FLAGS.normalize_states,
            normalize_rewards=FLAGS.normalize_rewards,
            noise_scale=FLAGS.noise_scale,
            bootstrap=FLAGS.bootstrap)
    else:
        env = suite_mujoco.load(FLAGS.env_name)
        env.seed(FLAGS.seed)
        env = tf_py_environment.TFPyEnvironment(env)

        data_file_name = os.path.join(
            FLAGS.data_dir, FLAGS.env_name, '0',
            f'dualdice_{FLAGS.behavior_policy_std}.pckl')
        behavior_dataset = Dataset(data_file_name,
                                   FLAGS.num_trajectories,
                                   normalize_states=FLAGS.normalize_states,
                                   normalize_rewards=FLAGS.normalize_rewards,
                                   noise_scale=FLAGS.noise_scale,
                                   bootstrap=FLAGS.bootstrap)

    tf_dataset = behavior_dataset.with_uniform_sampling(
        FLAGS.sample_batch_size)
    tf_dataset_iter = iter(tf_dataset)

    if FLAGS.d4rl:
        with tf.io.gfile.GFile(FLAGS.d4rl_policy_filename, 'rb') as f:
            policy_weights = pickle.load(f)
        actor = utils.D4rlActor(env,
                                policy_weights,
                                is_dapg='dapg' in FLAGS.d4rl_policy_filename)
    else:
        actor = Actor(env.observation_spec().shape[0], env.action_spec())
        actor.load_weights(behavior_dataset.model_filename)

    policy_returns = utils.estimate_monte_carlo_returns(
        env, FLAGS.discount, actor, FLAGS.target_policy_std,
        FLAGS.num_mc_episodes)
    logging.info('Estimated Per-Step Average Returns=%f', policy_returns)

    if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo:
        model = QFitter(env.observation_spec().shape[0],
                        env.action_spec().shape[0], FLAGS.lr,
                        FLAGS.weight_decay, FLAGS.tau)
    elif 'mb' in FLAGS.algo:
        model = ModelBased(env.observation_spec().shape[0],
                           env.action_spec().shape[0],
                           learning_rate=FLAGS.lr,
                           weight_decay=FLAGS.weight_decay)
    elif 'dual_dice' in FLAGS.algo:
        model = DualDICE(env.observation_spec().shape[0],
                         env.action_spec().shape[0], FLAGS.weight_decay)
    if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo:
        behavior = BehaviorCloning(env.observation_spec().shape[0],
                                   env.action_spec(), FLAGS.lr,
                                   FLAGS.weight_decay)

    @tf.function
    def get_target_actions(states):
        return actor(tf.cast(behavior_dataset.unnormalize_states(states),
                             env.observation_spec().dtype),
                     std=FLAGS.target_policy_std)[1]

    @tf.function
    def get_target_logprobs(states, actions):
        log_probs = actor(tf.cast(behavior_dataset.unnormalize_states(states),
                                  env.observation_spec().dtype),
                          actions=actions,
                          std=FLAGS.target_policy_std)[2]
        if tf.rank(log_probs) > 1:
            log_probs = tf.reduce_sum(log_probs, -1)
        return log_probs

    min_reward = tf.reduce_min(behavior_dataset.rewards)
    max_reward = tf.reduce_max(behavior_dataset.rewards)
    min_state = tf.reduce_min(behavior_dataset.states, 0)
    max_state = tf.reduce_max(behavior_dataset.states, 0)

    @tf.function
    def update_step():
        (states, actions, next_states, rewards, masks, weights,
         _) = next(tf_dataset_iter)
        initial_actions = get_target_actions(behavior_dataset.initial_states)
        next_actions = get_target_actions(next_states)

        if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo:
            model.update(states, actions, next_states, next_actions, rewards,
                         masks, weights, FLAGS.discount, min_reward,
                         max_reward)
        elif 'mb' in FLAGS.algo:
            model.update(states, actions, next_states, rewards, masks, weights)
        elif 'dual_dice' in FLAGS.algo:
            model.update(behavior_dataset.initial_states, initial_actions,
                         behavior_dataset.initial_weights, states, actions,
                         next_states, next_actions, masks, weights,
                         FLAGS.discount)

        if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo:
            behavior.update(states, actions, weights)

    gc.collect()

    for i in tqdm.tqdm(range(FLAGS.num_updates), desc='Running Training'):
        update_step()

        if i % FLAGS.eval_interval == 0:
            if 'fqe' in FLAGS.algo:
                pred_returns = model.estimate_returns(
                    behavior_dataset.initial_states,
                    behavior_dataset.initial_weights, get_target_actions)
            elif 'mb' in FLAGS.algo:
                pred_returns = model.estimate_returns(
                    behavior_dataset.initial_states,
                    behavior_dataset.initial_weights, get_target_actions,
                    FLAGS.discount, min_reward, max_reward, min_state,
                    max_state)
            elif FLAGS.algo in ['dual_dice']:
                pred_returns, pred_ratio = model.estimate_returns(
                    iter(tf_dataset))

                tf.summary.scalar('train/pred ratio', pred_ratio, step=i)
            elif 'iw' in FLAGS.algo or 'dr' in FLAGS.algo:
                discount = FLAGS.discount
                _, behavior_log_probs = behavior(behavior_dataset.states,
                                                 behavior_dataset.actions)
                target_log_probs = get_target_logprobs(
                    behavior_dataset.states, behavior_dataset.actions)
                offset = 0.0
                rewards = behavior_dataset.rewards
                if 'dr' in FLAGS.algo:
                    # Doubly-robust is effectively the same as importance-weighting but
                    # transforming rewards at (s,a) to r(s,a) + gamma * V^pi(s') -
                    # Q^pi(s,a) and adding an offset to each trajectory equal to V^pi(s0).
                    offset = model.estimate_returns(
                        behavior_dataset.initial_states,
                        behavior_dataset.initial_weights, get_target_actions)
                    q_values = (model(behavior_dataset.states,
                                      behavior_dataset.actions) /
                                (1 - discount))
                    n_samples = 10
                    next_actions = [
                        get_target_actions(behavior_dataset.next_states)
                        for _ in range(n_samples)
                    ]
                    next_q_values = sum([
                        model(behavior_dataset.next_states, next_action) /
                        (1 - discount) for next_action in next_actions
                    ]) / n_samples
                    rewards = rewards + discount * next_q_values - q_values

                # Now we compute the self-normalized importance weights.
                # Self-normalization happens over trajectories per-step, so we
                # restructure the dataset as [num_trajectories, num_steps].
                num_trajectories = len(behavior_dataset.initial_states)
                max_trajectory_length = np.max(behavior_dataset.steps) + 1
                trajectory_weights = behavior_dataset.initial_weights
                trajectory_starts = np.where(
                    np.equal(behavior_dataset.steps, 0))[0]

                batched_rewards = np.zeros(
                    [num_trajectories, max_trajectory_length])
                batched_masks = np.zeros(
                    [num_trajectories, max_trajectory_length])
                batched_log_probs = np.zeros(
                    [num_trajectories, max_trajectory_length])

                for traj_idx, traj_start in enumerate(trajectory_starts):
                    traj_end = (trajectory_starts[traj_idx + 1] if traj_idx +
                                1 < len(trajectory_starts) else len(rewards))
                    traj_length = traj_end - traj_start
                    batched_rewards[
                        traj_idx, :traj_length] = rewards[traj_start:traj_end]
                    batched_masks[traj_idx, :traj_length] = 1.
                    batched_log_probs[traj_idx, :traj_length] = (
                        -behavior_log_probs[traj_start:traj_end] +
                        target_log_probs[traj_start:traj_end])

                batched_weights = (
                    batched_masks *
                    (discount**np.arange(max_trajectory_length))[None, :])

                clipped_log_probs = np.clip(batched_log_probs, -6., 2.)
                cum_log_probs = batched_masks * np.cumsum(clipped_log_probs,
                                                          axis=1)
                cum_log_probs_offset = np.max(cum_log_probs, axis=0)
                cum_probs = np.exp(cum_log_probs -
                                   cum_log_probs_offset[None, :])
                avg_cum_probs = (
                    np.sum(cum_probs * trajectory_weights[:, None], axis=0) /
                    (1e-10 + np.sum(
                        batched_masks * trajectory_weights[:, None], axis=0)))
                norm_cum_probs = cum_probs / (1e-10 + avg_cum_probs[None, :])

                weighted_rewards = batched_weights * batched_rewards * norm_cum_probs
                trajectory_values = np.sum(weighted_rewards, axis=1)
                avg_trajectory_value = (
                    (1 - discount) *
                    np.sum(trajectory_values * trajectory_weights) /
                    np.sum(trajectory_weights))
                pred_returns = offset + avg_trajectory_value

            pred_returns = behavior_dataset.unnormalize_rewards(pred_returns)

            tf.summary.scalar('train/pred returns', pred_returns, step=i)
            logging.info('pred returns=%f', pred_returns)

            tf.summary.scalar('train/true minus pred returns',
                              policy_returns - pred_returns,
                              step=i)
            logging.info('true minus pred returns=%f',
                         policy_returns - pred_returns)
Ejemplo n.º 26
0
 def test_render(self):
   cartpole_env = gym.spec('CartPole-v1').make()
   cartpole_env.render = mock.MagicMock()
   env = gym_wrapper.GymWrapper(cartpole_env)
   env.render()
   cartpole_env.render.assert_called_once()
Ejemplo n.º 27
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 28
0
def train_eval(
    root_dir,
    load_root_dir=None,
    env_load_fn=None,
    gym_env_wrappers=[],
    monitor=False,
    env_name=None,
    agent_class=None,
    initial_collect_driver_class=None,
    collect_driver_class=None,
    online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver,
    num_global_steps=1000000,
    rb_size=None,
    train_steps_per_iteration=1,
    train_metrics=None,
    eval_metrics=None,
    train_metrics_callback=None,
    # SacAgent args
    actor_fc_layers=(256, 256),
    critic_joint_fc_layers=(256, 256),
    # Safety Critic training args
    sc_rb_size=None,
    target_safety=None,
    train_sc_steps=10,
    train_sc_interval=1000,
    online_critic=False,
    n_envs=None,
    finetune_sc=False,
    pretraining=True,
    lambda_schedule_nsteps=0,
    lambda_initial=0.,
    lambda_final=1.,
    kstep_fail=0,
    # Ensemble Critic training args
    num_critics=None,
    critic_learning_rate=3e-4,
    # Wcpg Critic args
    critic_preprocessing_layer_size=256,
    # Params for train
    batch_size=256,
    # Params for eval
    run_eval=False,
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    keep_rb_checkpoint=False,
    log_interval=1000,
    summary_interval=1000,
    monitor_interval=5000,
    summaries_flush_secs=10,
    early_termination_fn=None,
    debug_summaries=False,
    seed=None,
    eager_debug=False,
    env_metric_factories=None,
    wandb=False):  # pylint: disable=unused-argument

  """train and eval script for SQRL."""
  if isinstance(agent_class, str):
    assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class)
    agent_class = ALGOS.get(agent_class)
  n_envs = n_envs or num_eval_episodes
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  # =====================================================================#
  #  Setup summary metrics, file writers, and create env                 #
  # =====================================================================#
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
    train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  train_metrics = train_metrics or []
  eval_metrics = eval_metrics or []

  updating_sc = online_critic and (not load_root_dir or finetune_sc)
  logging.debug('updating safety critic: %s', updating_sc)

  if seed:
    tf.compat.v1.set_random_seed(seed)

  if agent_class in SAFETY_AGENTS:
    if online_critic:
      sc_tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
          [lambda: env_load_fn(env_name)] * n_envs
        ))
      if seed:
        seeds = [seed * n_envs + i for i in range(n_envs)]
        try:
          sc_tf_env.pyenv.seed(seeds)
        except:
          pass

  if run_eval:
    eval_dir = os.path.join(root_dir, 'eval')
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
                     tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                     tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                   ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics]
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name)] * n_envs
      ))
    if seed:
      try:
        for i, pyenv in enumerate(eval_tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
  elif 'Drunk' in env_name:
    # Just visualizes trajectories in drunk spider environment
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      env_load_fn(env_name))
  else:
    eval_tf_env = None

  if monitor:
    vid_path = os.path.join(root_dir, 'rollouts')
    monitor_env_wrapper = misc.monitor_freq(1, vid_path)
    monitor_env = gym.make(env_name)
    for wrapper in gym_env_wrappers:
      monitor_env = wrapper(monitor_env)
    monitor_env = monitor_env_wrapper(monitor_env)
    # auto_reset must be False to ensure Monitor works correctly
    monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

  global_step = tf.compat.v1.train.get_or_create_global_step()

  with tf.summary.record_if(
          lambda: tf.math.equal(global_step % summary_interval, 0)):
    py_env = env_load_fn(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    if seed:
      try:
        for i, pyenv in enumerate(tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    logging.debug('obs spec: %s', observation_spec)
    logging.debug('action spec: %s', action_spec)

    # =====================================================================#
    #  Setup agent class                                                   #
    # =====================================================================#

    if agent_class == wcpg_agent.WcpgAgent:
      alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1.,
                                                 name='alpha')
      input_tensor_spec = (observation_spec, action_spec, alpha_spec)
      critic_net = agents.DistributionalCriticNetwork(
        input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size,
        joint_fc_layer_params=critic_joint_fc_layers)
      actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)
      critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
      logging.debug('Making SQRL agent')
      if lambda_schedule_nsteps > 0:
        lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps
        step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps
        lambda_scheduler = lambda lam: common.periodically(
          body=lambda: tf.group(lam.assign(lam + step_size)),
          period=lambda_update_every_nsteps)
      else:
        lambda_scheduler = None
      safety_critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)
      ts = target_safety
      thresholds = [ts, 0.5]
      sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'),
                    tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                                    thresholds=thresholds),
                    tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                                    thresholds=thresholds),
                    tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                                    threshold=0.5)]
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        safety_critic_network=safety_critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries,
        safety_pretraining=pretraining,
        train_critic_online=online_critic,
        initial_log_lambda=lambda_initial,
        log_lambda=(lambda_scheduler is None),
        lambda_scheduler=lambda_scheduler)
    elif agent_class is ensemble_sac_agent.EnsembleSacAgent:
      critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)]
      for _ in range(num_critics - 1):
        critic_nets.append(agents.CriticNetwork((observation_spec, action_spec),
                                                joint_fc_layer_params=critic_joint_fc_layers))
        critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate))
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_networks=critic_nets,
        critic_optimizers=critic_optimizers,
        debug_summaries=debug_summaries
      )
    else:  # agent is either SacAgent or WcpgAgent
      logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec)
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries)

    tf_agent.initialize()

    # =====================================================================#
    #  Setup replay buffer                                                 #
    # =====================================================================#
    collect_data_spec = tf_agent.collect_data_spec

    logging.debug('Allocating replay buffer ...')
    # Add to replay buffer and other agent specific observers.
    rb_size = rb_size or 1000000
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      collect_data_spec,
      batch_size=1,
      max_length=rb_size)

    logging.debug('RB capacity: %i', replay_buffer.capacity)
    logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec)

    if agent_class in SAFETY_AGENTS:
      sc_rb_size = sc_rb_size or num_eval_episodes * 500
      sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=sc_rb_size,
        dataset_window_shift=1)

    num_episodes = tf_metrics.NumberOfEpisodes()
    num_env_steps = tf_metrics.EnvironmentSteps()
    return_metric = tf_metrics.AverageReturnMetric(
      buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
    train_metrics = [
                      num_episodes, num_env_steps,
                      return_metric,
                      tf_metrics.AverageEpisodeLengthMetric(
                        buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
                    ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    if 'Minitaur' in env_name and not pretraining:
      goal_vel = gin.query_parameter("%GOAL_VELOCITY")
      early_termination_fn = train_utils.MinitaurTerminationFn(
        speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3],
        env_steps_metric=num_env_steps, goal_speed=goal_vel)

    if env_metric_factories:
      for env_metric in env_metric_factories:
        train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs)))
        if run_eval:
          eval_metrics.append(env_metric([env for env in
                                          eval_tf_env.pyenv._envs]))

    # =====================================================================#
    #  Setup collect policies                                              #
    # =====================================================================#
    if not online_critic:
      eval_policy = tf_agent.policy
      collect_policy = tf_agent.collect_policy
      if not pretraining and agent_class in SAFETY_AGENTS:
        collect_policy = tf_agent.safe_policy
    else:
      eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      online_collect_policy = tf_agent.safe_policy  # if pretraining else tf_agent.collect_policy
      if pretraining:
        online_collect_policy._training = False

    if not load_root_dir:
      initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec)
    else:
      initial_collect_policy = collect_policy
    if agent_class == wcpg_agent.WcpgAgent:
      initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy)

    # =====================================================================#
    #  Setup Checkpointing                                                 #
    # =====================================================================#
    train_checkpointer = common.Checkpointer(
      ckpt_dir=train_dir,
      agent=tf_agent,
      global_step=global_step,
      metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
      ckpt_dir=os.path.join(train_dir, 'policy'),
      policy=eval_policy,
      global_step=global_step)

    rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
    rb_checkpointer = common.Checkpointer(
      ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer)

    if online_critic:
      online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer')
      online_rb_checkpointer = common.Checkpointer(
        ckpt_dir=online_rb_ckpt_dir,
        max_to_keep=1,
        replay_buffer=sc_buffer)

    # loads agent, replay buffer, and online sc/buffer if online_critic
    if load_root_dir:
      load_root_dir = os.path.expanduser(load_root_dir)
      load_train_dir = os.path.join(load_root_dir, 'train')
      misc.load_agent_ckpt(load_train_dir, tf_agent)
      if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1:
        load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer')
        misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer)
      if online_critic:
        load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc')
        load_online_rb_ckpt_dir = os.path.join(load_train_dir,
                                               'online_replay_buffer')
        if osp.exists(load_online_rb_ckpt_dir):
          misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer)
        if osp.exists(load_online_sc_ckpt_dir):
          misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir,
                                       safety_critic_net)
      elif agent_class in SAFETY_AGENTS:
        offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1]
        load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline',
                                        offline_run, 'safety_critic')
        if osp.exists(load_sc_ckpt_dir):
          sc_net_off = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=(512, 512),
            name='SafetyCriticOffline')
          sc_net_off.create_variables()
          target_sc_net_off = common.maybe_copy_target_network_with_checks(
            sc_net_off, None, 'TargetSafetyCriticNetwork')
          sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate)
          _ = misc.load_safety_critic_ckpt(
            load_sc_ckpt_dir, safety_critic_net=sc_net_off,
            target_safety_critic=target_sc_net_off,
            optimizer=sc_optimizer)
          tf_agent._safety_critic_network = sc_net_off
          tf_agent._target_safety_critic_network = target_sc_net_off
          tf_agent._safety_critic_optimizer = sc_optimizer
    else:
      train_checkpointer.initialize_or_restore()
      rb_checkpointer.initialize_or_restore()
      if online_critic:
        online_rb_checkpointer.initialize_or_restore()

    if agent_class in SAFETY_AGENTS:
      sc_dir = os.path.join(root_dir, 'sc')
      safety_critic_checkpointer = common.Checkpointer(
        ckpt_dir=sc_dir,
        safety_critic=tf_agent._safety_critic_network,
        # pylint: disable=protected-access
        target_safety_critic=tf_agent._target_safety_critic_network,
        optimizer=tf_agent._safety_critic_optimizer,
        global_step=global_step)

      if not (load_root_dir and not online_critic):
        safety_critic_checkpointer.initialize_or_restore()

    agent_observers = [replay_buffer.add_batch] + train_metrics
    collect_driver = collect_driver_class(
      tf_env, collect_policy, observers=agent_observers)
    collect_driver.run = common.function_in_tf1()(collect_driver.run)

    if online_critic:
      logging.debug('online driver class: %s', online_driver_class)
      online_agent_observers = [num_episodes, num_env_steps,
                                sc_buffer.add_batch]
      online_driver = online_driver_class(
        sc_tf_env, online_collect_policy, observers=online_agent_observers,
        num_episodes=num_eval_episodes)
      online_driver.run = common.function_in_tf1()(online_driver.run)

    if eager_debug:
      tf.config.experimental_run_functions_eagerly(True)
    else:
      config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True)
      tf.function(config_saver.after_create_session)()

    if global_step == 0:
      logging.info('Performing initial collection ...')
      init_collect_observers = agent_observers
      if agent_class in SAFETY_AGENTS:
        init_collect_observers += [sc_buffer.add_batch]
      initial_collect_driver_class(
        tf_env,
        initial_collect_policy,
        observers=init_collect_observers).run()
      last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
      logging.info('Data saved after initial collection: %d steps', last_id)
      if agent_class in SAFETY_AGENTS:
        last_id = sc_buffer._get_last_id()  # pylint: disable=protected-access
        logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id)

    if run_eval:
      results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='EvalMetrics',
      )
      if train_metrics_callback is not None:
        train_metrics_callback(results, global_step.numpy())
      metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size)

    if agent_class in SAFETY_AGENTS:
      critic_train_step = train_utils.get_critic_train_step(
        tf_agent, replay_buffer, sc_buffer, batch_size=batch_size,
        updating_sc=updating_sc, metrics=sc_metrics)

    if early_termination_fn is None:
      early_termination_fn = lambda: False

    loss_diverged = False
    # How many consecutive steps was loss diverged for.
    loss_divergence_counter = 0
    mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')

    if agent_class in SAFETY_AGENTS:
      resample_counter = collect_policy._resample_counter
      mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq')
      sc_metrics.append(mean_resample_ac)

      if online_critic:
        logging.debug('starting safety critic pretraining')
        # don't fine-tune safety critic
        if global_step.numpy() == 0:
          for _ in range(train_sc_steps):
            sc_loss, lambda_loss = critic_train_step()
          critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())]
          for critic_metric in sc_metrics:
            res = critic_metric.result().numpy()
            if not res.shape:
              critic_results.append((critic_metric.name, res))
            else:
              for r, thresh in zip(res, thresholds):
                name = '_'.join([critic_metric.name, str(thresh)])
                critic_results.append((name, r))
            critic_metric.reset_states()
          if train_metrics_callback:
            train_metrics_callback(collections.OrderedDict(critic_results),
                                   step=global_step.numpy())

    logging.debug('Starting main train loop...')
    curr_ep = []
    global_step_val = global_step.numpy()
    while global_step_val <= num_global_steps and not early_termination_fn():
      start_time = time.time()

      # MEASURE ACTION RESAMPLING FREQUENCY
      if agent_class in SAFETY_AGENTS:
        if pretraining and global_step_val == num_global_steps // 2:
          if online_critic:
            online_collect_policy._training = True
          collect_policy._training = True
        if online_critic or collect_policy._training:
          mean_resample_ac(resample_counter.result())
          resample_counter.reset()
          if time_step is None or time_step.is_last():
            resample_ac_freq = mean_resample_ac.result()
            mean_resample_ac.reset_states()
            tf.compat.v2.summary.scalar(
              name='resample_ac_freq', data=resample_ac_freq, step=global_step)

      # RUN COLLECTION
      time_step, policy_state = collect_driver.run(
        time_step=time_step,
        policy_state=policy_state,
      )

      # get last step taken by step_driver
      traj = replay_buffer._data_table.read(replay_buffer._get_last_id() %
                                            replay_buffer._capacity)
      curr_ep.append(traj)

      if time_step.is_last():
        if agent_class in SAFETY_AGENTS:
          if time_step.observation['task_agn_rew']:
            if kstep_fail:
              # applies task agn rew. over last k steps
              for i, traj in enumerate(curr_ep[-kstep_fail:]):
                traj.observation['task_agn_rew'] = 1.
                sc_buffer.add_batch(traj)
            else:
              [sc_buffer.add_batch(traj) for traj in curr_ep]
        curr_ep = []
        if agent_class == wcpg_agent.WcpgAgent:
          collect_policy._alpha = None  # reset WCPG alpha

      if (global_step_val + 1) % log_interval == 0:
        logging.debug('policy eval: %4.2f sec', time.time() - start_time)

      # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY)
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
        mean_train_loss(train_loss.loss)

      current_step = global_step.numpy()
      total_loss = mean_train_loss.result()
      mean_train_loss.reset_states()

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(
          collections.OrderedDict([(k, v.numpy()) for k, v in
                                   train_loss.extra._asdict().items()]),
          step=current_step)
        train_metrics_callback(
          {'train_loss': total_loss.numpy()}, step=current_step)

      # TRAIN AND/OR EVAL SAFETY CRITIC
      if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0:
        if online_critic:
          batch_time_step = sc_tf_env.reset()

          # run online critic training collect & update
          batch_policy_state = online_collect_policy.get_initial_state(
            sc_tf_env.batch_size)
          online_driver.run(time_step=batch_time_step,
                            policy_state=batch_policy_state)
        for _ in range(train_sc_steps):
          sc_loss, lambda_loss = critic_train_step()
        # log safety_critic loss results
        critic_results = [('sc_loss', sc_loss.numpy()),
                          ('lambda_loss', lambda_loss.numpy())]
        metric_utils.log_metrics(sc_metrics)
        for critic_metric in sc_metrics:
          res = critic_metric.result().numpy()
          if not res.shape:
            critic_results.append((critic_metric.name, res))
          else:
            for r, thresh in zip(res, thresholds):
              name = '_'.join([critic_metric.name, str(thresh)])
              critic_results.append((name, r))
          critic_metric.reset_states()
        if train_metrics_callback and current_step % summary_interval == 0:
          train_metrics_callback(collections.OrderedDict(critic_results),
                                 step=current_step)

      # Check for exploding losses.
      if (math.isnan(total_loss) or math.isinf(total_loss) or
              total_loss > MAX_LOSS):
        loss_divergence_counter += 1
        if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
          loss_diverged = True
          logging.info('Loss diverged, critic_loss: %s, actor_loss: %s',
                       train_loss.extra.critic_loss,
                       train_loss.extra.actor_loss)
          break
      else:
        loss_divergence_counter = 0

      time_acc += time.time() - start_time

      # LOGGING AND METRICS
      if current_step % log_interval == 0:
        metric_utils.log_metrics(train_metrics)
        logging.info('step = %d, loss = %f', current_step, total_loss)
        steps_per_sec = (current_step - timed_at_step) / time_acc
        logging.info('%4.2f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = current_step
        time_acc = 0

      train_results = []

      for metric in train_metrics[2:]:
        if isinstance(metric, (metrics.AverageEarlyFailureMetric,
                               metrics.AverageFallenMetric,
                               metrics.AverageSuccessMetric)):
          # Plot failure as a fn of return
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_episodes,
                                                  return_metric])
        else:
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_env_steps])
        train_results.append((metric.name, metric.result().numpy()))

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(collections.OrderedDict(train_results),
                               step=global_step.numpy())

      if current_step % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=current_step)

      if current_step % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=current_step)
        if agent_class in SAFETY_AGENTS:
          safety_critic_checkpointer.save(global_step=current_step)
          if online_critic:
            online_rb_checkpointer.save(global_step=current_step)

      if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=current_step)

      if wandb and current_step % eval_interval == 0 and "Drunk" in env_name:
        misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step)
        if online_critic:
          misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy,
                                         current_step, 'safe-trajectory')

      if run_eval and current_step % eval_interval == 0:
        eval_results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=num_eval_episodes,
          train_step=global_step,
          summary_writer=eval_summary_writer,
          summary_prefix='EvalMetrics',
        )
        if train_metrics_callback is not None:
          train_metrics_callback(eval_results, current_step)
        metric_utils.log_metrics(eval_metrics)

        with eval_summary_writer.as_default():
          for eval_metric in eval_metrics[2:]:
            eval_metric.tf_summaries(train_step=global_step,
                                     step_metrics=eval_metrics[:2])

      if monitor and current_step % monitor_interval == 0:
        monitor_time_step = monitor_py_env.reset()
        monitor_policy_state = eval_policy.get_initial_state(1)
        ep_len = 0
        monitor_start = time.time()
        while not monitor_time_step.is_last():
          monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state)
          action, monitor_policy_state = monitor_action.action, monitor_action.state
          monitor_time_step = monitor_py_env.step(action)
          ep_len += 1
        logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                      current_step, ep_len, time.time() - monitor_start)

      global_step_val = current_step

  if early_termination_fn():
    #  Early stopped, save all checkpoints if not saved
    if global_step_val % train_checkpoint_interval != 0:
      train_checkpointer.save(global_step=global_step_val)

    if global_step_val % policy_checkpoint_interval != 0:
      policy_checkpointer.save(global_step=global_step_val)
      if agent_class in SAFETY_AGENTS:
        safety_critic_checkpointer.save(global_step=global_step_val)
        if online_critic:
          online_rb_checkpointer.save(global_step=global_step_val)

    if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0:
      rb_checkpointer.save(global_step=global_step_val)

  if not keep_rb_checkpoint:
    misc.cleanup_checkpoints(rb_ckpt_dir)

  if loss_diverged:
    # Raise an error at the very end after the cleanup.
    raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
      total_loss, global_step.numpy()))

  return total_loss
Ejemplo n.º 29
0
replay_buffer_capacity = 100000

fc_layer_params = (100, )

batch_size = 64
learning_rate = 1e-3
log_interval = 200

num_eval_episodes = 10
eval_interval = 1000
eval_interval = 50

# env_name = 'CartPole-v0'
env_name = 'oscillator-v0'
train_env = envs.make(env_name)
train_py_env = gym_wrapper.GymWrapper(train_env)
eval_env = envs.make(env_name)
eval_py_env = gym_wrapper.GymWrapper(eval_env)

print('Observation spec: {}'.format(train_py_env.time_step_spec().observation))
print('Action spec: {}'.format(train_py_env.action_spec()))

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
# optimizer = tf.optimizers.Adam(learning_rate=learning_rate)