예제 #1
0
 def _init(self):
     self.global_step = 0
     self.kl_coeff = self.config["kl_coeff"]
     self.model = Runner(self.env_name, 1, self.config, self.logdir, False)
     self.agents = [
         RemoteRunner.remote(self.env_name, 1, self.config, self.logdir,
                             True)
         for _ in range(self.config["num_workers"])
     ]
     self.start_time = time.time()
     if self.config["write_logs"]:
         self.file_writer = tf.summary.FileWriter(self.logdir,
                                                  self.model.sess.graph)
     else:
         self.file_writer = None
     self.saver = tf.train.Saver(max_to_keep=None)
예제 #2
0
파일: ppo.py 프로젝트: zcli/ray
class PPOAgent(Agent):
    _agent_name = "PPO"
    _default_config = DEFAULT_CONFIG

    def _init(self):
        self.global_step = 0
        self.kl_coeff = self.config["kl_coeff"]
        self.model = Runner(self.env_creator, 1, self.config, self.logdir,
                            False)
        self.agents = [
            RemoteRunner.remote(self.env_creator, 1, self.config, self.logdir,
                                True)
            for _ in range(self.config["num_workers"])
        ]
        self.start_time = time.time()
        if self.config["write_logs"]:
            self.file_writer = tf.summary.FileWriter(self.logdir,
                                                     self.model.sess.graph)
        else:
            self.file_writer = None
        self.saver = tf.train.Saver(max_to_keep=None)

    def _train(self):
        agents = self.agents
        config = self.config
        model = self.model

        print("===> iteration", self.iteration)

        iter_start = time.time()
        weights = ray.put(model.get_weights())
        [a.load_weights.remote(weights) for a in agents]
        trajectory, total_reward, traj_len_mean = collect_samples(
            agents, config, self.model.observation_filter,
            self.model.reward_filter)
        print("total reward is ", total_reward)
        print("trajectory length mean is ", traj_len_mean)
        print("timesteps:", trajectory["dones"].shape[0])
        if self.file_writer:
            traj_stats = tf.Summary(value=[
                tf.Summary.Value(tag="ppo/rollouts/mean_reward",
                                 simple_value=total_reward),
                tf.Summary.Value(tag="ppo/rollouts/traj_len_mean",
                                 simple_value=traj_len_mean)
            ])
            self.file_writer.add_summary(traj_stats, self.global_step)
        self.global_step += 1

        def standardized(value):
            # Divide by the maximum of value.std() and 1e-4
            # to guard against the case where all values are equal
            return (value - value.mean()) / max(1e-4, value.std())

        if config["use_gae"]:
            trajectory["advantages"] = standardized(trajectory["advantages"])
        else:
            trajectory["returns"] = standardized(trajectory["returns"])

        rollouts_end = time.time()
        print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
              ", stepsize=" + str(config["sgd_stepsize"]) + "):")
        names = [
            "iter", "total loss", "policy loss", "vf loss", "kl", "entropy"
        ]
        print(("{:>15}" * len(names)).format(*names))
        trajectory = shuffle(trajectory)
        shuffle_end = time.time()
        tuples_per_device = model.load_data(
            trajectory, self.iteration == 0 and config["full_trace_data_load"])
        load_end = time.time()
        rollouts_time = rollouts_end - iter_start
        shuffle_time = shuffle_end - rollouts_end
        load_time = load_end - shuffle_end
        sgd_time = 0
        for i in range(config["num_sgd_iter"]):
            sgd_start = time.time()
            batch_index = 0
            num_batches = (int(tuples_per_device) //
                           int(model.per_device_batch_size))
            loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], []
            permutation = np.random.permutation(num_batches)
            # Prepare to drop into the debugger
            if self.iteration == config["tf_debug_iteration"]:
                model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
            while batch_index < num_batches:
                full_trace = (i == 0 and self.iteration == 0 and batch_index
                              == config["full_trace_nth_sgd_batch"])
                batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \
                    batch_entropy = model.run_sgd_minibatch(
                        permutation[batch_index] * model.per_device_batch_size,
                        self.kl_coeff, full_trace,
                        self.file_writer)
                loss.append(batch_loss)
                policy_loss.append(batch_policy_loss)
                vf_loss.append(batch_vf_loss)
                kl.append(batch_kl)
                entropy.append(batch_entropy)
                batch_index += 1
            loss = np.mean(loss)
            policy_loss = np.mean(policy_loss)
            vf_loss = np.mean(vf_loss)
            kl = np.mean(kl)
            entropy = np.mean(entropy)
            sgd_end = time.time()
            print("{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format(
                i, loss, policy_loss, vf_loss, kl, entropy))

            values = []
            if i == config["num_sgd_iter"] - 1:
                metric_prefix = "ppo/sgd/final_iter/"
                values.append(
                    tf.Summary.Value(tag=metric_prefix + "kl_coeff",
                                     simple_value=self.kl_coeff))
                values.extend([
                    tf.Summary.Value(tag=metric_prefix + "mean_entropy",
                                     simple_value=entropy),
                    tf.Summary.Value(tag=metric_prefix + "mean_loss",
                                     simple_value=loss),
                    tf.Summary.Value(tag=metric_prefix + "mean_kl",
                                     simple_value=kl)
                ])
                if self.file_writer:
                    sgd_stats = tf.Summary(value=values)
                    self.file_writer.add_summary(sgd_stats, self.global_step)
            self.global_step += 1
            sgd_time += sgd_end - sgd_start
        if kl > 2.0 * config["kl_target"]:
            self.kl_coeff *= 1.5
        elif kl < 0.5 * config["kl_target"]:
            self.kl_coeff *= 0.5

        info = {
            "kl_divergence": kl,
            "kl_coefficient": self.kl_coeff,
            "rollouts_time": rollouts_time,
            "shuffle_time": shuffle_time,
            "load_time": load_time,
            "sgd_time": sgd_time,
            "sample_throughput": len(trajectory["observations"]) / sgd_time
        }

        print("kl div:", kl)
        print("kl coeff:", self.kl_coeff)
        print("rollouts time:", rollouts_time)
        print("shuffle time:", shuffle_time)
        print("load time:", load_time)
        print("sgd time:", sgd_time)
        print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
        print("total time so far:", time.time() - self.start_time)

        result = TrainingResult(
            episode_reward_mean=total_reward,
            episode_len_mean=traj_len_mean,
            timesteps_this_iter=trajectory["dones"].shape[0],
            info=info)

        return result

    def _save(self):
        checkpoint_path = self.saver.save(self.model.sess,
                                          os.path.join(self.logdir,
                                                       "checkpoint"),
                                          global_step=self.iteration)
        agent_state = ray.get([a.save.remote() for a in self.agents])
        extra_data = [
            self.model.save(), self.global_step, self.kl_coeff, agent_state
        ]
        pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.saver.restore(self.model.sess, checkpoint_path)
        extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
        self.model.restore(extra_data[0])
        self.global_step = extra_data[1]
        self.kl_coeff = extra_data[2]
        ray.get([
            a.restore.remote(o) for (a, o) in zip(self.agents, extra_data[3])
        ])

    def compute_action(self, observation):
        observation = self.model.observation_filter(observation, update=False)
        return self.model.common_policy.compute([observation])[0][0]