Exemplo n.º 1
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())

        self.start_worker(sess)
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                params = self.optimize_policy(itr, )
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
                logger.log("Saving snapshot...")
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('IterTime', time.time() - itr_start_time)
                logger.record_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
        self.shutdown_worker()
        if created_session:
            sess.close()
Exemplo n.º 2
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        plotter.close()
        self.shutdown_worker()
Exemplo n.º 3
0
 def optimize_policy(self, itr, samples_data):
     self._top_algo.optimize_policy(itr, samples_data)
     with logger.prefix('ASA | '):
         make_skill, skill_subpath = self.decide_new_skill(samples_data)
         if make_skill:
             new_skill_pol, new_skill_id = self.create_and_train_new_skill(skill_subpath)
             self.integrate_new_skill(new_skill_id, skill_subpath)
Exemplo n.º 4
0
 def train_once(self, itr, paths):
     itr_start_time = time.time()
     with logger.prefix('itr #%d | ' % itr):
         self.log_diagnostics(paths)
         logger.log("Optimizing policy...")
         self.optimize_policy(itr, paths)
         logger.record_tabular('IterTime', time.time() - itr_start_time)
         logger.dump_tabular()
Exemplo n.º 5
0
    def train(self, sess=None):
        address = ("localhost", 6000)
        conn = Client(address)
        last_average_return = None
        try:
            created_session = True if (sess is None) else False
            if sess is None:
                sess = tf.Session()
                sess.__enter__()

            sess.run(tf.global_variables_initializer())
            conn.send(ExpLifecycle.START)
            self.start_worker(sess)
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Obtaining samples...")
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.process_samples(itr, paths)
                    last_average_return = samples_data["average_return"]
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        self.plotter.update_plot(self.policy,
                                                 self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

            conn.send(ExpLifecycle.SHUTDOWN)
            self.shutdown_worker()
            if created_session:
                sess.close()
        finally:
            conn.close()
        return last_average_return
Exemplo n.º 6
0
    def _training_step(self, itr):
        itr_start_time = time.time()

        with logger.prefix('itr #%d | ' % itr):
            self._sampling()

            self._bookkeeping()

            self._memory_selection(itr)
            self._policy_optimization(itr)

            if itr % self.evaluation_interval == 0:
                self._policy_evaluation()

            self._log_diagnostics(itr)

            logger.record_tabular('Time', time.time() - self.start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)
Exemplo n.º 7
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
        if created_session:
            sess.close()
        return last_average_return
Exemplo n.º 8
0
 def train(self):
     with tf.Session() as sess:
         sess.run(tf.initialize_all_variables())
         self.start_worker(sess)
         start_time = time.time()
         self.num_samples = 0
         for itr in range(self.start_itr, self.n_itr):
             itr_start_time = time.time()
             with logger.prefix('itr #%d | ' % itr):
                 logger.log("Obtaining new samples...")
                 paths = self.obtain_samples(itr)
                 for path in paths:
                     self.num_samples += len(path["rewards"])
                 logger.log("total num samples..." + str(self.num_samples))
                 logger.log("Processing samples...")
                 samples_data = self.process_samples(itr, paths)
                 logger.log("Logging diagnostics...")
                 self.log_diagnostics(paths)
                 logger.log("Optimizing policy...")
                 self.outer_optimize(samples_data)
                 for sub_itr in range(self.n_sub_itr):
                     logger.log("Minibatch Optimizing...")
                     self.inner_optimize(samples_data)
                 logger.log("Saving snapshot...")
                 params = self.get_itr_snapshot(itr,
                                                samples_data)  # , **kwargs)
                 if self.store_paths:
                     params["paths"] = samples_data["paths"]
                 logger.save_itr_params(itr, params)
                 logger.log("Saved")
                 logger.record_tabular('Time', time.time() - start_time)
                 logger.record_tabular('ItrTime',
                                       time.time() - itr_start_time)
                 logger.dump_tabular(with_prefix=False)
                 #if self.plot:
                 #   self.update_plot()
                 #  if self.pause_for_plot:
                 #     input("Plotting evaluation run: Press Enter to "
                 #          "continue...")
     self.shutdown_worker()
    def train(self):
        address = ("localhost", 6000)
        conn = Client(address)
        try:
            plotter = Plotter()
            if self.plot:
                plotter.init_plot(self.env, self.policy)
            conn.send(ExpLifecycle.START)
            self.start_worker()
            self.init_opt()
            for itr in range(self.current_itr, self.n_itr):
                with logger.prefix('itr #%d | ' % itr):
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.sampler.obtain_samples(itr)
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log("saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)
                    self.current_itr = itr + 1
                    params["algo"] = self
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("saved")
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        plotter.update_plot(self.policy, self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

            conn.send(ExpLifecycle.SHUTDOWN)
            plotter.close()
            self.shutdown_worker()
        finally:
            conn.close()
Exemplo n.º 10
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)

        if self.use_target:
            self.f_init_target()

        episode_rewards = []
        episode_policy_losses = []
        episode_qf_losses = []
        epoch_ys = []
        epoch_qs = []
        last_average_return = None

        for epoch in range(self.n_epochs):
            self.success_history.clear()
            with logger.prefix('epoch #%d | ' % epoch):
                for epoch_cycle in range(self.n_epoch_cycles):
                    paths = self.obtain_samples(epoch)
                    samples_data = self.process_samples(epoch, paths)
                    episode_rewards.extend(
                        samples_data["undiscounted_returns"])
                    self.success_history.extend(
                        samples_data["success_history"])
                    self.log_diagnostics(paths)
                    for train_itr in range(self.n_train_steps):
                        if self.replay_buffer.n_transitions_stored >= self.min_buffer_size:  # noqa: E501
                            self.evaluate = True
                            qf_loss, y, q, policy_loss = self.optimize_policy(
                                epoch, samples_data)

                            episode_policy_losses.append(policy_loss)
                            episode_qf_losses.append(qf_loss)
                            epoch_ys.append(y)
                            epoch_qs.append(q)

                    if self.plot:
                        self.plotter.update_plot(self.policy,
                                                 self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

                logger.log("Training finished")
                logger.log("Saving snapshot #{}".format(epoch))
                params = self.get_itr_snapshot(epoch, samples_data)
                logger.save_itr_params(epoch, params)
                logger.log("Saved")
                if self.evaluate:
                    logger.record_tabular('Epoch', epoch)
                    logger.record_tabular('AverageReturn',
                                          np.mean(episode_rewards))
                    logger.record_tabular('StdReturn', np.std(episode_rewards))
                    logger.record_tabular('Policy/AveragePolicyLoss',
                                          np.mean(episode_policy_losses))
                    logger.record_tabular('QFunction/AverageQFunctionLoss',
                                          np.mean(episode_qf_losses))
                    logger.record_tabular('QFunction/AverageQ',
                                          np.mean(epoch_qs))
                    logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                    logger.record_tabular('QFunction/AverageAbsQ',
                                          np.mean(np.abs(epoch_qs)))
                    logger.record_tabular('QFunction/AverageY',
                                          np.mean(epoch_ys))
                    logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                    logger.record_tabular('QFunction/AverageAbsY',
                                          np.mean(np.abs(epoch_ys)))
                    if self.input_include_goal:
                        logger.record_tabular('AverageSuccessRate',
                                              np.mean(self.success_history))
                    last_average_return = np.mean(episode_rewards)

                if not self.smooth_return:
                    episode_rewards = []
                    episode_policy_losses = []
                    episode_qf_losses = []
                    epoch_ys = []
                    epoch_qs = []

                logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()
        if created_session:
            sess.close()
        return last_average_return
Exemplo n.º 11
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
            sess.run(tf.global_variables_initializer())

        # Initialize some missing variables
        uninitialized_vars = []
        for var in tf.all_variables():
            try:
                sess.run(var)
            except tf.errors.FailedPreconditionError:
                print("Uninitialized var: ", var)
                uninitialized_vars.append(var)
        init_new_vars_op = tf.initialize_variables(uninitialized_vars)
        sess.run(init_new_vars_op)

        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        samples_total = 0
        for itr in range(self.start_itr, self.n_itr):
            if samples_total >= self.max_samples:
                print("WARNING: Total max num of samples collected: %d >= %d" %
                      (samples_total, self.max_samples))
                break
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                samples_total += self.batch_size
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                # import pdb; pdb.set_trace()
                if self.store_paths:
                    ## WARN: Beware that data is saved to hdf in float32 by default
                    # see param float_nptype
                    h5u.append_train_iter_data(h5file=self.hdf,
                                               data=samples_data["paths"],
                                               data_group="traj_data/",
                                               teacher_indx=self.teacher_indx,
                                               itr=None,
                                               float_nptype=np.float32)
                    # params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                self.log_env_info(samples_data["env_infos"])
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input(
                            "Plotting evaluation run: Press Enter to continue..."
                        )
                # Showing policy from time to time
                if self.record_every_itr is not None and self.record_every_itr > 0 and itr % self.record_every_itr == 0:
                    self.record_policy(env=self.env,
                                       policy=self.policy,
                                       itr=itr)
                if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                    self.play_policy(env=self.env, policy=self.policy)

        # Recording a few episodes at the end
        if self.record_end_ep_num is not None:
            for i in range(self.record_end_ep_num):
                self.record_policy(env=self.env,
                                   policy=self.policy,
                                   itr=itr,
                                   postfix="_%02d" % i)

        # Reporting termination criteria
        if itr >= self.n_itr - 1:
            print(
                "TERM CRITERIA: Max number of iterations reached itr: %d , itr_max: %d"
                % (itr, self.n_itr - 1))
        if samples_total >= self.max_samples:
            print(
                "TERM CRITERIA: Total max num of samples collected: %d >= %d" %
                (samples_total, self.max_samples))

        self.shutdown_worker()
        if created_session:
            sess.close()
Exemplo n.º 12
0
def run_task(*_):
    # Configure TF session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config).as_default() as tf_session:
        ## Load data from itr_N.pkl
        with open(snapshot_file, 'rb') as file:
            saved_data = dill.load(file)


        ## Construct PathTrie and find missing skill description
        # This is basically ASA.decide_new_skill
        min_length = 3
        max_length = 5
        action_map = {0: 's', 1: 'L', 2: 'R'}
        min_f_score = 1
        max_results = 10
        aggregations = []  # sublist of ['mean', 'most_freq', 'nearest_mean', 'medoid'] or 'all'

        paths = saved_data['paths']
        path_trie = PathTrie(saved_data['hrl_policy'].num_skills)
        for path in paths:
            actions = path['actions'].argmax(axis=1).tolist()
            observations = path['observations']
            path_trie.add_all_subpaths(
                    actions,
                    observations,
                    min_length=min_length,
                    max_length=max_length
            )
        logger.log('Searched {} rollouts'.format(len(paths)))

        frequent_paths = path_trie.items(
                action_map=action_map,
                min_count=10,  # len(paths) * 2
                min_f_score=min_f_score,
                max_results=max_results,
                aggregations=aggregations
        )
        logger.log('Found {} frequent paths: [index, actions, count, f-score]'.format(len(frequent_paths)))
        for i, f_path in enumerate(frequent_paths):
            logger.log('    {:2}: {:{pad}}\t{}\t{:.3f}'.format(
                i,
                f_path['actions_text'],
                f_path['count'],
                f_path['f_score'],
                pad=max_length))

        top_subpath = frequent_paths[0]
        start_obss = top_subpath['start_observations']
        end_obss   = top_subpath['end_observations']



        ## Prepare elements for training
        # Environment
        base_env = saved_data['env'].env.env  # <NormalizedEnv<MinibotEnv instance>>
        skill_learning_env = TfEnv(
                SkillLearningEnv(
                    # base env that was wrapped in HierarchizedEnv (not fully unwrapped - may be normalized!)
                    env=base_env,
                    start_obss=start_obss,
                    end_obss=end_obss
                )
        )

        # Skill policy
        hrl_policy = saved_data['hrl_policy']
        new_skill_policy, new_skill_id = hrl_policy.create_new_skill(
                end_obss=end_obss
        )

        # Baseline - clone baseline specified in low_algo_kwargs, or top-algo`s baseline
        low_algo_kwargs = dict(saved_data['low_algo_kwargs'])
        baseline_to_clone = low_algo_kwargs.get('baseline', saved_data['baseline'])
        baseline = Serializable.clone(  # to create blank baseline
                obj=baseline_to_clone,
                name='{}Skill{}'.format(type(baseline_to_clone).__name__, new_skill_id)
        )
        low_algo_kwargs['baseline'] = baseline
        low_algo_cls = saved_data['low_algo_cls']

        # Set custom training params (should`ve been set in asa_basic_run)
        low_algo_kwargs['batch_size'] = 2500
        low_algo_kwargs['max_path_length'] = 50
        low_algo_kwargs['n_itr'] = 500

        # Algorithm
        algo = low_algo_cls(
            env=skill_learning_env,
            policy=new_skill_policy,
            **low_algo_kwargs
        )

        # Logger parameters
        logger_snapshot_dir_before = logger.get_snapshot_dir()
        logger_snapshot_mode_before = logger.get_snapshot_mode()
        logger_snapshot_gap_before = logger.get_snapshot_gap()
        # No need to change snapshot dir in this script, it is used in ASA-algo.create_and_train_new_skill()
        # logger.set_snapshot_dir(os.path.join(
        #         logger_snapshot_dir_before,
        #         'skill{}'.format(new_skill_id)
        # ))
        logger.set_snapshot_mode('none')
        logger.set_tensorboard_step_key('Iteration')


        ## Train new skill
        with logger.prefix('Skill {} | '.format(new_skill_id)):
            algo.train(sess=tf_session)



        ## Save new policy and its end_obss (we`ll construct skill stopping function
        #  from them manually in asa_resume_with_new_skill.py)
        out_file = os.path.join(logger.get_snapshot_dir(), 'final.pkl')
        with open(out_file, 'wb') as file:
            out_data = {
                    'policy': new_skill_policy,
                    'subpath': top_subpath
            }
            dill.dump(out_data, file)

        # Restore logger parameters
        logger.set_snapshot_dir(logger_snapshot_dir_before)
        logger.set_snapshot_mode(logger_snapshot_mode_before)
        logger.set_snapshot_gap(logger_snapshot_gap_before)
Exemplo n.º 13
0
    def create_and_train_new_skill(self, skill_subpath):
        """
        Create and train a new skill based on given subpath. The new skill policy and
        ID are returned, and also saved in self._hrl_policy.
        """
        ## Prepare elements for training
        # Environment
        skill_learning_env = TfEnv(
                SkillLearningEnv(
                    # base env that was wrapped in HierarchizedEnv (not fully unwrapped - may be normalized!)
                    env=self.env.env.env,
                    start_obss=skill_subpath['start_observations'],
                    end_obss=skill_subpath['end_observations']
                )
        )

        # Skill policy
        new_skill_pol, new_skill_id = self._hrl_policy.create_new_skill(skill_subpath['end_observations'])  # blank policy to be trained

        # Baseline - clone baseline specified in low_algo_kwargs, or top-algo`s baseline
        #   We need to clone baseline, as each skill policy must have its own instance
        la_kwargs = dict(self._low_algo_kwargs)
        baseline_to_clone = la_kwargs.get('baseline', self.baseline)
        baseline = Serializable.clone(  # to create blank baseline
                obj=baseline_to_clone,
                name='{}Skill{}'.format(type(baseline_to_clone).__name__, new_skill_id)
        )
        la_kwargs['baseline'] = baseline

        # Algorithm
        algo = self._low_algo_cls(
                env=skill_learning_env,
                policy=new_skill_pol,
                **la_kwargs
        )

        # Logger parameters
        logger.dump_tabular(with_prefix=False)
        logger.log('Launching training of the new skill')
        logger_snapshot_dir_before = logger.get_snapshot_dir()
        logger_snapshot_mode_before = logger.get_snapshot_mode()
        logger_snapshot_gap_before = logger.get_snapshot_gap()
        logger.set_snapshot_dir(os.path.join(
                logger_snapshot_dir_before,
                'skill{}'.format(new_skill_id)
        ))
        logger.set_snapshot_mode('none')
        # logger.set_snapshot_gap(max(1, np.floor(la_kwargs['n_itr'] / 10)))
        logger.push_tabular_prefix('Skill{}/'.format(new_skill_id))
        logger.set_tensorboard_step_key('Iteration')

        # Train new skill
        with logger.prefix('Skill {} | '.format(new_skill_id)):
            algo.train(sess=self._tf_sess)

        # Restore logger parameters
        logger.pop_tabular_prefix()
        logger.set_snapshot_dir(logger_snapshot_dir_before)
        logger.set_snapshot_mode(logger_snapshot_mode_before)
        logger.set_snapshot_gap(logger_snapshot_gap_before)
        logger.log('Training of the new skill finished')

        return new_skill_pol, new_skill_id