def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: # TODO(rliaw): remove when refactoring from ray.rllib.agents.ppo.rollout import collect_samples samples = collect_samples(self.remote_evaluators, self.train_batch_size) else: samples = self.local_evaluator.sample() self._check_not_multiagent(samples) for field in self.standardize_fields: value = samples[field] standardized = (value - value.mean()) / max(1e-4, value.std()) samples[field] = standardized # Important: don't shuffle RNN sequence elements if not self.policy._state_inputs: samples.shuffle() with self.load_timer: tuples = self.policy._get_loss_inputs_dict(samples) data_keys = [ph for _, ph in self.policy.loss_inputs()] if self.policy._state_inputs: state_keys = (self.policy._state_inputs + [self.policy._seq_lens]) else: state_keys = [] tuples_per_device = self.par_opt.load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys]) with self.grad_timer: num_batches = (int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs ==") for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = self.par_opt.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return _averaged(iter_extra_fetches)
def step(self, postprocess_fn=None): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: # TODO(rliaw): remove when refactoring from ray.rllib.agents.ppo.rollout import collect_samples samples = collect_samples(self.remote_evaluators, self.timesteps_per_batch) else: samples = self.local_evaluator.sample() self._check_not_multiagent(samples) if postprocess_fn: postprocess_fn(samples) with self.load_timer: tuples_per_device = self.par_opt.load_data( self.sess, samples.columns([key for key, _ in self.policy.loss_inputs()])) with self.grad_timer: all_extra_fetches = defaultdict(list) num_batches = ( int(tuples_per_device) // int(self.per_device_batch_size)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): # TODO(ekl) support ppo's debugging features, e.g. # printing the current loss and tracing batch_fetches = self.par_opt.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k] += [v] for k, v in iter_extra_fetches.items(): all_extra_fetches[k] += [v] self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return all_extra_fetches
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: # TODO(rliaw): remove when refactoring from ray.rllib.agents.ppo.rollout import collect_samples samples = collect_samples(self.remote_evaluators, self.timesteps_per_batch) else: samples = self.local_evaluator.sample() self._check_not_multiagent(samples) for field in self.standardize_fields: value = samples[field] standardized = (value - value.mean()) / max(1e-4, value.std()) samples[field] = standardized samples.shuffle() with self.load_timer: tuples_per_device = self.par_opt.load_data( self.sess, samples.columns([key for key, _ in self.policy.loss_inputs()])) with self.grad_timer: num_batches = (int(tuples_per_device) // int(self.per_device_batch_size)) print("== sgd epochs ==") for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = self.par_opt.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) print(i, _averaged(iter_extra_fetches)) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return _averaged(iter_extra_fetches)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: # TODO(rliaw): remove when refactoring from ray.rllib.agents.ppo.rollout import collect_samples samples = collect_samples(self.remote_evaluators, self.train_batch_size) else: samples = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) for _, batch in samples.policy_batches.items(): for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized for policy_id, policy in self.policies.items(): # Important: don't shuffle RNN sequence elements if (policy_id in samples.policy_batches and not policy._state_inputs): samples.policy_batches[policy_id].shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = (int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches