Esempio n. 1
0
    def step(self, postprocess_fn=None):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.timesteps_per_batch)
            else:
                samples = self.local_evaluator.sample()
            assert isinstance(samples, SampleBatch)

            if postprocess_fn:
                postprocess_fn(samples)

        with self.load_timer:
            tuples_per_device = self.par_opt.load_data(
                self.local_evaluator.sess,
                samples.columns([key for key, _ in self.loss_inputs]))

        with self.grad_timer:
            all_extra_fetches = []
            model = self.local_evaluator
            num_batches = (
                int(tuples_per_device) // int(self.per_device_batch_size))
            for i in range(self.num_sgd_iter):
                iter_extra_fetches = []
                permutation = np.random.permutation(num_batches)
                for batch_index in range(num_batches):
                    # TODO(ekl) support ppo's debugging features, e.g.
                    # printing the current loss and tracing
                    batch_fetches = self.par_opt.optimize(
                        self.sess,
                        permutation[batch_index] * self.per_device_batch_size,
                        extra_ops=model.extra_apply_grad_fetches(),
                        extra_feed_dict=model.extra_apply_grad_feed_dict())
                    iter_extra_fetches += [batch_fetches]
                all_extra_fetches += [iter_extra_fetches]

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return all_extra_fetches
Esempio n. 2
0
    def step(self, postprocess_fn=None):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.timesteps_per_batch)
            else:
                samples = self.local_evaluator.sample()
            self._check_not_multiagent(samples)

            if postprocess_fn:
                postprocess_fn(samples)

        with self.load_timer:
            tuples_per_device = self.par_opt.load_data(
                self.sess,
                samples.columns([key for key, _ in self.policy.loss_inputs()]))

        with self.grad_timer:
            all_extra_fetches = defaultdict(list)
            num_batches = (
                int(tuples_per_device) // int(self.per_device_batch_size))
            for i in range(self.num_sgd_iter):
                iter_extra_fetches = defaultdict(list)
                permutation = np.random.permutation(num_batches)
                for batch_index in range(num_batches):
                    # TODO(ekl) support ppo's debugging features, e.g.
                    # printing the current loss and tracing
                    batch_fetches = self.par_opt.optimize(
                        self.sess,
                        permutation[batch_index] * self.per_device_batch_size)
                    for k, v in batch_fetches.items():
                        iter_extra_fetches[k] += [v]
                for k, v in iter_extra_fetches.items():
                    all_extra_fetches[k] += [v]

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return all_extra_fetches
Esempio n. 3
0
    def _train(self):
        agents = self.remote_evaluators
        config = self.config
        model = self.local_evaluator

        print("===> iteration", self.iteration)

        iter_start = time.time()
        weights = ray.put(model.get_weights())
        [a.set_weights.remote(weights) for a in agents]
        samples = collect_samples(agents, config, self.local_evaluator)

        def standardized(value):
            # Divide by the maximum of value.std() and 1e-4
            # to guard against the case where all values are equal
            return (value - value.mean()) / max(1e-4, value.std())

        samples.data["advantages"] = standardized(samples["advantages"])

        rollouts_end = time.time()
        print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
              ", stepsize=" + str(config["sgd_stepsize"]) + "):")
        names = [
            "iter", "total loss", "policy loss", "vf loss", "kl", "entropy"
        ]
        print(("{:>15}" * len(names)).format(*names))
        samples.shuffle()
        shuffle_end = time.time()
        tuples_per_device = model.load_data(
            samples, self.iteration == 0 and config["full_trace_data_load"])
        load_end = time.time()
        rollouts_time = rollouts_end - iter_start
        shuffle_time = shuffle_end - rollouts_end
        load_time = load_end - shuffle_end
        sgd_time = 0
        for i in range(config["num_sgd_iter"]):
            sgd_start = time.time()
            batch_index = 0
            num_batches = (int(tuples_per_device) //
                           int(model.per_device_batch_size))
            loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], []
            permutation = np.random.permutation(num_batches)
            # Prepare to drop into the debugger
            if self.iteration == config["tf_debug_iteration"]:
                model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
            while batch_index < num_batches:
                full_trace = (i == 0 and self.iteration == 0 and batch_index
                              == config["full_trace_nth_sgd_batch"])
                batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \
                    batch_entropy = model.run_sgd_minibatch(
                        permutation[batch_index] * model.per_device_batch_size,
                        self.kl_coeff, full_trace,
                        self.file_writer)
                loss.append(batch_loss)
                policy_loss.append(batch_policy_loss)
                vf_loss.append(batch_vf_loss)
                kl.append(batch_kl)
                entropy.append(batch_entropy)
                batch_index += 1
            loss = np.mean(loss)
            policy_loss = np.mean(policy_loss)
            vf_loss = np.mean(vf_loss)
            kl = np.mean(kl)
            entropy = np.mean(entropy)
            sgd_end = time.time()
            print("{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format(
                i, loss, policy_loss, vf_loss, kl, entropy))

            values = []
            if i == config["num_sgd_iter"] - 1:
                metric_prefix = "ppo/sgd/final_iter/"
                values.append(
                    tf.Summary.Value(tag=metric_prefix + "kl_coeff",
                                     simple_value=self.kl_coeff))
                values.extend([
                    tf.Summary.Value(tag=metric_prefix + "mean_entropy",
                                     simple_value=entropy),
                    tf.Summary.Value(tag=metric_prefix + "mean_loss",
                                     simple_value=loss),
                    tf.Summary.Value(tag=metric_prefix + "mean_kl",
                                     simple_value=kl)
                ])
                if self.file_writer:
                    sgd_stats = tf.Summary(value=values)
                    self.file_writer.add_summary(sgd_stats, self.global_step)
            self.global_step += 1
            sgd_time += sgd_end - sgd_start
        if kl > 2.0 * config["kl_target"]:
            self.kl_coeff *= 1.5
        elif kl < 0.5 * config["kl_target"]:
            self.kl_coeff *= 0.5

        info = {
            "kl_divergence": kl,
            "kl_coefficient": self.kl_coeff,
            "rollouts_time": rollouts_time,
            "shuffle_time": shuffle_time,
            "load_time": load_time,
            "sgd_time": sgd_time,
            "sample_throughput": len(samples["observations"]) / sgd_time
        }

        FilterManager.synchronize(self.local_evaluator.filters,
                                  self.remote_evaluators)
        res = self._fetch_metrics_from_remote_evaluators()
        res = res._replace(info=info)

        return res
Esempio n. 4
0
File: ppo.py Progetto: zcli/ray
    def _train(self):
        agents = self.agents
        config = self.config
        model = self.model

        print("===> iteration", self.iteration)

        iter_start = time.time()
        weights = ray.put(model.get_weights())
        [a.load_weights.remote(weights) for a in agents]
        trajectory, total_reward, traj_len_mean = collect_samples(
            agents, config, self.model.observation_filter,
            self.model.reward_filter)
        print("total reward is ", total_reward)
        print("trajectory length mean is ", traj_len_mean)
        print("timesteps:", trajectory["dones"].shape[0])
        if self.file_writer:
            traj_stats = tf.Summary(value=[
                tf.Summary.Value(tag="ppo/rollouts/mean_reward",
                                 simple_value=total_reward),
                tf.Summary.Value(tag="ppo/rollouts/traj_len_mean",
                                 simple_value=traj_len_mean)
            ])
            self.file_writer.add_summary(traj_stats, self.global_step)
        self.global_step += 1

        def standardized(value):
            # Divide by the maximum of value.std() and 1e-4
            # to guard against the case where all values are equal
            return (value - value.mean()) / max(1e-4, value.std())

        if config["use_gae"]:
            trajectory["advantages"] = standardized(trajectory["advantages"])
        else:
            trajectory["returns"] = standardized(trajectory["returns"])

        rollouts_end = time.time()
        print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
              ", stepsize=" + str(config["sgd_stepsize"]) + "):")
        names = [
            "iter", "total loss", "policy loss", "vf loss", "kl", "entropy"
        ]
        print(("{:>15}" * len(names)).format(*names))
        trajectory = shuffle(trajectory)
        shuffle_end = time.time()
        tuples_per_device = model.load_data(
            trajectory, self.iteration == 0 and config["full_trace_data_load"])
        load_end = time.time()
        rollouts_time = rollouts_end - iter_start
        shuffle_time = shuffle_end - rollouts_end
        load_time = load_end - shuffle_end
        sgd_time = 0
        for i in range(config["num_sgd_iter"]):
            sgd_start = time.time()
            batch_index = 0
            num_batches = (int(tuples_per_device) //
                           int(model.per_device_batch_size))
            loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], []
            permutation = np.random.permutation(num_batches)
            # Prepare to drop into the debugger
            if self.iteration == config["tf_debug_iteration"]:
                model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
            while batch_index < num_batches:
                full_trace = (i == 0 and self.iteration == 0 and batch_index
                              == config["full_trace_nth_sgd_batch"])
                batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \
                    batch_entropy = model.run_sgd_minibatch(
                        permutation[batch_index] * model.per_device_batch_size,
                        self.kl_coeff, full_trace,
                        self.file_writer)
                loss.append(batch_loss)
                policy_loss.append(batch_policy_loss)
                vf_loss.append(batch_vf_loss)
                kl.append(batch_kl)
                entropy.append(batch_entropy)
                batch_index += 1
            loss = np.mean(loss)
            policy_loss = np.mean(policy_loss)
            vf_loss = np.mean(vf_loss)
            kl = np.mean(kl)
            entropy = np.mean(entropy)
            sgd_end = time.time()
            print("{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format(
                i, loss, policy_loss, vf_loss, kl, entropy))

            values = []
            if i == config["num_sgd_iter"] - 1:
                metric_prefix = "ppo/sgd/final_iter/"
                values.append(
                    tf.Summary.Value(tag=metric_prefix + "kl_coeff",
                                     simple_value=self.kl_coeff))
                values.extend([
                    tf.Summary.Value(tag=metric_prefix + "mean_entropy",
                                     simple_value=entropy),
                    tf.Summary.Value(tag=metric_prefix + "mean_loss",
                                     simple_value=loss),
                    tf.Summary.Value(tag=metric_prefix + "mean_kl",
                                     simple_value=kl)
                ])
                if self.file_writer:
                    sgd_stats = tf.Summary(value=values)
                    self.file_writer.add_summary(sgd_stats, self.global_step)
            self.global_step += 1
            sgd_time += sgd_end - sgd_start
        if kl > 2.0 * config["kl_target"]:
            self.kl_coeff *= 1.5
        elif kl < 0.5 * config["kl_target"]:
            self.kl_coeff *= 0.5

        info = {
            "kl_divergence": kl,
            "kl_coefficient": self.kl_coeff,
            "rollouts_time": rollouts_time,
            "shuffle_time": shuffle_time,
            "load_time": load_time,
            "sgd_time": sgd_time,
            "sample_throughput": len(trajectory["observations"]) / sgd_time
        }

        print("kl div:", kl)
        print("kl coeff:", self.kl_coeff)
        print("rollouts time:", rollouts_time)
        print("shuffle time:", shuffle_time)
        print("load time:", load_time)
        print("sgd time:", sgd_time)
        print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
        print("total time so far:", time.time() - self.start_time)

        result = TrainingResult(
            episode_reward_mean=total_reward,
            episode_len_mean=traj_len_mean,
            timesteps_this_iter=trajectory["dones"].shape[0],
            info=info)

        return result
Esempio n. 5
0
File: ppo.py Progetto: adgirish/ray
    def _train(self):
        agents = self.remote_evaluators
        config = self.config
        model = self.local_evaluator

        if (config["num_workers"] * config["min_steps_per_task"] >
                config["timesteps_per_batch"]):
            print(
                "WARNING: num_workers * min_steps_per_task > "
                "timesteps_per_batch. This means that the output of some "
                "tasks will be wasted. Consider decreasing "
                "min_steps_per_task or increasing timesteps_per_batch.")

        print("===> iteration", self.iteration)

        iter_start = time.time()
        weights = ray.put(model.get_weights())
        [a.set_weights.remote(weights) for a in agents]
        samples = collect_samples(agents, config, self.local_evaluator)

        def standardized(value):
            # Divide by the maximum of value.std() and 1e-4
            # to guard against the case where all values are equal
            return (value - value.mean()) / max(1e-4, value.std())

        samples.data["advantages"] = standardized(samples["advantages"])

        rollouts_end = time.time()
        print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
              ", stepsize=" + str(config["sgd_stepsize"]) + "):")
        names = [
            "iter", "total loss", "policy loss", "vf loss", "kl", "entropy"]
        print(("{:>15}" * len(names)).format(*names))
        samples.shuffle()
        shuffle_end = time.time()
        tuples_per_device = model.load_data(
            samples, self.iteration == 0 and config["full_trace_data_load"])
        load_end = time.time()
        rollouts_time = rollouts_end - iter_start
        shuffle_time = shuffle_end - rollouts_end
        load_time = load_end - shuffle_end
        sgd_time = 0
        for i in range(config["num_sgd_iter"]):
            sgd_start = time.time()
            batch_index = 0
            num_batches = (
                int(tuples_per_device) // int(model.per_device_batch_size))
            loss, policy_loss, vf_loss, kl, entropy = [], [], [], [], []
            permutation = np.random.permutation(num_batches)
            # Prepare to drop into the debugger
            if self.iteration == config["tf_debug_iteration"]:
                model.sess = tf_debug.LocalCLIDebugWrapperSession(model.sess)
            while batch_index < num_batches:
                full_trace = (
                    i == 0 and self.iteration == 0 and
                    batch_index == config["full_trace_nth_sgd_batch"])
                batch_loss, batch_policy_loss, batch_vf_loss, batch_kl, \
                    batch_entropy = model.run_sgd_minibatch(
                        permutation[batch_index] * model.per_device_batch_size,
                        self.kl_coeff, full_trace,
                        self.file_writer)
                loss.append(batch_loss)
                policy_loss.append(batch_policy_loss)
                vf_loss.append(batch_vf_loss)
                kl.append(batch_kl)
                entropy.append(batch_entropy)
                batch_index += 1
            loss = np.mean(loss)
            policy_loss = np.mean(policy_loss)
            vf_loss = np.mean(vf_loss)
            kl = np.mean(kl)
            entropy = np.mean(entropy)
            sgd_end = time.time()
            print(
                "{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format(
                    i, loss, policy_loss, vf_loss, kl, entropy))

            values = []
            if i == config["num_sgd_iter"] - 1:
                metric_prefix = "ppo/sgd/final_iter/"
                values.append(tf.Summary.Value(
                    tag=metric_prefix + "kl_coeff",
                    simple_value=self.kl_coeff))
                values.extend([
                    tf.Summary.Value(
                        tag=metric_prefix + "mean_entropy",
                        simple_value=entropy),
                    tf.Summary.Value(
                        tag=metric_prefix + "mean_loss",
                        simple_value=loss),
                    tf.Summary.Value(
                        tag=metric_prefix + "mean_kl",
                        simple_value=kl)])
                if self.file_writer:
                    sgd_stats = tf.Summary(value=values)
                    self.file_writer.add_summary(sgd_stats, self.global_step)
            self.global_step += 1
            sgd_time += sgd_end - sgd_start
        if kl > 2.0 * config["kl_target"]:
            self.kl_coeff *= 1.5
        elif kl < 0.5 * config["kl_target"]:
            self.kl_coeff *= 0.5

        info = {
            "kl_divergence": kl,
            "kl_coefficient": self.kl_coeff,
            "rollouts_time": rollouts_time,
            "shuffle_time": shuffle_time,
            "load_time": load_time,
            "sgd_time": sgd_time,
            "sample_throughput": len(samples["observations"]) / sgd_time
        }

        FilterManager.synchronize(
            self.local_evaluator.filters, self.remote_evaluators)
        res = self._fetch_metrics_from_remote_evaluators()
        res = res._replace(info=info)
        return res