示例#1
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        plotter.close()
        self.shutdown_worker()
示例#2
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())

        self.start_worker(sess)
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                params = self.optimize_policy(itr, )
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
                logger.log("Saving snapshot...")
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('IterTime', time.time() - itr_start_time)
                logger.record_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
        self.shutdown_worker()
        if created_session:
            sess.close()
示例#3
0
 def train_once(self, itr, paths):
     itr_start_time = time.time()
     with logger.prefix('itr #%d | ' % itr):
         self.log_diagnostics(paths)
         logger.log("Optimizing policy...")
         self.optimize_policy(itr, paths)
         logger.record_tabular('IterTime', time.time() - itr_start_time)
         logger.dump_tabular()
示例#4
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(range(self._n_epochs + 1),
                                      save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(iteration=t +
                                          epoch * self._epoch_length,
                                          batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
示例#5
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(
                            iteration=t + epoch * self._epoch_length,
                            batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
示例#6
0
    def train(self, sess=None):
        address = ("localhost", 6000)
        conn = Client(address)
        last_average_return = None
        try:
            created_session = True if (sess is None) else False
            if sess is None:
                sess = tf.Session()
                sess.__enter__()

            sess.run(tf.global_variables_initializer())
            conn.send(ExpLifecycle.START)
            self.start_worker(sess)
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Obtaining samples...")
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.process_samples(itr, paths)
                    last_average_return = samples_data["average_return"]
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        self.plotter.update_plot(self.policy,
                                                 self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

            conn.send(ExpLifecycle.SHUTDOWN)
            self.shutdown_worker()
            if created_session:
                sess.close()
        finally:
            conn.close()
        return last_average_return
示例#7
0
    def _training_step(self, itr):
        itr_start_time = time.time()

        with logger.prefix('itr #%d | ' % itr):
            self._sampling()

            self._bookkeeping()

            self._memory_selection(itr)
            self._policy_optimization(itr)

            if itr % self.evaluation_interval == 0:
                self._policy_evaluation()

            self._log_diagnostics(itr)

            logger.record_tabular('Time', time.time() - self.start_time)
            logger.record_tabular('ItrTime', time.time() - itr_start_time)
            logger.dump_tabular(with_prefix=False)
示例#8
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
        if created_session:
            sess.close()
        return last_average_return
示例#9
0
文件: catrpo.py 项目: Mee321/HAPG_exp
 def train(self):
     with tf.Session() as sess:
         sess.run(tf.initialize_all_variables())
         self.start_worker(sess)
         start_time = time.time()
         self.num_samples = 0
         for itr in range(self.start_itr, self.n_itr):
             itr_start_time = time.time()
             with logger.prefix('itr #%d | ' % itr):
                 logger.log("Obtaining new samples...")
                 paths = self.obtain_samples(itr)
                 for path in paths:
                     self.num_samples += len(path["rewards"])
                 logger.log("total num samples..." + str(self.num_samples))
                 logger.log("Processing samples...")
                 samples_data = self.process_samples(itr, paths)
                 logger.log("Logging diagnostics...")
                 self.log_diagnostics(paths)
                 logger.log("Optimizing policy...")
                 self.outer_optimize(samples_data)
                 for sub_itr in range(self.n_sub_itr):
                     logger.log("Minibatch Optimizing...")
                     self.inner_optimize(samples_data)
                 logger.log("Saving snapshot...")
                 params = self.get_itr_snapshot(itr,
                                                samples_data)  # , **kwargs)
                 if self.store_paths:
                     params["paths"] = samples_data["paths"]
                 logger.save_itr_params(itr, params)
                 logger.log("Saved")
                 logger.record_tabular('Time', time.time() - start_time)
                 logger.record_tabular('ItrTime',
                                       time.time() - itr_start_time)
                 logger.dump_tabular(with_prefix=False)
                 #if self.plot:
                 #   self.update_plot()
                 #  if self.pause_for_plot:
                 #     input("Plotting evaluation run: Press Enter to "
                 #          "continue...")
     self.shutdown_worker()
    def train(self):
        address = ("localhost", 6000)
        conn = Client(address)
        try:
            plotter = Plotter()
            if self.plot:
                plotter.init_plot(self.env, self.policy)
            conn.send(ExpLifecycle.START)
            self.start_worker()
            self.init_opt()
            for itr in range(self.current_itr, self.n_itr):
                with logger.prefix('itr #%d | ' % itr):
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.sampler.obtain_samples(itr)
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log("saving snapshot...")
                    params = self.get_itr_snapshot(itr, samples_data)
                    self.current_itr = itr + 1
                    params["algo"] = self
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("saved")
                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        plotter.update_plot(self.policy, self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

            conn.send(ExpLifecycle.SHUTDOWN)
            plotter.close()
            self.shutdown_worker()
        finally:
            conn.close()
示例#11
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma.CMAEvolutionStrategy(cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (stateful_pool.singleton_pool.run_map(
                    sample_return,
                    [(x, self.max_path_length, self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return,
                        [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = -np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn', np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
            logger.record_tabular('MinReturn', np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn', np.mean(fs))
            logger.record_tabular(
                'AvgTrajLen',
                np.mean([len(info['returns']) for info in infos]))
            self.policy.log_diagnostics(infos)
            logger.save_itr_params(
                itr, dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                ))
            logger.dump_tabular(with_prefix=False)
            if self.plot:
                self.plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

            # Showing policy from time to time
            if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                self.play_policy(env=self.env, policy=self.policy, n_rollout=self.play_rollouts_num)

        # Set final params.
        self.policy.set_param_values(es.result()[0])
        parallel_sampler.terminate_task()
        self.plotter.close()
示例#12
0
文件: ddpg.py 项目: gntoni/garage
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        replay_buffer = self.opt_info["replay_buffer"]
        f_init_target = self.opt_info["f_init_target"]
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                for rollout in range(self.n_rollout_steps):
                    action = self.es.get_action(rollout, observation,
                                                self.actor)
                    assert action.shape == self.env.action_space.shape

                    next_observation, reward, terminal, info = self.env.step(
                        action)
                    episode_reward += reward
                    episode_step += 1

                    replay_buffer.add_transition(observation, action,
                                                 reward * self.reward_scale,
                                                 terminal, next_observation)

                    observation = next_observation

                    if terminal:
                        episode_rewards.append(episode_reward)
                        episode_steps.append(episode_step)
                        episode_reward = 0.
                        episode_step = 0
                        episodes += 1

                        observation = self.env.reset()
                        if self.es:
                            self.es.reset()

                for train_itr in range(self.n_train_steps):
                    if replay_buffer.size >= self.min_buffer_size:
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

            logger.log("Training finished")
            if replay_buffer.size >= self.min_buffer_size:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))

                # Uncomment the following if you want to calculate the average
                # in each epoch
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.shutdown()
        if created_session:
            sess.close()
示例#13
0
文件: ddpg.py 项目: ScapeQin/garage
    def train(self):
        # This seems like a rather sequential method
        input_shapes = dims_to_shapes(self.input_dims)
        pool = ReplayBuffer(
            buffer_shapes=input_shapes,
            max_buffer_size=self.replay_pool_size,
        )
        self.start_worker()

        self.init_opt()
        itr = 0
        path_length = 0
        path_return = 0
        terminal = False
        observation = self.env.reset()

        sample_policy = pickle.loads(pickle.dumps(self.policy))

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # Execute policy
                if terminal:  # or path_length > self.max_path_length:
                    # Note that if the last time step ends an episode, the very
                    # last state and observation will be ignored and not added
                    # to the replay pool
                    observation = self.env.reset()
                    self.es.reset()
                    sample_policy.reset()
                    self.es_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                action = self.es.get_action(
                    itr, observation, policy=sample_policy)

                next_observation, reward, terminal, _ = self.env.step(action)
                path_length += 1
                path_return += reward

                if not terminal and path_length >= self.max_path_length:
                    terminal = True
                    # only include the terminal transition in this case if the
                    # flag was set
                    if self.include_horizon_terminal_transitions:
                        pool.add_transition(
                            observation=observation,
                            action=action,
                            reward=reward * self.scale_reward,
                            terminal=terminal,
                            next_observation=next_observation)
                else:
                    pool.add_transition(
                        observation=observation,
                        action=action,
                        reward=reward * self.scale_reward,
                        terminal=terminal,
                        next_observation=next_observation)

                observation = next_observation

                if pool.size >= self.min_pool_size:
                    for update_itr in range(self.n_updates_per_sample):
                        # Train policy
                        batch = pool.sample(self.batch_size)
                        self.do_training(itr, batch)
                    sample_policy.set_param_values(
                        self.policy.get_param_values())

                itr += 1

            logger.log("Training finished")
            if pool.size >= self.min_pool_size:
                self.evaluate(epoch, pool)
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.update_plot()
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")
        self.env.close()
        self.policy.terminate()
        self.plotter.close()
示例#14
0
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        self.f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            self.success_history.clear()
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                if self.use_her:
                    successes = []
                    for rollout in range(self.n_rollout_steps):
                        o = np.clip(observation["observation"], -self.clip_obs,
                                    self.clip_obs)
                        g = np.clip(observation["desired_goal"],
                                    -self.clip_obs, self.clip_obs)
                        obs_goal = np.concatenate((o, g), axis=-1)
                        action = self.es.get_action(rollout, obs_goal,
                                                    self.actor)

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        if 'is_success' in info:
                            successes.append([info["is_success"]])
                        episode_reward += reward
                        episode_step += 1

                        info_dict = {
                            "info_{}".format(key): info[key].reshape(1)
                            for key in info.keys()
                        }
                        self.replay_buffer.add_transition(
                            observation=observation['observation'],
                            action=action,
                            goal=observation['desired_goal'],
                            achieved_goal=observation['achieved_goal'],
                            **info_dict,
                        )

                        observation = next_observation

                        if rollout == self.n_rollout_steps - 1:
                            self.replay_buffer.add_transition(
                                observation=observation['observation'],
                                achieved_goal=observation['achieved_goal'])

                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    successful = np.array(successes)[-1, :]
                    success_rate = np.mean(successful)
                    self.success_history.append(success_rate)

                    for train_itr in range(self.n_train_steps):
                        self.evaluate = True
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

                    self.f_update_target()
                else:
                    for rollout in range(self.n_rollout_steps):
                        action = self.es.get_action(rollout, observation,
                                                    self.actor)
                        assert action.shape == self.env.action_space.shape

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        episode_reward += reward
                        episode_step += 1

                        self.replay_buffer.add_transition(
                            observation=observation,
                            action=action,
                            reward=reward * self.reward_scale,
                            terminal=terminal,
                            next_observation=next_observation,
                        )

                        observation = next_observation

                        if terminal or rollout == self.n_rollout_steps - 1:
                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    for train_itr in range(self.n_train_steps):
                        if self.replay_buffer.size >= self.min_buffer_size:
                            self.evaluate = True
                            critic_loss, y, q, action_loss = self._learn()

                            episode_actor_losses.append(action_loss)
                            episode_critic_losses.append(critic_loss)
                            epoch_ys.append(y)
                            epoch_qs.append(q)

            logger.log("Training finished")
            logger.log("Saving snapshot")
            itr = epoch * self.n_epoch_cycles + epoch_cycle
            params = self.get_itr_snapshot(itr)
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            if self.evaluate:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))
                if self.use_her:
                    logger.record_tabular('AverageSuccessRate',
                                          np.mean(self.success_history))

                # Uncomment the following if you want to calculate the average
                # in each epoch, better uncomment when self.use_her is True
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.close()
        if created_session:
            sess.close()
示例#15
0
文件: ddpg.py 项目: wwxFromTju/garage
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)

        if self.use_target:
            self.f_init_target()

        episode_rewards = []
        episode_policy_losses = []
        episode_qf_losses = []
        epoch_ys = []
        epoch_qs = []
        last_average_return = None

        for epoch in range(self.n_epochs):
            self.success_history.clear()
            with logger.prefix('epoch #%d | ' % epoch):
                for epoch_cycle in range(self.n_epoch_cycles):
                    paths = self.obtain_samples(epoch)
                    samples_data = self.process_samples(epoch, paths)
                    episode_rewards.extend(
                        samples_data["undiscounted_returns"])
                    self.success_history.extend(
                        samples_data["success_history"])
                    self.log_diagnostics(paths)
                    for train_itr in range(self.n_train_steps):
                        if self.replay_buffer.n_transitions_stored >= self.min_buffer_size:  # noqa: E501
                            self.evaluate = True
                            qf_loss, y, q, policy_loss = self.optimize_policy(
                                epoch, samples_data)

                            episode_policy_losses.append(policy_loss)
                            episode_qf_losses.append(qf_loss)
                            epoch_ys.append(y)
                            epoch_qs.append(q)

                    if self.plot:
                        self.plotter.update_plot(self.policy,
                                                 self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

                logger.log("Training finished")
                logger.log("Saving snapshot #{}".format(epoch))
                params = self.get_itr_snapshot(epoch, samples_data)
                logger.save_itr_params(epoch, params)
                logger.log("Saved")
                if self.evaluate:
                    logger.record_tabular('Epoch', epoch)
                    logger.record_tabular('AverageReturn',
                                          np.mean(episode_rewards))
                    logger.record_tabular('StdReturn', np.std(episode_rewards))
                    logger.record_tabular('Policy/AveragePolicyLoss',
                                          np.mean(episode_policy_losses))
                    logger.record_tabular('QFunction/AverageQFunctionLoss',
                                          np.mean(episode_qf_losses))
                    logger.record_tabular('QFunction/AverageQ',
                                          np.mean(epoch_qs))
                    logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                    logger.record_tabular('QFunction/AverageAbsQ',
                                          np.mean(np.abs(epoch_qs)))
                    logger.record_tabular('QFunction/AverageY',
                                          np.mean(epoch_ys))
                    logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                    logger.record_tabular('QFunction/AverageAbsY',
                                          np.mean(np.abs(epoch_ys)))
                    if self.input_include_goal:
                        logger.record_tabular('AverageSuccessRate',
                                              np.mean(self.success_history))
                    last_average_return = np.mean(episode_rewards)

                if not self.smooth_return:
                    episode_rewards = []
                    episode_policy_losses = []
                    episode_qf_losses = []
                    epoch_ys = []
                    epoch_qs = []

                logger.dump_tabular(with_prefix=False)

        self.shutdown_worker()
        if created_session:
            sess.close()
        return last_average_return
示例#16
0
    def train(self, sess=None):
        # created_session = True if (sess is None) else False
        # if sess is None:
        #     sess = tf.Session()
        #     sess.__enter__()

        sess.run(tf.global_variables_initializer())

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

        cur_std = self.init_std
        cur_mean = self.policy.get_param_values()
        # K = cur_mean.size
        n_best = max(1, int(self.n_samples * self.best_frac))

        for itr in range(self.n_itr):
            # sample around the current distribution
            extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0)
            sample_std = np.sqrt(
                np.square(cur_std) +
                np.square(self.extra_std) * extra_var_mult)
            if self.batch_size is None:
                criterion = 'paths'
                threshold = self.n_samples
            else:
                criterion = 'samples'
                threshold = self.batch_size
            infos = stateful_pool.singleton_pool.run_collect(
                _worker_rollout_policy,
                threshold=threshold,
                args=(dict(
                    cur_mean=cur_mean,
                    sample_std=sample_std,
                    max_path_length=self.max_path_length,
                    discount=self.discount,
                    criterion=criterion,
                    n_evals=self.n_evals), ))
            xs = np.asarray([info[0] for info in infos])
            paths = [info[1] for info in infos]

            fs = np.array([path['returns'][0] for path in paths])
            print((xs.shape, fs.shape))
            best_inds = (-fs).argsort()[:n_best]
            best_xs = xs[best_inds]
            cur_mean = best_xs.mean(axis=0)
            cur_std = best_xs.std(axis=0)
            best_x = best_xs[0]
            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [path['undiscounted_return'] for path in paths])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn', np.std(undiscounted_returns))
            logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
            logger.record_tabular('MinReturn', np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn', np.mean(fs))
            logger.record_tabular('NumTrajs', len(paths))
            paths = list(chain(
                *[d['full_paths']
                  for d in paths]))  # flatten paths for the case n_evals > 1
            logger.record_tabular(
                'AvgTrajLen',
                np.mean([len(path['returns']) for path in paths]))

            self.policy.set_param_values(best_x)
            self.policy.log_diagnostics(paths)
            logger.save_itr_params(
                itr,
                dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                    cur_mean=cur_mean,
                    cur_std=cur_std,
                ))
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.policy, self.max_path_length)
            
            # Showing policy from time to time
            if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                self.play_policy(env=self.env, policy=self.policy, n_rollout=self.play_rollouts_num)

        parallel_sampler.terminate_task()
        self.plotter.close()
        if created_session:
            sess.close()
示例#17
0
    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()
            sess.run(tf.global_variables_initializer())

        # Initialize some missing variables
        uninitialized_vars = []
        for var in tf.all_variables():
            try:
                sess.run(var)
            except tf.errors.FailedPreconditionError:
                print("Uninitialized var: ", var)
                uninitialized_vars.append(var)
        init_new_vars_op = tf.initialize_variables(uninitialized_vars)
        sess.run(init_new_vars_op)

        self.start_worker(sess)
        start_time = time.time()
        last_average_return = None
        samples_total = 0
        for itr in range(self.start_itr, self.n_itr):
            if samples_total >= self.max_samples:
                print("WARNING: Total max num of samples collected: %d >= %d" %
                      (samples_total, self.max_samples))
                break
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                samples_total += self.batch_size
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                last_average_return = samples_data["average_return"]
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                # import pdb; pdb.set_trace()
                if self.store_paths:
                    ## WARN: Beware that data is saved to hdf in float32 by default
                    # see param float_nptype
                    h5u.append_train_iter_data(h5file=self.hdf,
                                               data=samples_data["paths"],
                                               data_group="traj_data/",
                                               teacher_indx=self.teacher_indx,
                                               itr=None,
                                               float_nptype=np.float32)
                    # params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                self.log_env_info(samples_data["env_infos"])
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input(
                            "Plotting evaluation run: Press Enter to continue..."
                        )
                # Showing policy from time to time
                if self.record_every_itr is not None and self.record_every_itr > 0 and itr % self.record_every_itr == 0:
                    self.record_policy(env=self.env,
                                       policy=self.policy,
                                       itr=itr)
                if self.play_every_itr is not None and self.play_every_itr > 0 and itr % self.play_every_itr == 0:
                    self.play_policy(env=self.env, policy=self.policy)

        # Recording a few episodes at the end
        if self.record_end_ep_num is not None:
            for i in range(self.record_end_ep_num):
                self.record_policy(env=self.env,
                                   policy=self.policy,
                                   itr=itr,
                                   postfix="_%02d" % i)

        # Reporting termination criteria
        if itr >= self.n_itr - 1:
            print(
                "TERM CRITERIA: Max number of iterations reached itr: %d , itr_max: %d"
                % (itr, self.n_itr - 1))
        if samples_total >= self.max_samples:
            print(
                "TERM CRITERIA: Total max num of samples collected: %d >= %d" %
                (samples_total, self.max_samples))

        self.shutdown_worker()
        if created_session:
            sess.close()
    def create_and_train_new_skill(self, skill_subpath):
        """
        Create and train a new skill based on given subpath. The new skill policy and
        ID are returned, and also saved in self._hrl_policy.
        """
        ## Prepare elements for training
        # Environment
        skill_learning_env = TfEnv(
                SkillLearningEnv(
                    # base env that was wrapped in HierarchizedEnv (not fully unwrapped - may be normalized!)
                    env=self.env.env.env,
                    start_obss=skill_subpath['start_observations'],
                    end_obss=skill_subpath['end_observations']
                )
        )

        # Skill policy
        new_skill_pol, new_skill_id = self._hrl_policy.create_new_skill(skill_subpath['end_observations'])  # blank policy to be trained

        # Baseline - clone baseline specified in low_algo_kwargs, or top-algo`s baseline
        #   We need to clone baseline, as each skill policy must have its own instance
        la_kwargs = dict(self._low_algo_kwargs)
        baseline_to_clone = la_kwargs.get('baseline', self.baseline)
        baseline = Serializable.clone(  # to create blank baseline
                obj=baseline_to_clone,
                name='{}Skill{}'.format(type(baseline_to_clone).__name__, new_skill_id)
        )
        la_kwargs['baseline'] = baseline

        # Algorithm
        algo = self._low_algo_cls(
                env=skill_learning_env,
                policy=new_skill_pol,
                **la_kwargs
        )

        # Logger parameters
        logger.dump_tabular(with_prefix=False)
        logger.log('Launching training of the new skill')
        logger_snapshot_dir_before = logger.get_snapshot_dir()
        logger_snapshot_mode_before = logger.get_snapshot_mode()
        logger_snapshot_gap_before = logger.get_snapshot_gap()
        logger.set_snapshot_dir(os.path.join(
                logger_snapshot_dir_before,
                'skill{}'.format(new_skill_id)
        ))
        logger.set_snapshot_mode('none')
        # logger.set_snapshot_gap(max(1, np.floor(la_kwargs['n_itr'] / 10)))
        logger.push_tabular_prefix('Skill{}/'.format(new_skill_id))
        logger.set_tensorboard_step_key('Iteration')

        # Train new skill
        with logger.prefix('Skill {} | '.format(new_skill_id)):
            algo.train(sess=self._tf_sess)

        # Restore logger parameters
        logger.pop_tabular_prefix()
        logger.set_snapshot_dir(logger_snapshot_dir_before)
        logger.set_snapshot_mode(logger_snapshot_mode_before)
        logger.set_snapshot_gap(logger_snapshot_gap_before)
        logger.log('Training of the new skill finished')

        return new_skill_pol, new_skill_id