Beispiel #1
0
    def train(self):
        self.start_worker()
        self.init_opt()
        # logz.configure_output_dir("/home/hendawy/Desktop/HumonoidwithTRPOandMappingtojointangles\Trial1",13000)
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                optimization_data = self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                opt_data = self.get_itr_snapshot(itr, samples_data)
                values = opt_data["policy"].get_param_values()
                print("Saving learned TF nn model parameters.")
                f = open(
                    '/home/hendawy/Desktop/HumonoidwithTRPOandMappingtojointangles/Trial1/saver%i.save'
                    % itr, 'wb')
                cPickle.dump(values, f, protocol=cPickle.HIGHEST_PROTOCOL)
                f.close()

                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.obtain_samples(itr)
                samples_data = self.process_samples(itr, paths)
                # TOFIX(eugene) why is this here, and can I get rid of it?
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                # FIXME(eugene) uncomment this line
                #params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths_n = self.obtain_samples(itr)
                samples_data_n = self.process_samples(itr, paths_n)
                self.log_diagnostics(paths_n)
                # print('Average Return:', np.mean([sum(path["rewards"])for paths in paths_n for path in paths]))
                self.optimize_agents_policies(itr, samples_data_n)
                if itr and (itr % self.average_period == 0):
                    self.optimize_policy()
                    logger.log("saving snapshot...")
                    params = self.get_itr_snapshot(itr)
                    self.current_itr = itr + 1
                    params["algo"] = self
                    logger.save_itr_params(itr, params)
                    logger.log("saved")
                    logger.dump_tabular(with_prefix=False)
        if (self.n_itr - 1) % self.average_period != 0:
            self.optimize_policy()
            logger.log("saving snapshot...")
            params = self.get_itr_snapshot(self.n_itr - 1)
            params["algo"] = self
            logger.save_itr_params(self.n_itr - 1, params)
            logger.log("saved")
            logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()

        return np.mean(
            [sum(path["rewards"]) for paths in paths_n for path in paths])
Beispiel #4
0
def custom_train(algo, sess=None):
    created_session = True if (sess is None) else False
    if sess is None:
        sess = tf.Session()
        sess.__enter__()

    rollout_cache = []
    initialize_uninitialized(sess)
    algo.start_worker()
    start_time = time.time()
    for itr in range(algo.start_itr, algo.n_itr):
        itr_start_time = time.time()
        with logger.prefix('itr #%d | ' % itr):
            logger.log("Obtaining samples...")
            paths = algo.obtain_samples(itr)
            logger.log("Processing samples...")
            samples_data = algo.process_samples(itr, paths)
            logger.log("Logging diagnostics...")
            algo.log_diagnostics(paths)
            logger.log("Optimizing policy...")
            algo.optimize_policy(itr, samples_data)
            logger.log("Saving snapshot...")
            params = algo.get_itr_snapshot(itr, samples_data)  # , **kwargs)
            if algo.store_paths:
                params["paths"] = samples_data["paths"]
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            logger.record_tabular('Time', time.time() - start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)

    algo.shutdown_worker()
    if created_session:
        sess.close()
Beispiel #5
0
    def log_diagnostics(self, itr):
        self.pbar.stop()
        self.save_itr_snapshot(itr)

        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
        logger.record_tabular('CumCompletedSteps', self._cum_completed_steps)
        logger.record_tabular('CumTotalSteps', (itr + 1) * self._sample_size)
        logger.record_tabular('NewCompletedTrajs', self._new_completed_trajs)
        logger.record_tabular('StepsInTrajWindow',
                              sum(info["Length"] for info in self._traj_infos))

        if self._log_entropy:
            logger.record_tabular('Entropy', self._entropy_ema)
            logger.record_tabular('Perplexity', self._perplexity_ema)

        self._log_infos()

        new_time = time.time()
        samples_per_second = \
            (self._log_interval_itrs * self._sample_size) / (new_time - self._last_time)
        logger.record_tabular('CumTime (s)', new_time - self._start_time)
        logger.record_tabular('SamplesPerSecond', samples_per_second)
        self._last_time = new_time
        logger.dump_tabular(with_prefix=False)

        self._new_completed_trajs = 0
        if itr < self._n_itr - 1:
            logger.log('optimizing over {} iterations'.format(
                self._log_interval_itrs))
            self.pbar = ProgBarCounter(self._log_interval_itrs)
Beispiel #6
0
    def train(self):
        self.start_worker()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                logger.log('Obtaining samples...')
                paths = self.sampler.obtain_samples(itr)
                logger.log('Processing samples...')
                samples_data = self.sampler.process_samples(itr, paths)
                logger.log('Logging diagnostics...')
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                # Save the trajectories into the param
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                logger.save_itr_params(itr, params)
                logger.log('Saved')
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')

        self.shutdown_worker()
Beispiel #7
0
 def train(self):
     with tf.Session() as sess:
         sess.run(tf.initialize_all_variables())
         self.start_worker()
         start_time = time.time()
         for itr in range(self.start_itr, self.n_itr):
             itr_start_time = time.time()
             with logger.prefix('itr #%d | ' % itr):
                 logger.log("Obtaining samples...")
                 paths = self.obtain_samples(itr)
                 logger.log("Processing samples...")
                 samples_data = self.process_samples(itr, paths)
                 logger.log("Logging diagnostics...")
                 self.log_diagnostics(paths)
                 logger.log("Optimizing policy...")
                 self.optimize_policy(itr, samples_data)
                 logger.log("Saving snapshot...")
                 params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                 if self.store_paths:
                     params["paths"] = samples_data["paths"]
                 logger.save_itr_params(itr, params)
                 logger.log("Saved")
                 logger.record_tabular('Time', time.time() - start_time)
                 logger.record_tabular('ItrTime', time.time() - itr_start_time)
                 logger.dump_tabular(with_prefix=False)
                 if self.plot:
                     self.update_plot()
                     if self.pause_for_plot:
                         input("Plotting evaluation run: Press Enter to "
                               "continue...")
     self.shutdown_worker()
Beispiel #8
0
    def train(self, already_init=False):
        self.start_worker()
        if not already_init:
            self.init_opt()
        all_paths = []
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                before_rollouts = time.time()
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                after_rollouts = time.time()
                print("rollout: ", after_rollouts - before_rollouts)
                self.log_diagnostics(paths)
                before_update = time.time()
                self.optimize_policy(itr, samples_data)
                after_update = time.time()
                print("update: ", after_update - before_update)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                all_paths.append(paths)
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
        return all_paths
Beispiel #9
0
    def train(self):
        self.start_worker()
        self.init_opt()
        rets = []
        for itr in range(self.start_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.obtain_samples(itr)
                print(("BatchPolopt:train len(paths)", len(paths)))
                samples_data, total_returns_per_episode = self.process_samples(itr, paths)
                rets.append(total_returns_per_episode)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        eval(input("Plotting evaluation run: Press Enter to "
                                  "continue..."))

        self.shutdown_worker()
        return rets
Beispiel #10
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in xrange(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                if self.exp_name:
                    num_traj = len(samples_data["paths"])
                    final_eepts = np.concatenate([samples_data["paths"][traj]["observations"][-1,14:20][None,:]\
                        for traj in range(num_traj)], axis=0)
                    cPickle.dump( final_eepts, open( "/home/ajay/rllab/data/local/{0}/{1}/final_eepts_itr_{2}.pkl".format(\
                        self.exp_prefix, self.exp_name, itr), "w+" ) )
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        raw_input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        self.shutdown_worker()
Beispiel #11
0
    def train(self, continue_learning=False):
        self.start_worker()
        if not continue_learning:
            self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
Beispiel #12
0
    def train(self):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            start_time = time.time()
            self.start_worker()

            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    paths = self.sampler.obtain_samples(itr)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    self.optimize_policy(itr, samples_data)
                    logger.log("saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)
                    #self.current_itr = itr + 1
                    #params["algo"] = self
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                      "continue...")

        self.shutdown_worker()
Beispiel #13
0
 def train(self, sess=None):
     if sess is None:
         sess = tf.Session()
         sess.__enter__()
     #with tf.Session() as sess:
     sess.run(tf.initialize_all_variables())
     self.start_worker()
     start_time = time.time()
     for itr in range(self.start_itr, self.n_itr):
         itr_start_time = time.time()
         with logger.prefix('itr #%d | ' % itr):
             logger.log("Obtaining samples...")
             paths = self.obtain_samples(itr)
             logger.log("Processing samples...")
             samples_data = self.process_samples(itr, paths)
             logger.log("Logging diagnostics...")
             self.log_diagnostics(paths)
             logger.log("Optimizing policy...")
             self.optimize_policy(itr, samples_data)
             logger.log("Saving snapshot...")
             params = self.get_itr_snapshot(itr,
                                            samples_data)  # , **kwargs)
             if self.store_paths:
                 params["paths"] = samples_data["paths"]
             logger.save_itr_params(itr, params)
             logger.log("Saved")
             logger.record_tabular('Time', time.time() - start_time)
             logger.record_tabular('ItrTime', time.time() - itr_start_time)
             logger.dump_tabular(with_prefix=False)
             if self.plot:
                 self.update_plot()
                 if self.pause_for_plot:
                     input("Plotting evaluation run: Press Enter to "
                           "continue...")
     self.shutdown_worker()
Beispiel #14
0
    def train(self, sess=None):

        if sess is None:
            sess = tf.Session()

        sess.run(tf.global_variables_initializer())

        replay_buffer = SimpleReplayBuffer(env_spec=self._env.spec, max_replay_buffer_size=self._max_pool_size)

        path_length = 0
        episode_rewards = 0
        observation = self._env.reset()

        with sess.as_default():
            self._update_target()

            for ep in range(self._n_epochs):
                mean_loss = 0
                trained_iter = 0
                epoch_rewards = list()
                episode_lengths = list()
                with logger.prefix('Epoch #%d | ' % ep):
                    for ep_iter in pyprind.prog_bar(range(self._epoch_length)):
                        self._env.render()
                        action, _ = self._es.get_action(observation)
                        next_observation, reward, terminal, _ = self._env.step(action)

                        replay_buffer.add_sample(
                            observation=observation,
                            next_observation=next_observation,
                            action=action,
                            terminal=terminal,
                            reward=reward,
                        )

                        episode_rewards += reward
                        path_length += 1

                        observation = next_observation

                        if terminal or path_length >= self._max_path_length:
                            observation = self._env.reset()
                            epoch_rewards.append(episode_rewards)
                            episode_lengths.append(path_length)
                            path_length = 0
                            episode_rewards = 0

                        iter = ep * self._epoch_length + ep_iter
                        if replay_buffer.size > self._min_pool_size:
                            batch = replay_buffer.random_batch(self._batch_size)
                            loss = self._do_training(iter, batch)
                            mean_loss += loss
                            trained_iter += 1

                        if iter % self._target_update_period == 0 and replay_buffer.size > self._min_pool_size:
                            self._update_target()
                    logger.record_tabular('mean-td-error', (mean_loss/self._epoch_length))
                    logger.record_tabular('mean-episode-reward', np.mean(epoch_rewards))
                    logger.record_tabular('mean-epsiode-length', np.mean(episode_lengths))
                    logger.dump_tabular()
Beispiel #15
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                average_return_list = []
                for i in range(self.num_of_agents):
                    paths = self.sampler_list[i].obtain_samples(itr)
                    samples_data, average_return = self.sampler_list[
                        i].process_samples(itr, paths)
                    average_return_list.append(average_return)
                    # self.log_diagnostics(paths)
                    self.optimize_policy(itr, samples_data, i)
                logger.record_tabular('AverageReturn',
                                      np.max(average_return_list))
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    pass
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()
Beispiel #16
0
    def train(self):
        self.start_worker()
        self.init_opt()
        episode_rewards = []
        episode_lengths = []
        for itr in xrange(self.start_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.obtain_samples(itr)
                samples_data = self.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        raw_input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        self.shutdown_worker()
Beispiel #17
0
    def log_diagnostics(self, itr, eval_traj_infos, eval_time):
        self.save_itr_snapshot(itr)
        if not eval_traj_infos:
            logger.log("ERROR: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        logger.record_tabular('Iteration', itr)
        logger.record_tabular('CumCompletedSteps', itr * self._sample_size)
        logger.record_tabular('StepsInEval', steps_in_eval)
        logger.record_tabular('TrajsInEval', len(eval_traj_infos))

        self._log_infos(eval_traj_infos)

        new_time = time.time()
        log_interval_time = new_time - self._last_time
        new_train_time = log_interval_time - eval_time
        self._cum_train_time += new_train_time
        self._cum_eval_time += eval_time
        self._cum_total_time += log_interval_time
        self._last_time = new_time
        train_speed = float('nan') if itr == 0 else \
            self._log_interval_itrs * self._sample_size / new_train_time

        logger.record_tabular('CumTrainTime', self._cum_train_time)
        logger.record_tabular('CumEvalTime', self._cum_eval_time)
        logger.record_tabular('CumTotalTime', self._cum_total_time)
        logger.record_tabular('SamplesPerSecond', train_speed)

        logger.dump_tabular(with_prefix=False)

        logger.log('optimizing over {} iterations'.format(
            self._log_interval_itrs))
        self.pbar = ProgBarCounter(self._log_interval_itrs)
Beispiel #18
0
    def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)

        if self.init_qvar_params is not None:
            self.qvar_model.set_params(self.init_qvar_params)

        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)

        if self.init_empw_params is not None:
            self.empw.set_params(self.init_empw_params)

        self.start_worker()
        start_time = time.time()

        returns = []
        rew = []  # stores score at each step
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()

            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths, r = self.compute_irl(paths, itr=itr)
                rew.append(r)
                returns.append(self.log_avg_returns(paths))
                self.compute_qvar(paths, itr=itr)
                self.compute_empw(paths, itr=itr)
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            if itr % self.target_empw_update == 0 and self.train_empw:  #reward 5
                print('updating target empowerment parameters')
                self.tempw.set_params(self.__empw_params)

        #pickle.dump(rew, open("rewards.p", "wb" )) # uncomment to store rewards in every iteration
        self.shutdown_worker()
        return
    def train(self, sess=None):

        sess = self.sess
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):

                logger.log("Obtaining samples...")

                logger.log("Collecting both agent and oracle samples...")
                paths, agent_only_paths = self.obtain_samples(
                    itr, self.oracle_policy)

                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                agent_samples_data = self.process_agent_samples(
                    itr, agent_only_paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                self.log_diagnostics(agent_only_paths)

                #### optimising the policy based on the collected samples
                logger.log("Optimizing policy...")
                self.optimize_agent_policy(itr, agent_samples_data)
                self.optimize_policy(itr, samples_data)

                logger.log("Saving snapshot...")

                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]

                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)

                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to"
                              "continue...")

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #20
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        self.shutdown_worker()
    def train(self):
        self.start_worker()

        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                # TODO: do we use a new rollout on expert data in each itr? for now we can do so but at some point we only have a fixed dataset
                generated_paths = self.sampler.obtain_samples(itr)
                generated_data = self.sampler.process_samples(
                    itr, generated_paths)
                self.log_diagnostics(generated_paths)

                self.optimize_policy(itr, generated_data)

                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, generated_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = generated_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")

                logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()
Beispiel #22
0
    def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            if self.anneal_temp and (
                    itr + 1
            ) % self.anneal_discount_epoch == 0 and itr >= self.anneal_temp_start:
                if self.anneal_method == 'loglinear':
                    self.temp *= self.anneal_discount_factor
                elif self.anneal_method == 'linear':
                    self.temp -= self.anneal_discount_factor
                if self.temp < self.temp_min:
                    self.temp = self.temp_min
                logger.log("Current Temperature {:}".format(self.temp))
            with logger.prefix('itr #%d | ' % itr):
                average_return_list = []
                gradient_list = []
                for i in range(self.num_of_agents):
                    paths = self.sampler_list[i].obtain_samples(itr)
                    samples_data, average_return = self.sampler_list[
                        i].process_samples(itr, paths)
                    average_return_list.append(average_return)
                    gradient = self.optimize_policy(itr, samples_data, i)
                    gradient_list.append(gradient)
                logger.log("Update Policy {BEGIN}")
                self.update_policies(gradient_list)
                logger.log("Update Policy {END}")
                logger.record_tabular('AverageReturn',
                                      np.max(average_return_list))
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    pass
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
            if self.evolution and (itr + 1) % self.evolution_update_steps == 0:
                logger.log(
                    ">>>>>>>>>>>>>>>>>>>>>>> Evolution START <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
                )
                num_of_update = int(self.evolution_ratio * self.num_of_agents)
                sorted_id = np.argsort(average_return_list)
                deleted_id = sorted_id[:num_of_update]
                sampled_id = sorted_id[num_of_update:]
                for i in range(len(deleted_id)):
                    current_id = np.random.choice(sampled_id, 1)
                    current_params = self.policy_list[
                        current_id].get_param_values()
                    current_epsilon = self.evolution_epsilon * (
                        np.random.random(current_params.shape) - 0.5)
                    self.policy_list[deleted_id[i]].set_param_values(
                        current_params + current_epsilon)
                logger.log(
                    ">>>>>>>>>>>>>>>>>>>>>>> Evolution FINISH <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
                )

        self.shutdown_worker()
Beispiel #23
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
        if not self.transfer:
            sess.run(tf.global_variables_initializer())

        #initialize uninitialize variables
        global_vars = tf.global_variables()
        is_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_initialized) if not f
        ]
        # print([str(i.name) for i in not_initialized_vars]) # only for testing
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

                params = tf.trainable_variables()
                params_val = sess.run(params)
                for param, param_val in zip(params, params_val):
                    print(param.name + "value: ", param_val)

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #24
0
    def train_mf(self):
        self.start_worker()
        self.init_opt()
        logz.configure_output_dir(
            "/home/hendawy/Desktop/2DOF_Robotic_Arm_withSphereObstacle/Rr",
            1807)
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr, Constrained=True)
                samples_data, analysis_data = self.sampler.process_samples(
                    itr, paths)
                self.log_diagnostics(paths)
                optimization_data = self.optimize_policy(itr, samples_data)
                logz.log_tabular('Iteration', analysis_data["Iteration"])
                # In terms of true environment reward of your rolled out trajectory using the MPC controller
                logz.log_tabular('AverageDiscountedReturn',
                                 analysis_data["AverageDiscountedReturn"])
                logz.log_tabular('AverageReturns',
                                 analysis_data["AverageReturn"])
                logz.log_tabular('violation_cost',
                                 np.mean(samples_data["violation_cost"]))
                logz.log_tabular(
                    'boundary_violation_cost',
                    np.mean(samples_data["boundary_violation_cost"]))
                logz.log_tabular('success_rate', samples_data["success_rate"])
                logz.log_tabular(
                    'successful_AverageReturn',
                    np.mean(samples_data["successful_AverageReturn"]))
                logz.log_tabular('ExplainedVariance',
                                 analysis_data["ExplainedVariance"])
                logz.log_tabular('NumTrajs', analysis_data["NumTrajs"])
                logz.log_tabular('Entropy', analysis_data["Entropy"])
                logz.log_tabular('Perplexity', analysis_data["Perplexity"])
                logz.log_tabular('StdReturn', analysis_data["StdReturn"])
                logz.log_tabular('MaxReturn', analysis_data["MaxReturn"])
                logz.log_tabular('MinReturn', analysis_data["MinReturn"])
                logz.log_tabular('LossBefore', optimization_data["LossBefore"])
                logz.log_tabular('LossAfter', optimization_data["LossAfter"])
                logz.log_tabular('MeanKLBefore',
                                 optimization_data["MeanKLBefore"])
                logz.log_tabular('MeanKL', optimization_data["MeanKL"])
                logz.log_tabular('dLoss', optimization_data["dLoss"])
                logz.dump_tabular()
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
Beispiel #25
0
def agent_train(
    algo,
    oracle_policy,
    sess=None,
):
    """
    This is necessary so that we don't wipe away already initialized policy params.
    Ideally, we should pull request this in as an option to RLlab and remove it from here once done
    """
    created_session = True if (sess is None) else False
    if sess is None:
        sess = tf.Session()
        sess.__enter__()

    rollout_cache = []
    initialize_uninitialized(sess)
    algo.start_worker()
    start_time = time.time()

    #every time step
    for itr in range(algo.start_itr, algo.n_itr):

        itr_start_time = time.time()
        with logger.prefix('itr #%d | ' % itr):

            #use multiple rollouts/trajectories to obtain samples for TRPO
            logger.log("Obtaining samples...")
            ## obtain samples - for both only agent and all samples (including oracle and agent)
            paths, agent_only_paths = algo.obtain_samples(itr, oracle_policy)

            logger.log("Processing samples...")
            samples_data = algo.process_samples(itr, paths)
            agent_samples_data = algo.process_agent_samples(
                itr, agent_only_paths)

            logger.log("Logging diagnostics...")
            algo.log_diagnostics(paths)

            logger.log("Optimizing policy...")
            ## optimising pi(s) with agent samples data only
            algo.optimize_agent_policy(itr, agent_samples_data)
            ## optimising beta(s) with all samples
            algo.optimize_policy(itr, samples_data)

            logger.log("Saving snapshot...")
            params = algo.get_itr_snapshot(itr, samples_data)  # , **kwargs)
            if algo.store_paths:
                params["paths"] = samples_data["paths"]
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            logger.record_tabular('Time', time.time() - start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)

    algo.shutdown_worker()

    if created_session:
        sess.close()
Beispiel #26
0
 def train(self):
     with tf.Session() as sess:
         sess.run(tf.initialize_all_variables())
         if self.qprop:
             pool = SimpleReplayPool(
                 max_pool_size=self.replay_pool_size,
                 observation_dim=self.env.observation_space.flat_dim,
                 action_dim=self.env.action_space.flat_dim,
                 replacement_prob=self.replacement_prob,
             )
         self.start_worker()
         self.init_opt()
         # This initializes the optimizer parameters
         sess.run(tf.initialize_all_variables())
         start_time = time.time()
         for itr in range(self.start_itr, self.n_itr):
             itr_start_time = time.time()
             with logger.prefix('itr #%d | ' % itr):
                 if self.qprop and not self.qprop_enable and \
                         itr >= self.qprop_min_itr:
                     logger.log(
                         "Restarting workers with batch size %d->%d..." %
                         (self.batch_size, self.qprop_batch_size))
                     self.shutdown_worker()
                     self.batch_size = self.qprop_batch_size
                     self.start_worker()
                     if self.qprop_use_qf_baseline:
                         self.baseline = self.qf_baseline
                     self.qprop_enable = True
                 logger.log("Obtaining samples...")
                 paths = self.obtain_samples(itr)
                 logger.log("Processing samples...")
                 samples_data = self.process_samples(itr, paths)
                 logger.log("Logging diagnostics...")
                 self.log_diagnostics(paths)
                 if self.qprop:
                     logger.log("Adding samples to replay pool...")
                     self.add_pool(itr, paths, pool)
                     logger.log("Optimizing critic before policy...")
                     self.optimize_critic(itr, pool)
                 logger.log("Optimizing policy...")
                 self.optimize_policy(itr, samples_data)
                 params = self.get_itr_snapshot(itr,
                                                samples_data)  # , **kwargs)
                 if self.store_paths:
                     params["paths"] = samples_data["paths"]
                 logger.save_itr_params(itr, params)
                 logger.log("Saved")
                 logger.record_tabular('Time', time.time() - start_time)
                 logger.record_tabular('ItrTime',
                                       time.time() - itr_start_time)
                 logger.dump_tabular(with_prefix=False)
                 if self.plot:
                     self.update_plot()
                     if self.pause_for_plot:
                         input("Plotting evaluation run: Press Enter to "
                               "continue...")
     self.shutdown_worker()
Beispiel #27
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(range(self._n_epochs + 1),
                                      save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(iteration=t +
                                          epoch * self._epoch_length,
                                          batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

            self.sampler.terminate()
Beispiel #28
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        AvgDisReturn = []
        AvgReturn = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                #print(paths)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                # for key in samples_data:
                #     print(key)
                # print(samples_data["rewards"])
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)

                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                AvgDisReturn.append(
                    float(dict(logger._tabular)["AverageDiscountedReturn"]))
                AvgReturn.append(float(dict(logger._tabular)["AverageReturn"]))
                # for key in dict(logger._tabular):
                #     print(key)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

                store("AvgDisReturn.dat", AvgDisReturn)
                store("AvgReturn.dat", AvgReturn)

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #29
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        global_step = tf.train.get_or_create_global_step()
        global_step_inc = global_step.assign_add(1)

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        total_timesteps = 0
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                with _MeasureTime('ObtainSamplesTime'):
                    paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                with _MeasureTime('ProcessPathsTime'):
                    self.process_paths(paths)
                with _MeasureTime('ProcessSamplesTime'):
                    samples_data = self.process_samples(itr, paths)
                timesteps = len(samples_data['observations'])
                total_timesteps += timesteps
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                with _MeasureTime('OptimizePolicyTime'):
                    self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.record_tabular('Timesteps', timesteps)
                logger.record_tabular('TotalTimesteps', total_timesteps)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

                sess.run(global_step_inc)

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #30
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data, self._wandb_dict)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr,
                                               samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env,
                            self.policy,
                            animated=True,
                            max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
                if self._render:
                    fn = self._gif_header + str(itr) + '.gif'
                    # obtain gym.env from rllab.env
                    render_env(self.env.wrapped_env.env,
                               path=self._gif_dir,
                               filename=fn)
                    if self._log_wandb:
                        full_fn = os.path.join(os.getcwd(), self._gif_dir, fn)
                        wandb.log({
                            "video":
                            wandb.Video(full_fn, fps=60, format="gif")
                        })
                if self._log_wandb:
                    wandb.log(self._wandb_dict)

        self.shutdown_worker()
        if created_session:
            sess.close()
Beispiel #31
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """

        self._init_training(env, policy, pool)
        self.sampler.initialize(env, policy, pool)

        with self._sess.as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(range(self._n_epochs + 1),
                                      save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    # TODO.codeconsolidation: Add control interval to sampler
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(iteration=t +
                                          epoch * self._epoch_length,
                                          batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(epoch)

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)
                times_itrs = gt.get_times().stamps.itrs

                eval_time = times_itrs['eval'][-1] if epoch > 1 else 0
                total_time = gt.get_times().total
                logger.record_tabular('time-train', times_itrs['train'][-1])
                logger.record_tabular('time-eval', eval_time)
                logger.record_tabular('time-sample', times_itrs['sample'][-1])
                logger.record_tabular('time-total', total_time)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

                gt.stamp('eval')

            self.sampler.terminate()
Beispiel #32
0
    def train(self):
        self.start_worker()
        self.init_opt()

        # added, store average returns and std returns
        if self.plot_learning_curve:
            avg_returns = []
            std_returns = []

        # added, make sure we add the first curriculum element
        assert (self.current_itr == 0)

        for itr in range(self.current_itr, self.n_itr):

            # added, update curriculum if necessary
            if isinstance(self.policy, CurriculumPolicy):
                if itr % self.policy.update_freq == 0:
                    if len(self.curriculum_list) > 0:
                        self.curriculum.append(self.curriculum_list.pop(0))

            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)

                # added
                if self.plot_learning_curve:
                    cur_paths = samples_data["paths"]
                    total_returns = [
                        sum(path["rewards"]) for path in cur_paths
                    ]
                    avg_returns.append(np.mean(total_returns))
                    std_returns.append(np.std(total_returns))

                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()

        # added
        if self.plot_learning_curve:
            return avg_returns, std_returns
    def train(self):

        memory = ReplayMem(
            obs_dim=self.env.observation_space.flat_dim,
            act_dim=self.env.action_space.flat_dim,
            memory_size=self.memory_size)

        itr = 0
        path_length = 0
        path_return = 0
        end = False
        obs = self.env.reset()

        for epoch in xrange(self.n_epochs):
            logger.push_prefix("epoch #%d | " % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # run the policy
                if end:
                    # reset the environment and stretegy when an episode ends
                    obs = self.env.reset()
                    self.strategy.reset()
                    # self.policy.reset()
                    self.strategy_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                # note action is sampled from the policy not the target policy
                act = self.strategy.get_action(obs, self.policy)
                nxt, rwd, end, _ = self.env.step(act)

                path_length += 1
                path_return += rwd

                if not end and path_length >= self.max_path_length:
                    end = True
                    if self.include_horizon_terminal:
                        memory.add_sample(obs, act, rwd, end)
                else:
                    memory.add_sample(obs, act, rwd, end)

                obs = nxt

                if memory.size >= self.memory_start_size:
                    for update_time in xrange(self.n_updates_per_sample):
                        batch = memory.get_batch(self.batch_size)
                        self.do_update(itr, batch)

                itr += 1

            logger.log("Training finished")
            if memory.size >= self.memory_start_size:
                self.evaluate(epoch, memory)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
    def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths = self.compute_irl(paths, itr=itr)
                returns.append(self.log_avg_returns(paths))
                samples_data = self.process_samples(itr, paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        return 
Beispiel #35
0
def custom_train(algo, sess=None):
    """
    This is necessary so that we don't wipe away already initialized policy params.
    Ideally, we should pull request this in as an option to RLlab and remove it from here once done
    """
    created_session = True if (sess is None) else False
    if sess is None:
        sess = tf.Session()
        sess.__enter__()

    rollout_cache = []
    initialize_uninitialized(sess)
    algo.start_worker()
    start_time = time.time()
    for itr in range(algo.start_itr, algo.n_itr):
        itr_start_time = time.time()
        with logger.prefix('itr #%d | ' % itr):
            logger.log("Obtaining samples...")
            paths = algo.obtain_samples(itr)
            logger.log("Processing samples...")
            samples_data = algo.process_samples(itr, paths)
            logger.log("Logging diagnostics...")
            algo.log_diagnostics(paths)
            logger.log("Optimizing policy...")
            algo.optimize_policy(itr, samples_data)
            logger.log("Saving snapshot...")
            params = algo.get_itr_snapshot(itr, samples_data)  # , **kwargs)
            if algo.store_paths:
                params["paths"] = samples_data["paths"]
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            logger.record_tabular('Time', time.time() - start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)

    algo.shutdown_worker()
    if created_session:
        sess.close()
Beispiel #36
0
 def train(self, sess=None):
     created_session = True if (sess is None) else False
     if sess is None:
         sess = tf.Session()
         sess.__enter__()
         
     sess.run(tf.global_variables_initializer())
     self.start_worker()
     start_time = time.time()
     for itr in range(self.start_itr, self.n_itr):
         itr_start_time = time.time()
         with logger.prefix('itr #%d | ' % itr):
             logger.log("Obtaining samples...")
             paths = self.obtain_samples(itr)
             logger.log("Processing samples...")
             samples_data = self.process_samples(itr, paths)
             logger.log("Logging diagnostics...")
             self.log_diagnostics(paths)
             logger.log("Optimizing policy...")
             self.optimize_policy(itr, samples_data)
             logger.log("Saving snapshot...")
             params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
             if self.store_paths:
                 params["paths"] = samples_data["paths"]
             logger.save_itr_params(itr, params)
             logger.log("Saved")
             logger.record_tabular('Time', time.time() - start_time)
             logger.record_tabular('ItrTime', time.time() - itr_start_time)
             logger.dump_tabular(with_prefix=False)
             if self.plot:
                 rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
                 if self.pause_for_plot:
                     input("Plotting evaluation run: Press Enter to "
                           "continue...")
     self.shutdown_worker()
     if created_session:
         sess.close()
Beispiel #37
0
    def train(self):

        # Bayesian neural network (BNN) initialization.
        # ------------------------------------------------
        batch_size = 1  # Redundant
        n_batches = 5  # Hardcode or annealing scheme \pi_i.

        # MDP observation and action dimensions.
        obs_dim = np.prod(self.env.observation_space.shape)
        act_dim = np.prod(self.env.action_space.shape)

        logger.log("Building BNN model (eta={}) ...".format(self.eta))
        start_time = time.time()

        self.bnn = bnn.BNN(
            n_in=(obs_dim + act_dim),
            n_hidden=self.unn_n_hidden,
            n_out=obs_dim,
            n_batches=n_batches,
            layers_type=self.unn_layers_type,
            trans_func=lasagne.nonlinearities.rectify,
            out_func=lasagne.nonlinearities.linear,
            batch_size=batch_size,
            n_samples=self.snn_n_samples,
            prior_sd=self.prior_sd,
            use_reverse_kl_reg=self.use_reverse_kl_reg,
            reverse_kl_reg_factor=self.reverse_kl_reg_factor,
            #             stochastic_output=self.stochastic_output,
            second_order_update=self.second_order_update,
            learning_rate=self.unn_learning_rate,
            compression=self.compression,
            information_gain=self.information_gain
        )

        logger.log(
            "Model built ({:.1f} sec).".format((time.time() - start_time)))

        if self.use_replay_pool:
            self.pool = SimpleReplayPool(
                max_pool_size=self.replay_pool_size,
                observation_shape=self.env.observation_space.shape,
                action_dim=act_dim
            )
        # ------------------------------------------------

        self.start_worker()
        self.init_opt()
        episode_rewards = []
        episode_lengths = []
        for itr in xrange(self.start_itr, self.n_itr):
            logger.push_prefix('itr #%d | ' % itr)

            paths = self.obtain_samples(itr)
            samples_data = self.process_samples(itr, paths)

            # Exploration code
            # ----------------
            if self.use_replay_pool:
                # Fill replay pool.
                logger.log("Fitting dynamics model using replay pool ...")
                for path in samples_data['paths']:
                    path_len = len(path['rewards'])
                    for i in xrange(path_len):
                        obs = path['observations'][i]
                        act = path['actions'][i]
                        rew = path['rewards'][i]
                        term = (i == path_len - 1)
                        self.pool.add_sample(obs, act, rew, term)

                # Now we train the dynamics model using the replay self.pool; only
                # if self.pool is large enough.
                if self.pool.size >= self.min_pool_size:
                    obs_mean, obs_std, act_mean, act_std = self.pool.mean_obs_act()
                    _inputss = []
                    _targetss = []
                    for _ in xrange(self.n_updates_per_sample):
                        batch = self.pool.random_batch(
                            self.pool_batch_size)
                        obs = (batch['observations'] - obs_mean) / \
                            (obs_std + 1e-8)
                        next_obs = (
                            batch['next_observations'] - obs_mean) / (obs_std + 1e-8)
                        act = (batch['actions'] - act_mean) / \
                            (act_std + 1e-8)
                        _inputs = np.hstack(
                            [obs, act])
                        _targets = next_obs
                        _inputss.append(_inputs)
                        _targetss.append(_targets)

                    old_acc = 0.
                    for _inputs, _targets in zip(_inputss, _targetss):
                        _out = self.bnn.pred_fn(_inputs)
                        old_acc += np.mean(np.square(_out - _targets))
                    old_acc /= len(_inputss)

                    for _inputs, _targets in zip(_inputss, _targetss):
                        self.bnn.train_fn(_inputs, _targets)

                    new_acc = 0.
                    for _inputs, _targets in zip(_inputss, _targetss):
                        _out = self.bnn.pred_fn(_inputs)
                        new_acc += np.mean(np.square(_out - _targets))
                    new_acc /= len(_inputss)

                    logger.record_tabular(
                        'BNN_DynModelSqLossBefore', old_acc)
                    logger.record_tabular(
                        'BNN_DynModelSqLossAfter', new_acc)
            # ----------------

            self.env.log_diagnostics(paths)
            self.policy.log_diagnostics(paths)
            self.baseline.log_diagnostics(paths)
            self.optimize_policy(itr, samples_data)
            logger.log("saving snapshot...")
            params = self.get_itr_snapshot(itr, samples_data)
            paths = samples_data["paths"]
            if self.store_paths:
                params["paths"] = paths
            episode_rewards.extend(sum(p["rewards"]) for p in paths)
            episode_lengths.extend(len(p["rewards"]) for p in paths)
            params["episode_rewards"] = np.array(episode_rewards)
            params["episode_lengths"] = np.array(episode_lengths)
            params["algo"] = self
            logger.save_itr_params(itr, params)
            logger.log("saved")
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.update_plot()
                if self.pause_for_plot:
                    raw_input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
Beispiel #38
0
    def train(self):

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma_es_lib.CMAEvolutionStrategy(
            cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (
                    stateful_pool.singleton_pool.run_map(sample_return, [(x, self.max_path_length,
                                                                          self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return, [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = - np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn',
                                  np.max(undiscounted_returns))
            logger.record_tabular('MinReturn',
                                  np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn',
                                  np.mean(fs))
            logger.record_tabular('AvgTrajLen',
                                  np.mean([len(info['returns']) for info in infos]))
            self.env.log_diagnostics(infos)
            self.policy.log_diagnostics(infos)

            logger.save_itr_params(itr, dict(
                itr=itr,
                policy=self.policy,
                env=self.env,
            ))
            logger.dump_tabular(with_prefix=False)
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

        # Set final params.
        self.policy.set_param_values(es.result()[0])
        parallel_sampler.terminate_task()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        with tf.Session() as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))

            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Sampling set of tasks/goals for this meta-batch...")

                    env = self.env
                    while 'sample_goals' not in dir(env):
                        env = env.wrapped_env
                    learner_env_goals = env.sample_goals(self.meta_batch_size)

                    self.policy.switch_to_init_dist()  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    for step in range(self.num_grad_updates+1):
                        #if step > 0:
                        #    import pdb; pdb.set_trace() # test param_vals functions.
                        logger.log('** Step ' + str(step) + ' **')
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(itr, reset_args=learner_env_goals, log_prefix=str(step))
                        all_paths.append(paths)
                        logger.log("Processing samples...")
                        samples_data = {}
                        for key in paths.keys():  # the keys are the tasks
                            # don't log because this will spam the consol with every task.
                            samples_data[key] = self.process_samples(itr, paths[key], log=False)
                        all_samples_data.append(samples_data)
                        # for logging purposes only
                        self.process_samples(itr, flatten_list(paths.values()), prefix=str(step), log=True)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(flatten_list(paths.values()), prefix=str(step))
                        if step < self.num_grad_updates:
                            logger.log("Computing policy updates...")
                            self.policy.compute_updated_dists(samples_data)


                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    self.optimize_policy(itr, all_samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, all_samples_data[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data[-1]["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

                    # The rest is some example plotting code.
                    # Plotting code is useful for visualizing trajectories across a few different tasks.
                    if False and itr % 2 == 0 and self.env.observation_space.shape[0] <= 4: # point-mass
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            plt.plot(learner_env_goals[ind][0], learner_env_goals[ind][1], 'k*', markersize=10)
                            plt.hold(True)

                            preupdate_paths = all_paths[0]
                            postupdate_paths = all_paths[-1]

                            pre_points = preupdate_paths[ind][0]['observations']
                            post_points = postupdate_paths[ind][0]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-b', linewidth=1)

                            pre_points = preupdate_paths[ind][1]['observations']
                            post_points = postupdate_paths[ind][1]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '--r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '--b', linewidth=1)

                            pre_points = preupdate_paths[ind][2]['observations']
                            post_points = postupdate_paths[ind][2]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-.r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-.b', linewidth=1)

                            plt.plot(0,0, 'k.', markersize=5)
                            plt.xlim([-0.8, 0.8])
                            plt.ylim([-0.8, 0.8])
                            plt.legend(['goal', 'preupdate path', 'postupdate path'])
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'prepost_path'+str(ind)+'.png'))
                    elif False and itr % 2 == 0:  # swimmer or cheetah
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            goal_vel = learner_env_goals[ind]
                            plt.title('Swimmer paths, goal vel='+str(goal_vel))
                            plt.hold(True)

                            prepathobs = all_paths[0][ind][0]['observations']
                            postpathobs = all_paths[-1][ind][0]['observations']
                            plt.plot(prepathobs[:,0], prepathobs[:,1], '-r', linewidth=2)
                            plt.plot(postpathobs[:,0], postpathobs[:,1], '--b', linewidth=1)
                            plt.plot(prepathobs[-1,0], prepathobs[-1,1], 'r*', markersize=10)
                            plt.plot(postpathobs[-1,0], postpathobs[-1,1], 'b*', markersize=10)
                            plt.xlim([-1.0, 5.0])
                            plt.ylim([-1.0, 1.0])

                            plt.legend(['preupdate path', 'postupdate path'], loc=2)
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'swim1d_prepost_itr'+str(itr)+'_id'+str(ind)+'.pdf'))
        self.shutdown_worker()
    def train(self):
        pool = SimpleReplayPool(
            max_pool_size=self.replay_pool_size,
            observation_dim=self.env.observation_space.flat_dim,
            action_dim=self.env.action_space.flat_dim,
        )
        self.start_worker()

        self.init_opt()
        itr = 0
        path_length = 0
        path_return = 0
        terminal = False
        observation = self.env.reset()

        sample_policy = pickle.loads(pickle.dumps(self.policy))
        #self.experiment_space = self.env.action_space
        
        for epoch in xrange(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(xrange(self.epoch_length)):
                # Execute policy
                if terminal: 
                    observation = self.env.reset()
                    self.es.reset()
                    sample_policy.reset()
                    self.es_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                action = self.es.get_action(itr, observation, policy=sample_policy)  # qf=qf)
                
                next_observation, reward, terminal, _ = self.env.step(action, observation)
                path_length += 1
                path_return += reward

                if not terminal and path_length >= self.max_path_length:
                    terminal = True
                    if self.include_horizon_terminal_transitions:
                        pool.add_sample(
                            self.env.observation_space.flatten(observation),
                            self.env.action_space.flatten(action),
                            reward * self.scale_reward,
                            terminal
                        )
                        
                else:
                    pool.add_sample(
                        self.env.observation_space.flatten(observation),
                        self.env.action_space.flatten(action),
                        reward * self.scale_reward,
                        terminal
                    )
                observation = next_observation

                if pool.size >= self.min_pool_size:
                    for update_itr in xrange(self.n_updates_per_sample):
                        # Train policy
                        batch = pool.random_batch(self.batch_size)
                        self.do_training(itr, batch)
                    sample_policy.set_param_values(self.policy.get_param_values())

                itr += 1
                self.pool = pool

            logger.log("Training finished")
            if pool.size >= self.min_pool_size:
                self.evaluate(epoch, pool)
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.update_plot()
                if self.pause_for_plot:
                    raw_input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.env.terminate()
        self.policy.terminate()
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """

        self._init_training(env, policy, pool)

        with self._sess.as_default():
            observation = env.reset()
            policy.reset()

            path_length = 0
            path_return = 0
            last_path_return = 0
            max_path_return = -np.inf
            n_episodes = 0
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                if self.iter_callback is not None:
                    self.iter_callback(locals(), globals())

                for t in range(self._epoch_length):
                    iteration = t + epoch * self._epoch_length

                    action, _ = policy.get_action(observation)
                    next_ob, reward, terminal, info = env.step(action)
                    path_length += 1
                    path_return += reward

                    self.pool.add_sample(
                        observation,
                        action,
                        reward,
                        terminal,
                        next_ob,
                    )

                    if terminal or path_length >= self._max_path_length:
                        observation = env.reset()
                        policy.reset()
                        path_length = 0
                        max_path_return = max(max_path_return, path_return)
                        last_path_return = path_return

                        path_return = 0
                        n_episodes += 1

                    else:
                        observation = next_ob
                    gt.stamp('sample')

                    if self.pool.size >= self._min_pool_size:
                        for i in range(self._n_train_repeat):
                            batch = self.pool.random_batch(self._batch_size)
                            self._do_training(iteration, batch)

                    gt.stamp('train')

                self._evaluate(epoch)

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)
                times_itrs = gt.get_times().stamps.itrs

                eval_time = times_itrs['eval'][-1] if epoch > 1 else 0
                total_time = gt.get_times().total
                logger.record_tabular('time-train', times_itrs['train'][-1])
                logger.record_tabular('time-eval', eval_time)
                logger.record_tabular('time-sample', times_itrs['sample'][-1])
                logger.record_tabular('time-total', total_time)
                logger.record_tabular('epoch', epoch)
                logger.record_tabular('episodes', n_episodes)
                logger.record_tabular('max-path-return', max_path_return)
                logger.record_tabular('last-path-return', last_path_return)
                logger.record_tabular('pool-size', self.pool.size)

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

                gt.stamp('eval')

            env.terminate()
Beispiel #42
0
    def train(self):
        with tf.Session() as sess:
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)
            uninit_vars = []
            for var in tf.all_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.initialize_variables(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)
                    #new_param_values = self.policy.get_variable_values(self.policy.all_params)

                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)

                    #import pickle
                    #with open('paths_itr'+str(itr)+'.pkl', 'wb') as f:
                    #    pickle.dump(paths, f)

                    # debugging
                    """
                    if itr % 1 == 0:
                        logger.log("Saving visualization of paths")
                        import matplotlib.pyplot as plt;
                        for ind in range(5):
                            plt.clf(); plt.hold(True)
                            points = paths[ind]['observations']
                            plt.plot(points[:,0], points[:,1], '-r', linewidth=2)
                            plt.xlim([-1.0, 1.0])
                            plt.ylim([-1.0, 1.0])
                            plt.legend(['path'])
                            plt.savefig('/home/cfinn/path'+str(ind)+'.png')
                    """
                    # end debugging

                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()
Beispiel #43
0
    def train(self):
        # This seems like a rather sequential method
        pool = SimpleReplayPool(
            max_pool_size=self.replay_pool_size,
            observation_dim=self.env.observation_space.flat_dim,
            action_dim=self.env.action_space.flat_dim,
        )
        self.start_worker()

        self.init_opt()
        itr = 0
        path_length = 0
        path_return = 0
        terminal = False
        observation = self.env.reset()

        sample_policy = pickle.loads(pickle.dumps(self.policy))

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # Execute policy
                if terminal:  # or path_length > self.max_path_length:
                    # Note that if the last time step ends an episode, the very
                    # last state and observation will be ignored and not added
                    # to the replay pool
                    observation = self.env.reset()
                    self.es.reset()
                    sample_policy.reset()
                    self.es_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                action = self.es.get_action(itr, observation, policy=sample_policy)  # qf=qf)

                next_observation, reward, terminal, _ = self.env.step(action)
                path_length += 1
                path_return += reward

                if not terminal and path_length >= self.max_path_length:
                    terminal = True
                    # only include the terminal transition in this case if the flag was set
                    if self.include_horizon_terminal_transitions:
                        pool.add_sample(observation, action, reward * self.scale_reward, terminal)
                else:
                    pool.add_sample(observation, action, reward * self.scale_reward, terminal)

                observation = next_observation

                if pool.size >= self.min_pool_size:
                    for update_itr in range(self.n_updates_per_sample):
                        # Train policy
                        batch = pool.random_batch(self.batch_size)
                        self.do_training(itr, batch)
                    sample_policy.set_param_values(self.policy.get_param_values())

                itr += 1

            logger.log("Training finished")
            if pool.size >= self.min_pool_size:
                self.evaluate(epoch, pool)
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.update_plot()
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.env.terminate()
        self.policy.terminate()
Beispiel #44
0
    def train(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.init_std
        cur_mean = self.policy.get_param_values()
        # K = cur_mean.size
        n_best = max(1, int(self.n_samples * self.best_frac))

        for itr in range(self.n_itr):
            # sample around the current distribution
            extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0)
            sample_std = np.sqrt(np.square(cur_std) + np.square(self.extra_std) * extra_var_mult)
            if self.batch_size is None:
                criterion = 'paths'
                threshold = self.n_samples
            else:
                criterion = 'samples'
                threshold = self.batch_size
            infos = stateful_pool.singleton_pool.run_collect(
                _worker_rollout_policy,
                threshold=threshold,
                args=(dict(cur_mean=cur_mean,
                          sample_std=sample_std,
                          max_path_length=self.max_path_length,
                          discount=self.discount,
                          criterion=criterion),)
            )
            xs = np.asarray([info[0] for info in infos])
            paths = [info[1] for info in infos]

            fs = np.array([path['returns'][0] for path in paths])
            print((xs.shape, fs.shape))
            best_inds = (-fs).argsort()[:n_best]
            best_xs = xs[best_inds]
            cur_mean = best_xs.mean(axis=0)
            cur_std = best_xs.std(axis=0)
            best_x = best_xs[0]
            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array([path['undiscounted_return'] for path in paths])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn',
                                  np.max(undiscounted_returns))
            logger.record_tabular('MinReturn',
                                  np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn',
                                  np.mean(fs))
            logger.record_tabular('AvgTrajLen',
                                  np.mean([len(path['returns']) for path in paths]))
            logger.record_tabular('NumTrajs',
                                  len(paths))
            self.policy.set_param_values(best_x)
            self.env.log_diagnostics(paths)
            self.policy.log_diagnostics(paths)
            logger.save_itr_params(itr, dict(
                itr=itr,
                policy=self.policy,
                env=self.env,
                cur_mean=cur_mean,
                cur_std=cur_std,
            ))
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
        parallel_sampler.terminate_task()