コード例 #1
0
ファイル: cma_es.py プロジェクト: zhmz90/rllab
    def train(self):

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma_es_lib.CMAEvolutionStrategy(cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (stateful_pool.singleton_pool.run_map(
                    sample_return,
                    [(x, self.max_path_length, self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return,
                        [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = -np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn', np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
            logger.record_tabular('MinReturn', np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn', np.mean(fs))
            logger.record_tabular(
                'AvgTrajLen',
                np.mean([len(info['returns']) for info in infos]))
            self.env.log_diagnostics(infos)
            self.policy.log_diagnostics(infos)

            logger.save_itr_params(
                itr, dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                ))
            logger.dump_tabular(with_prefix=False)
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

        # Set final params.
        self.policy.set_param_values(es.result()[0])
コード例 #2
0
ファイル: batch_polopt.py プロジェクト: hal2001/rllab
 def update_plot(self):
     if self.plot:
         plotter.update_plot(self.policy, self.max_path_length)
コード例 #3
0
ファイル: batch_polopt.py プロジェクト: andrewliao11/rllab
 def update_plot(self):
     if self.plot:
         plotter.update_plot(self.policy, self.max_path_length)
コード例 #4
0
ファイル: cma_es.py プロジェクト: QuantCollective/maml_rl
    def train(self):

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma_es_lib.CMAEvolutionStrategy(
            cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (
                    stateful_pool.singleton_pool.run_map(sample_return, [(x, self.max_path_length,
                                                                          self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return, [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = - np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn',
                                  np.max(undiscounted_returns))
            logger.record_tabular('MinReturn',
                                  np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn',
                                  np.mean(fs))
            logger.record_tabular('AvgTrajLen',
                                  np.mean([len(info['returns']) for info in infos]))
            self.env.log_diagnostics(infos)
            self.policy.log_diagnostics(infos)

            logger.save_itr_params(itr, dict(
                itr=itr,
                policy=self.policy,
                env=self.env,
            ))
            logger.dump_tabular(with_prefix=False)
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

        # Set final params.
        self.policy.set_param_values(es.result()[0])
        parallel_sampler.terminate_task()
コード例 #5
0
ファイル: cem.py プロジェクト: QuantCollective/maml_rl
    def train(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.init_std
        cur_mean = self.policy.get_param_values()
        # K = cur_mean.size
        n_best = max(1, int(self.n_samples * self.best_frac))

        for itr in range(self.n_itr):
            # sample around the current distribution
            extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0)
            sample_std = np.sqrt(np.square(cur_std) + np.square(self.extra_std) * extra_var_mult)
            if self.batch_size is None:
                criterion = 'paths'
                threshold = self.n_samples
            else:
                criterion = 'samples'
                threshold = self.batch_size
            infos = stateful_pool.singleton_pool.run_collect(
                _worker_rollout_policy,
                threshold=threshold,
                args=(dict(cur_mean=cur_mean,
                          sample_std=sample_std,
                          max_path_length=self.max_path_length,
                          discount=self.discount,
                          criterion=criterion),)
            )
            xs = np.asarray([info[0] for info in infos])
            paths = [info[1] for info in infos]

            fs = np.array([path['returns'][0] for path in paths])
            print((xs.shape, fs.shape))
            best_inds = (-fs).argsort()[:n_best]
            best_xs = xs[best_inds]
            cur_mean = best_xs.mean(axis=0)
            cur_std = best_xs.std(axis=0)
            best_x = best_xs[0]
            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array([path['undiscounted_return'] for path in paths])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn',
                                  np.max(undiscounted_returns))
            logger.record_tabular('MinReturn',
                                  np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn',
                                  np.mean(fs))
            logger.record_tabular('AvgTrajLen',
                                  np.mean([len(path['returns']) for path in paths]))
            logger.record_tabular('NumTrajs',
                                  len(paths))
            self.policy.set_param_values(best_x)
            self.env.log_diagnostics(paths)
            self.policy.log_diagnostics(paths)
            logger.save_itr_params(itr, dict(
                itr=itr,
                policy=self.policy,
                env=self.env,
                cur_mean=cur_mean,
                cur_std=cur_std,
            ))
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
        parallel_sampler.terminate_task()
コード例 #6
0
ファイル: cem.py プロジェクト: ermongroup/MetaIRL
    def train(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.init_std
        cur_mean = self.policy.get_param_values()
        # K = cur_mean.size
        n_best = max(1, int(self.n_samples * self.best_frac))

        for itr in range(self.n_itr):
            # sample around the current distribution
            extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0)
            sample_std = np.sqrt(
                np.square(cur_std) +
                np.square(self.extra_std) * extra_var_mult)
            if self.batch_size is None:
                criterion = 'paths'
                threshold = self.n_samples
            else:
                criterion = 'samples'
                threshold = self.batch_size
            infos = stateful_pool.singleton_pool.run_collect(
                _worker_rollout_policy,
                threshold=threshold,
                args=(dict(cur_mean=cur_mean,
                           sample_std=sample_std,
                           max_path_length=self.max_path_length,
                           discount=self.discount,
                           criterion=criterion,
                           n_evals=self.n_evals), ))
            xs = np.asarray([info[0] for info in infos])
            paths = [info[1] for info in infos]

            fs = np.array([path['returns'][0] for path in paths])
            print((xs.shape, fs.shape))
            best_inds = (-fs).argsort()[:n_best]
            best_xs = xs[best_inds]
            cur_mean = best_xs.mean(axis=0)
            cur_std = best_xs.std(axis=0)
            best_x = best_xs[0]
            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [path['undiscounted_return'] for path in paths])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn', np.std(undiscounted_returns))
            logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
            logger.record_tabular('MinReturn', np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn', np.mean(fs))
            logger.record_tabular('NumTrajs', len(paths))
            paths = list(chain(
                *[d['full_paths']
                  for d in paths]))  #flatten paths for the case n_evals > 1
            logger.record_tabular(
                'AvgTrajLen',
                np.mean([len(path['returns']) for path in paths]))

            self.policy.set_param_values(best_x)
            self.env.log_diagnostics(paths)
            self.policy.log_diagnostics(paths)
            logger.save_itr_params(
                itr,
                dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                    cur_mean=cur_mean,
                    cur_std=cur_std,
                ))
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
        parallel_sampler.terminate_task()
コード例 #7
0
ファイル: dqn.py プロジェクト: leduckhc/rllab
 def update_plot(self):
     if self.plot:
         plotter.update_plot(self.policy, self.epoch_length)
コード例 #8
0
    def update_plot(self):
        if self.plot:
            plotter.update_plot(self.policy, self.max_path_length)


#constrained training function for HCMPC with:
#dyn_model: the dynamical model that is trained in the Model-Based Stage
#policy: the initial Model-Free policy
#rau:exploration probability to achieve a faster learning performance
#logdir_HCMPC: directory for HCMPC results
#logdir_HCMPC: log file number
#For My local machine---------> "/home/hendawy/Desktop/2DOF_Robotic_Arm_withSphereObstacle/Rr"
# def constrained_train(self,dyn_model,logdir_HCMPC,file_number,policy=None,rau=None):
#     self.start_worker()
#     self.init_opt()
#     logz.configure_output_dir(logdir_HCMPC,file_number)
#     for itr in range(self.current_itr, self.n_itr):
#             paths = self.sampler.obtain_samples(dyn_model,itr,policy,rau)
#             samples_data,analysis_data = self.sampler.process_samples(itr, paths)
#             self.log_diagnostics(paths)
#             optimization_data=self.optimize_policy(itr, samples_data)
#             if(rau<=0.05):
#                 rau=0
#             else:
#                 rau-=0.02
#             logz.log_tabular('Iteration', analysis_data["Iteration"])
#             # In terms of true environment reward of your rolled out trajectory using the MPC controller
#             logz.log_tabular('AverageDiscountedReturn',analysis_data["AverageDiscountedReturn"])
#             logz.log_tabular('AverageReturns', analysis_data["AverageReturn"])
#             logz.log_tabular('violation_cost', np.mean(samples_data["violation_cost"]))
#             logz.log_tabular('boundary_violation_cost', np.mean(samples_data["boundary_violation_cost"]))
#             logz.log_tabular('success_rate', samples_data["success_rate"])
#             logz.log_tabular('successful_AverageReturn', np.mean(samples_data["successful_AverageReturn"]))
#             logz.log_tabular('ExplainedVariance', analysis_data["ExplainedVariance"])
#             logz.log_tabular('NumTrajs', analysis_data["NumTrajs"])
#             logz.log_tabular('Entropy', analysis_data["Entropy"])
#             logz.log_tabular('Perplexity', analysis_data["Perplexity"])
#             logz.log_tabular('StdReturn', analysis_data["StdReturn"])
#             logz.log_tabular('MaxReturn', analysis_data["MaxReturn"])
#             logz.log_tabular('MinReturn', analysis_data["MinReturn"])
#             logz.log_tabular('LossBefore', optimization_data["LossBefore"])
#             logz.log_tabular('LossAfter', optimization_data["LossAfter"])
#             logz.log_tabular('MeanKLBefore', optimization_data["MeanKLBefore"])
#             logz.log_tabular('MeanKL', optimization_data["MeanKL"])
#             logz.log_tabular('dLoss', optimization_data["dLoss"])
#             logz.dump_tabular()
#             logger.log("saving snapshot...")
#             params = self.get_itr_snapshot(itr, samples_data)
#             self.current_itr = itr + 1
#             params["algo"] = self
#             if self.store_paths:
#                 params["paths"] = samples_data["paths"]
#             logger.save_itr_params(itr, params)
#             logger.log("saved")
#             logger.dump_tabular(with_prefix=False)
#             if self.plot:
#                 self.update_plot()
#                 if self.pause_for_plot:
#                     input("Plotting evaluation run: Press Enter to "
#                               "continue...")

#     self.shutdown_worker()
コード例 #9
0
 def update_plot(self, name):
     if self.plot:
         plotter.update_plot(self.policies[name], self.max_path_length)