def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.remote_evaluators, self.train_batch_size) else: samples = collect_samples( self.remote_evaluators, self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) return fetches
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: if self.workers.remote_workers(): if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.workers.remote_workers(), self.train_batch_size) else: samples = collect_samples(self.workers.remote_workers(), self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) self.learner_stats = fetches return fetches
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: if self.workers.remote_workers(): samples = collect_samples(self.workers.remote_workers(), self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) # assert int(tuples_per_device) % int( # self.per_device_batch_size # ) == 0 # assert num_batches == 1, (tuples_per_device, # self.per_device_batch_size, num_batches) # if self.use_vtrace: # for i in range(self.num_sgd_iter): # iter_extra_fetches = defaultdict(list) # # permutation = np.random.permutation(num_batches) # # for batch_index in range(num_batches): # batch_fetches = optimizer.optimize(self.sess, 0) # for k, v in batch_fetches[LEARNER_STATS_KEY].items(): # iter_extra_fetches[k].append(v) # fetches[policy_id] = _averaged(iter_extra_fetches) # else: for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) # Here! if self.compute_num_steps_sampled: self.num_steps_sampled += self.compute_num_steps_sampled(samples) else: self.num_steps_sampled += np.mean( [b.count for b in samples.policy_batches.values()], dtype=np.int64) # logger.debug( # "***** [num_steps_sampled] Count is: {}, the new one is { # }".format( # samples.count, # np.mean( # [b.count for b in samples.policy_batches.values()], # dtype=np.int64 # ) # ) # ) # Here! self.num_steps_trained += np.mean(list(num_loaded_tuples.values()), dtype=np.int64) * len(self.devices) # logger.debug( # "***** [num_steps_sampled] Count is: {}, the new one is { # }".format( # tuples_per_device * len(self.devices), # np.mean(list(num_loaded_tuples.values()), dtype=np.int64) * # len(self.devices) # ) # ) self.learner_stats = fetches return fetches