예제 #1
0
파일: tdm.py 프로젝트: zhackzey/leap
    def update_sampler_and_rollout_function(self):
        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            qf=self.qf1,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            tau_sampling_function=self._sample_max_tau_for_rollout,
            cycle_taus_for_rollout=self.cycle_taus_for_rollout,
            render=self.render_during_eval,
            observation_key=self.observation_key,
            desired_goal_key=self.desired_goal_key,
        )

        # the rl_algorithm constructor is called before the tdm's, so
        # initializing the rollout function must be done here instead of
        # overriding the function
        from railrl.samplers.rollout_functions import create_rollout_function, tdm_rollout

        self.train_rollout_function = create_rollout_function(
            tdm_rollout,
            init_tau=self.max_tau_for_rollout,
            cycle_tau=self.cycle_taus_for_rollout,
            decrement_tau=self.cycle_taus_for_rollout,
            observation_key=self.observation_key,
            desired_goal_key=self.desired_goal_key,
        )
        self.eval_rollout_function = self.train_rollout_function
예제 #2
0
 def __init__(self,
              env,
              qf,
              policy,
              exploration_policy,
              replay_buffer,
              obs_normalizer: TorchNormalizer = None,
              goal_normalizer: TorchNormalizer = None,
              eval_sampler=None,
              epsilon=1e-4,
              num_steps_per_eval=1000,
              max_path_length=1000,
              terminate_when_goal_reached=False,
              pre_activation_weight=1.,
              **kwargs):
     assert isinstance(replay_buffer, HerReplayBuffer)
     assert eval_sampler is None
     super().__init__(env,
                      qf,
                      policy,
                      exploration_policy,
                      replay_buffer=replay_buffer,
                      eval_sampler=eval_sampler,
                      num_steps_per_eval=num_steps_per_eval,
                      max_path_length=max_path_length,
                      **kwargs)
     self.obs_normalizer = obs_normalizer
     self.goal_normalizer = goal_normalizer
     self.eval_sampler = MultigoalSimplePathSampler(
         env=env,
         policy=self.target_policy,
         max_samples=num_steps_per_eval,
         max_path_length=max_path_length,
         tau_sampling_function=self._sample_tau_for_rollout,
         goal_sampling_function=self._sample_goal_for_rollout,
         cycle_taus_for_rollout=False,
     )
     self.epsilon = epsilon
     assert self.qf_weight_decay == 0
     assert self.residual_gradient_weight == 0
     self.terminate_when_goal_reached = terminate_when_goal_reached
     self.pre_activation_weight = pre_activation_weight
     self._current_path_goal = None
예제 #3
0
    def __init__(
        self,
        max_tau=10,
        epoch_max_tau_schedule=None,
        sample_train_goals_from='replay_buffer',
        sample_rollout_goals_from='environment',
        cycle_taus_for_rollout=False,
    ):
        """
        :param max_tau: Maximum tau (planning horizon) to train with.
        :param epoch_max_tau_schedule: A schedule for the maximum planning
        horizon tau.
        :param sample_train_goals_from: Sampling strategy for goals used in
        training. Can be one of the following strings:
            - environment: Sample from the environment
            - replay_buffer: Sample from the replay_buffer
            - her: Sample from a HER-based replay_buffer
        :param sample_rollout_goals_from: Sampling strategy for goals used
        during rollout. Can be one of the following strings:
            - environment: Sample from the environment
            - replay_buffer: Sample from the replay_buffer
            - fixed: Do no resample the goal. Just use the one in the
            environment.
        :param vectorized: Train the QF in vectorized form?
        :param cycle_taus_for_rollout: Decrement the tau passed into the
        policy during rollout?
        """
        assert sample_train_goals_from in [
            'environment', 'replay_buffer', 'her'
        ]
        assert sample_rollout_goals_from in [
            'environment', 'replay_buffer', 'fixed'
        ]
        if epoch_max_tau_schedule is None:
            epoch_max_tau_schedule = ConstantSchedule(max_tau)

        self.max_tau = max_tau
        self.epoch_max_tau_schedule = epoch_max_tau_schedule
        self.sample_train_goals_from = sample_train_goals_from
        self.sample_rollout_goals_from = sample_rollout_goals_from
        self.cycle_taus_for_rollout = cycle_taus_for_rollout
        self._current_path_goal = None
        self._rollout_tau = self.max_tau

        self.policy = MakeUniversal(self.policy)
        self.eval_policy = MakeUniversal(self.eval_policy)
        self.exploration_policy = MakeUniversal(self.exploration_policy)
        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            discount_sampling_function=self._sample_max_tau_for_rollout,
            goal_sampling_function=self._sample_goal_for_rollout,
            cycle_taus_for_rollout=self.cycle_taus_for_rollout,
        )
        if self.collection_mode == 'online-parallel':
            # TODO(murtaza): What happens to the eval env?
            # see `eval_sampler` definition above.

            self.training_env = RemoteRolloutEnv(
                env=self.env,
                policy=self.eval_policy,
                exploration_policy=self.exploration_policy,
                max_path_length=self.max_path_length,
                normalize_env=self.normalize_env,
                rollout_function=self.rollout,
            )
예제 #4
0
    def __init__(
        self,
        max_tau=10,
        epoch_max_tau_schedule=None,
        vectorized=True,
        cycle_taus_for_rollout=True,
        dense_rewards=False,
        finite_horizon=True,
        tau_sample_strategy='uniform',
        goal_reached_epsilon=1e-3,
        terminate_when_goal_reached=False,
        truncated_geom_factor=2.,
        square_distance=False,
        goal_weights=None,
        normalize_distance=False,
        observation_key=None,
        desired_goal_key=None,
    ):
        """

        :param max_tau: Maximum tau (planning horizon) to train with.
        :param epoch_max_tau_schedule: A schedule for the maximum planning
        horizon tau.
        :param vectorized: Train the QF in vectorized form?
        :param cycle_taus_for_rollout: Decrement the tau passed into the
        policy during rollout?
        :param dense_rewards: If True, always give rewards. Otherwise,
        only give rewards when the episode terminates.
        :param finite_horizon: If True, use a finite horizon formulation:
        give the time as input to the Q-function and terminate.
        :param tau_sample_strategy: Sampling strategy for taus used
        during training. Can be one of the following strings:
            - no_resampling: Do not resample the tau. Use the one from rollout.
            - uniform: Sample uniformly from [0, max_tau]
            - truncated_geometric: Sample from a truncated geometric
            distribution, truncated at max_tau.
            - all_valid: Always use all 0 to max_tau values
        :param goal_reached_epsilon: Epsilon used to determine if the goal
        has been reached. Used by `indicator` version of `reward_type` and when
        `terminate_whe_goal_reached` is True.
        :param terminate_when_goal_reached: Do you terminate when you have
        reached the goal?
        :param goal_weights: None or the weights for the different goal
        dimensions. These weights are used to compute the distances to the goal.
        """
        assert tau_sample_strategy in [
            'no_resampling',
            'uniform',
            'truncated_geometric',
            'all_valid',
        ]
        if epoch_max_tau_schedule is None:
            epoch_max_tau_schedule = ConstantSchedule(max_tau)

        if not finite_horizon:
            max_tau = 0
            epoch_max_tau_schedule = ConstantSchedule(max_tau)
            cycle_taus_for_rollout = False

        self.max_tau = max_tau
        self.epoch_max_tau_schedule = epoch_max_tau_schedule
        self.vectorized = vectorized
        self.cycle_taus_for_rollout = cycle_taus_for_rollout
        self.dense_rewards = dense_rewards
        self.finite_horizon = finite_horizon
        self.tau_sample_strategy = tau_sample_strategy
        self.goal_reached_epsilon = goal_reached_epsilon
        self.terminate_when_goal_reached = terminate_when_goal_reached
        self.square_distance = square_distance
        self._rollout_tau = np.array([self.max_tau])
        self.truncated_geom_factor = float(truncated_geom_factor)
        self.goal_weights = goal_weights
        if self.goal_weights is not None:
            # In case they were passed in as (e.g.) tuples or list
            self.goal_weights = np.array(self.goal_weights)
            assert self.goal_weights.size == self.env.goal_dim
        self.normalize_distance = normalize_distance

        self.observation_key = observation_key
        self.desired_goal_key = desired_goal_key

        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            tau_sampling_function=self._sample_max_tau_for_rollout,
            cycle_taus_for_rollout=self.cycle_taus_for_rollout,
            render=self.render_during_eval,
            observation_key=self.observation_key,
            desired_goal_key=self.desired_goal_key,
        )
        self.pretrain_obs = None
예제 #5
0
    def __init__(self,
                 env,
                 exploration_policy,
                 beta_q,
                 beta_q2,
                 beta_v,
                 policy,
                 train_with='both',
                 goal_reached_epsilon=1e-3,
                 learning_rate=1e-3,
                 prioritized_replay=False,
                 always_reset_env=True,
                 finite_horizon=False,
                 max_num_steps_left=0,
                 flip_training_period=100,
                 train_simultaneously=True,
                 policy_and_target_update_period=2,
                 target_policy_noise=0.2,
                 target_policy_noise_clip=0.5,
                 soft_target_tau=0.005,
                 per_beta_schedule=None,
                 **kwargs):
        self.train_simultaneously = train_simultaneously
        assert train_with in ['both', 'off_policy', 'on_policy']
        super().__init__(env, exploration_policy, **kwargs)
        self.eval_sampler = MultigoalSimplePathSampler(
            env=self.env,
            policy=self.eval_policy,
            max_samples=self.num_steps_per_eval,
            max_path_length=self.max_path_length,
            tau_sampling_function=lambda: 0,
            goal_sampling_function=self.env.sample_goal_for_rollout,
            cycle_taus_for_rollout=False,
            render=self.render_during_eval)
        self.goal_reached_epsilon = goal_reached_epsilon
        self.beta_q = beta_q
        self.beta_v = beta_v
        self.beta_q2 = beta_q2
        self.target_beta_q = self.beta_q.copy()
        self.target_beta_q2 = self.beta_q2.copy()
        self.train_with = train_with
        self.policy = policy
        self.target_policy = policy
        self.prioritized_replay = prioritized_replay
        self.flip_training_period = flip_training_period

        self.always_reset_env = always_reset_env
        self.finite_horizon = finite_horizon
        self.max_num_steps_left = max_num_steps_left
        assert max_num_steps_left >= 0

        self.policy_and_target_update_period = policy_and_target_update_period
        self.target_policy_noise = target_policy_noise
        self.target_policy_noise_clip = target_policy_noise_clip
        self.soft_target_tau = soft_target_tau
        if per_beta_schedule is None:
            per_beta_schedule = ConstantSchedule(1.0)
        self.per_beta_schedule = per_beta_schedule

        self.beta_q_optimizer = Adam(self.beta_q.parameters(),
                                     lr=learning_rate)
        self.beta_q2_optimizer = Adam(self.beta_q2.parameters(),
                                      lr=learning_rate)
        self.beta_v_optimizer = Adam(self.beta_v.parameters(),
                                     lr=learning_rate)
        self.policy_optimizer = Adam(
            self.policy.parameters(),
            lr=learning_rate,
        )
        self.q_criterion = nn.BCELoss()
        self.v_criterion = nn.BCELoss()

        # For the multitask env
        self._rollout_goal = None

        self.extra_eval_statistics = OrderedDict()
        for key_not_always_updated in [
                'Policy Gradient Norms',
                'Beta Q Gradient Norms',
                'dQ/da',
        ]:
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    key_not_always_updated,
                    np.zeros(2),
                ))

        self.training_policy = False

        # For debugging
        self.train_batches = []