Esempio n. 1
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.agents.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.train_batch_size)
            else:
                samples = self.local_evaluator.sample()
            self._check_not_multiagent(samples)

        for field in self.standardize_fields:
            value = samples[field]
            standardized = (value - value.mean()) / max(1e-4, value.std())
            samples[field] = standardized

        # Important: don't shuffle RNN sequence elements
        if not self.policy._state_inputs:
            samples.shuffle()

        with self.load_timer:
            tuples = self.policy._get_loss_inputs_dict(samples)
            data_keys = [ph for _, ph in self.policy.loss_inputs()]
            if self.policy._state_inputs:
                state_keys = (self.policy._state_inputs +
                              [self.policy._seq_lens])
            else:
                state_keys = []
            tuples_per_device = self.par_opt.load_data(
                self.sess, [tuples[k] for k in data_keys],
                [tuples[k] for k in state_keys])

        with self.grad_timer:
            num_batches = (int(tuples_per_device) //
                           int(self.per_device_batch_size))
            logger.debug("== sgd epochs ==")
            for i in range(self.num_sgd_iter):
                iter_extra_fetches = defaultdict(list)
                permutation = np.random.permutation(num_batches)
                for batch_index in range(num_batches):
                    batch_fetches = self.par_opt.optimize(
                        self.sess,
                        permutation[batch_index] * self.per_device_batch_size)
                    for k, v in batch_fetches.items():
                        iter_extra_fetches[k].append(v)
                logger.debug("{} {}".format(i, _averaged(iter_extra_fetches)))

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return _averaged(iter_extra_fetches)
Esempio n. 2
0
    def step(self, postprocess_fn=None):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.agents.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.timesteps_per_batch)
            else:
                samples = self.local_evaluator.sample()
            self._check_not_multiagent(samples)

            if postprocess_fn:
                postprocess_fn(samples)

        with self.load_timer:
            tuples_per_device = self.par_opt.load_data(
                self.sess,
                samples.columns([key for key, _ in self.policy.loss_inputs()]))

        with self.grad_timer:
            all_extra_fetches = defaultdict(list)
            num_batches = (
                int(tuples_per_device) // int(self.per_device_batch_size))
            for i in range(self.num_sgd_iter):
                iter_extra_fetches = defaultdict(list)
                permutation = np.random.permutation(num_batches)
                for batch_index in range(num_batches):
                    # TODO(ekl) support ppo's debugging features, e.g.
                    # printing the current loss and tracing
                    batch_fetches = self.par_opt.optimize(
                        self.sess,
                        permutation[batch_index] * self.per_device_batch_size)
                    for k, v in batch_fetches.items():
                        iter_extra_fetches[k] += [v]
                for k, v in iter_extra_fetches.items():
                    all_extra_fetches[k] += [v]

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return all_extra_fetches
Esempio n. 3
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.agents.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.timesteps_per_batch)
            else:
                samples = self.local_evaluator.sample()
            self._check_not_multiagent(samples)

        for field in self.standardize_fields:
            value = samples[field]
            standardized = (value - value.mean()) / max(1e-4, value.std())
            samples[field] = standardized
        samples.shuffle()

        with self.load_timer:
            tuples_per_device = self.par_opt.load_data(
                self.sess,
                samples.columns([key for key, _ in self.policy.loss_inputs()]))

        with self.grad_timer:
            num_batches = (int(tuples_per_device) //
                           int(self.per_device_batch_size))
            print("== sgd epochs ==")
            for i in range(self.num_sgd_iter):
                iter_extra_fetches = defaultdict(list)
                permutation = np.random.permutation(num_batches)
                for batch_index in range(num_batches):
                    batch_fetches = self.par_opt.optimize(
                        self.sess,
                        permutation[batch_index] * self.per_device_batch_size)
                    for k, v in batch_fetches.items():
                        iter_extra_fetches[k].append(v)
                print(i, _averaged(iter_extra_fetches))

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return _averaged(iter_extra_fetches)
Esempio n. 4
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                # TODO(rliaw): remove when refactoring
                from ray.rllib.agents.ppo.rollout import collect_samples
                samples = collect_samples(self.remote_evaluators,
                                          self.train_batch_size)
            else:
                samples = self.local_evaluator.sample()
            # Handle everything as if multiagent
            if isinstance(samples, SampleBatch):
                samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                          samples.count)

        for _, batch in samples.policy_batches.items():
            for field in self.standardize_fields:
                value = batch[field]
                standardized = (value - value.mean()) / max(1e-4, value.std())
                batch[field] = standardized

        for policy_id, policy in self.policies.items():
            # Important: don't shuffle RNN sequence elements
            if (policy_id in samples.policy_batches
                    and not policy._state_inputs):
                samples.policy_batches[policy_id].shuffle()

        num_loaded_tuples = {}
        with self.load_timer:
            for policy_id, batch in samples.policy_batches.items():
                policy = self.policies[policy_id]
                tuples = policy._get_loss_inputs_dict(batch)
                data_keys = [ph for _, ph in policy._loss_inputs]
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        fetches = {}
        with self.grad_timer:
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = (int(tuples_per_device) //
                               int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)
                        for k, v in batch_fetches.items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches