Esempio n. 1
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #{} | '.format(itr)):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                snapshotter.save_itr_params(itr, params)
                logger.log('saved')
                logger.log(tabular)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')

        plotter.close()
        self.shutdown_worker()
Esempio n. 2
0
 def _start_worker(self):
     """Start Plotter and Sampler workers."""
     if self._plot:
         # pylint: disable=import-outside-toplevel
         from garage.plotter import Plotter
         self._plotter = Plotter()
         self._plotter.init_plot(self.get_env_copy(), self._algo.policy)
Esempio n. 3
0
 def __init__(self,
              env,
              policy,
              n_itr=500,
              max_path_length=500,
              discount=0.99,
              sigma0=1.,
              batch_size=None,
              plot=False,
              **kwargs):
     """
     :param n_itr: Number of iterations.
     :param max_path_length: Maximum length of a single rollout.
     :param batch_size: # of samples from trajs from param distribution,
      when this is set, n_samples is ignored
     :param discount: Discount.
     :param plot: Plot evaluation run after each iteration.
     :param sigma0: Initial std for param dist
     :return:
     """
     Serializable.quick_init(self, locals())
     self.env = env
     self.policy = policy
     self.plot = plot
     self.sigma0 = sigma0
     self.discount = discount
     self.max_path_length = max_path_length
     self.n_itr = n_itr
     self.batch_size = batch_size
     self.plotter = Plotter()
Esempio n. 4
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        plotter.close()
        self.shutdown_worker()
Esempio n. 5
0
 def _start_worker(self):
     """Start Plotter and Sampler workers."""
     self._sampler.start_worker()
     if self._plot:
         # pylint: disable=import-outside-toplevel
         from garage.tf.plotter import Plotter
         self._plotter = Plotter(self.get_env_copy(),
                                 self._algo.policy,
                                 sess=tf.compat.v1.get_default_session())
         self._plotter.start()
Esempio n. 6
0
 def __init__(self,
              env,
              policy,
              n_itr=500,
              max_path_length=500,
              discount=0.99,
              init_std=1.,
              n_samples=100,
              batch_size=None,
              best_frac=0.05,
              extra_std=1.,
              extra_decay_time=100,
              plot=False,
              n_evals=1,
              play_every_itr=None,
              play_rollouts_num=3,
              **kwargs):
     """
     :param n_itr: Number of iterations.
     :param max_path_length: Maximum length of a single rollout.
     :param batch_size: # of samples from trajs from param distribution,
      when this is set, n_samples is ignored
     :param discount: Discount.
     :param plot: Plot evaluation run after each iteration.
     :param init_std: Initial std for param distribution
     :param extra_std: Decaying std added to param distribution at each
      iteration
     :param extra_decay_time: Iterations that it takes to decay extra std
     :param n_samples: #of samples from param distribution
     :param best_frac: Best fraction of the sampled params
     :param n_evals: # of evals per sample from the param distr. returned
      score is mean - stderr of evals
     :return:
     """
     Serializable.quick_init(self, locals())
     self.env = env
     self.policy = policy
     self.batch_size = batch_size
     self.plot = plot
     self.extra_decay_time = extra_decay_time
     self.extra_std = extra_std
     self.best_frac = best_frac
     self.n_samples = n_samples
     self.init_std = init_std
     self.discount = discount
     self.max_path_length = max_path_length
     self.n_itr = n_itr
     self.n_evals = n_evals
     self.plotter = Plotter()
     self.play_every_itr = play_every_itr
     self.play_rollouts_num = play_rollouts_num
    def train(self):
        address = ('localhost', 6000)
        conn = Client(address)
        try:
            plotter = Plotter()
            if self.plot:
                plotter.init_plot(self.env, self.policy)
            conn.send(ExpLifecycle.START)
            self.start_worker()
            self.init_opt()
            for itr in range(self.current_itr, self.n_itr):
                with logger.prefix('itr #{} | '.format(itr)):
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.sampler.obtain_samples(itr)
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log('saving snapshot...')
                    params = self.get_itr_snapshot(itr, samples_data)
                    self.current_itr = itr + 1
                    params['algo'] = self
                    if self.store_paths:
                        params['paths'] = samples_data['paths']
                    snapshotter.save_itr_params(itr, params)
                    logger.log('saved')
                    logger.log(tabular)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        plotter.update_plot(self.policy, self.max_path_length)
                        if self.pause_for_plot:
                            input('Plotting evaluation run: Press Enter to '
                                  'continue...')

            conn.send(ExpLifecycle.SHUTDOWN)
            plotter.close()
            self.shutdown_worker()
        finally:
            conn.close()
Esempio n. 8
0
    def __init__(self,
                 env_spec,
                 policy,
                 baseline,
                 n_samples,
                 gae_lambda=1,
                 max_path_length=500,
                 discount=0.99,
                 init_std=1,
                 best_frac=0.05,
                 extra_std=1.,
                 extra_decay_time=100,
                 **kwargs):
        self.env_spec = env_spec
        self.policy = policy
        self.baseline = baseline
        self.n_samples = n_samples
        self.extra_decay_time = extra_decay_time
        self.extra_std = extra_std
        self.best_frac = best_frac
        self.init_std = init_std
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.max_path_length = max_path_length
        self.plotter = Plotter()

        # epoch-wise
        self.cur_std = self.init_std
        self.cur_mean = self.policy.get_param_values()
        # epoch-cycle-wise
        self.cur_params = self.cur_mean
        self.all_returns = []
        self.all_params = [self.cur_mean.copy()]
        # fixed
        self.n_best = int(n_samples * best_frac)
        assert self.n_best >= 1, (
            f"n_samples is too low. Make sure that n_samples * best_frac >= 1")
        self.n_params = len(self.cur_mean)
Esempio n. 9
0
    def __init__(self,
                 env,
                 policy,
                 qf,
                 es,
                 batch_size=32,
                 n_epochs=200,
                 epoch_length=1000,
                 min_pool_size=10000,
                 replay_pool_size=1000000,
                 discount=0.99,
                 max_path_length=250,
                 qf_weight_decay=0.,
                 qf_update_method='adam',
                 qf_learning_rate=1e-3,
                 policy_weight_decay=0,
                 policy_update_method='adam',
                 policy_learning_rate=1e-4,
                 eval_samples=10000,
                 soft_target=True,
                 soft_target_tau=0.001,
                 n_updates_per_sample=1,
                 scale_reward=1.0,
                 include_horizon_terminal_transitions=False,
                 plot=False,
                 pause_for_plot=False):
        """
        :param env: Environment
        :param policy: Policy
        :param qf: Q function
        :param es: Exploration strategy
        :param batch_size: Number of samples for each minibatch.
        :param n_epochs: Number of epochs. Policy will be evaluated after each
         epoch.
        :param epoch_length: How many timesteps for each epoch.
        :param min_pool_size: Minimum size of the pool to start training.
        :param replay_pool_size: Size of the experience replay pool.
        :param discount: Discount factor for the cumulative return.
        :param max_path_length: Discount factor for the cumulative return.
        :param qf_weight_decay: Weight decay factor for parameters of the Q
         function.
        :param qf_update_method: Online optimization method for training Q
         function.
        :param qf_learning_rate: Learning rate for training Q function.
        :param policy_weight_decay: Weight decay factor for parameters of the
         policy.
        :param policy_update_method: Online optimization method for training
         the policy.
        :param policy_learning_rate: Learning rate for training the policy.
        :param eval_samples: Number of samples (timesteps) for evaluating the
         policy.
        :param soft_target_tau: Interpolation parameter for doing the soft
         target update.
        :param n_updates_per_sample: Number of Q function and policy updates
         per new sample obtained
        :param scale_reward: The scaling factor applied to the rewards when
         training
        :param include_horizon_terminal_transitions: whether to include
         transitions with terminal=True because the horizon was reached. This
         might make the Q value back up less stable for certain tasks.
        :param plot: Whether to visualize the policy performance after each
         eval_interval.
        :param pause_for_plot: Whether to pause before continuing when plotting
        :return:
        """
        self.env = env
        self.input_dims = configure_dims(env)
        self.policy = policy
        self.qf = qf
        self.es = es
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.epoch_length = epoch_length
        self.min_pool_size = min_pool_size
        self.replay_pool_size = replay_pool_size
        self.discount = discount
        self.max_path_length = max_path_length
        self.qf_weight_decay = qf_weight_decay
        self.qf_update_method = \
            parse_update_method(
                qf_update_method,
                learning_rate=qf_learning_rate,
            )
        self.qf_learning_rate = qf_learning_rate
        self.policy_weight_decay = policy_weight_decay
        self.policy_update_method = \
            parse_update_method(
                policy_update_method,
                learning_rate=policy_learning_rate,
            )
        self.policy_learning_rate = policy_learning_rate
        self.eval_samples = eval_samples
        self.soft_target_tau = soft_target_tau
        self.n_updates_per_sample = n_updates_per_sample
        self.include_horizon_terminal_transitions = \
            include_horizon_terminal_transitions
        self.plot = plot
        self.pause_for_plot = pause_for_plot

        self.qf_loss_averages = []
        self.policy_surr_averages = []
        self.q_averages = []
        self.y_averages = []
        self.paths = []
        self.es_path_returns = []
        self.paths_samples_cnt = 0

        self.scale_reward = scale_reward

        self.opt_info = None
        self.plotter = Plotter()