def get_pc_feeder(self, batch): """ Returns feed dictionary for `pixel control` loss estimation subgraph. """ if not self.use_off_policy_aac: # use single pass of network on same off-policy batch feeder = feed_dict_from_nested(self.local_network.pc_state_in, batch['state']) feeder.update( feed_dict_rnn_context( self.local_network.pc_lstm_state_pl_flatten, batch['context'])) feeder.update({ self.local_network.pc_a_r_in: batch['last_action_reward'], self.pc_action: batch['action'], self.pc_target: batch['pixel_change'] }) else: feeder = { self.pc_action: batch['action'], self.pc_target: batch['pixel_change'] } return feeder
def get_rp_feeder(self, batch): """ Returns feed dictionary for `reward prediction` loss estimation subgraph. """ feeder = feed_dict_from_nested(self.local_network.rp_state_in, batch['state']) feeder.update({ self.rp_target: batch['rp_target'], self.local_network.rp_batch_size: batch['batch_size'], }) return feeder
def get_vr_feeder(self, batch): """ Returns feed dictionary for `value replay` loss estimation subgraph. """ if not self.use_off_policy_aac: # use single pass of network on same off-policy batch feeder = feed_dict_from_nested(self.local_network.vr_state_in, batch['state']) feeder.update(feed_dict_rnn_context(self.local_network.vr_lstm_state_pl_flatten, batch['context'])) feeder.update( { self.local_network.vr_batch_size: batch['batch_size'], self.local_network.vr_time_length: batch['time_steps'], self.local_network.vr_a_r_in: batch['last_action_reward'], self.vr_target: batch['r'] } ) else: feeder = {self.vr_target: batch['r']} # redundant actually :) return feeder
def process(self, sess): """ Grabs a on_policy_rollout that's been produced by the thread runner. If data identified as 'train data' - samples off_policy rollout[s] from replay memory and updates the parameters; writes summaries if any. The update is then sent to the parameter server. If on_policy_rollout contains 'test data' - no policy update is performed and learn rate is set to zero; Meanwile test data are stored in replay memory. """ # Collect data from child thread runners: data = self.get_data() # Test or train: if at least one rollout from parallel runners is test rollout - # set learn rate to zero for entire minibatch. Doh. try: is_train = not np.asarray([ env['state']['metadata']['type'] for env in data['on_policy'] ]).any() except KeyError: is_train = True # Copy weights from local policy to local target policy: if self.use_target_policy and self.local_steps % self.pi_prime_update_period == 0: sess.run(self.sync_pi_prime) if is_train: # If there is no testing rollouts - copy weights from shared to local new_policy: sess.run(self.sync_pi) #self.log.debug('is_train: {}'.format(is_train)) # Process minibatch for on-policy train step: on_policy_rollouts = data['on_policy'] on_policy_batch = batch_stack([ r.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda, size=self.rollout_length, time_flat=self.time_flat, ) for r in on_policy_rollouts ]) # Feeder for on-policy AAC loss estimation graph: feed_dict = feed_dict_from_nested(self.local_network.on_state_in, on_policy_batch['state']) feed_dict.update( feed_dict_rnn_context(self.local_network.on_lstm_state_pl_flatten, on_policy_batch['context'])) feed_dict.update({ self.local_network.on_a_r_in: on_policy_batch['last_action_reward'], self.local_network.on_batch_size: on_policy_batch['batch_size'], self.local_network.on_time_length: on_policy_batch['time_steps'], self.on_pi_act_target: on_policy_batch['action'], self.on_pi_adv_target: on_policy_batch['advantage'], self.on_pi_r_target: on_policy_batch['r'], self.local_network.train_phase: is_train, # Zeroes learn rate, [+ batch_norm] }) if self.use_target_policy: feed_dict.update( feed_dict_from_nested(self.local_network_prime.on_state_in, on_policy_batch['state'])) feed_dict.update( feed_dict_rnn_context( self.local_network_prime.on_lstm_state_pl_flatten, on_policy_batch['context'])) feed_dict.update({ self.local_network_prime.on_batch_size: on_policy_batch['batch_size'], self.local_network_prime.on_time_length: on_policy_batch['time_steps'], self.local_network_prime.on_a_r_in: on_policy_batch['last_action_reward'] }) if self.use_memory: # Process rollouts from replay memory: off_policy_rollouts = data['off_policy'] off_policy_batch = batch_stack([ r.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda, size=self.replay_rollout_length, time_flat=self.time_flat, ) for r in off_policy_rollouts ]) # Feeder for off-policy AAC loss estimation graph: off_policy_feed_dict = feed_dict_from_nested( self.local_network.off_state_in, off_policy_batch['state']) off_policy_feed_dict.update( feed_dict_rnn_context( self.local_network.off_lstm_state_pl_flatten, off_policy_batch['context'])) off_policy_feed_dict.update({ self.local_network.off_a_r_in: off_policy_batch['last_action_reward'], self.local_network.off_batch_size: off_policy_batch['batch_size'], self.local_network.off_time_length: off_policy_batch['time_steps'], self.off_pi_act_target: off_policy_batch['action'], self.off_pi_adv_target: off_policy_batch['advantage'], self.off_pi_r_target: off_policy_batch['r'], }) if self.use_target_policy: off_policy_feed_dict.update( feed_dict_from_nested( self.local_network_prime.off_state_in, off_policy_batch['state'])) off_policy_feed_dict.update({ self.local_network_prime.off_batch_size: off_policy_batch['batch_size'], self.local_network_prime.off_time_length: off_policy_batch['time_steps'], self.local_network_prime.off_a_r_in: off_policy_batch['last_action_reward'] }) off_policy_feed_dict.update( feed_dict_rnn_context( self.local_network_prime.off_lstm_state_pl_flatten, off_policy_batch['context'])) feed_dict.update(off_policy_feed_dict) # Update with reward prediction subgraph: if self.use_reward_prediction: # Rebalanced 50/50 sample for RP: rp_rollouts = data['off_policy_rp'] rp_batch = batch_stack([ rp.process_rp(self.rp_reward_threshold) for rp in rp_rollouts ]) feed_dict.update(self.get_rp_feeder(rp_batch)) # Pixel control ... if self.use_pixel_control: feed_dict.update(self.get_pc_feeder(off_policy_batch)) # VR... if self.use_value_replay: feed_dict.update(self.get_vr_feeder(off_policy_batch)) # Every worker writes train episode and model summaries: ep_summary_feeder = {} # Look for train episode summaries from all env runners: for stat in data['ep_summary']: if stat is not None: for key in stat.keys(): if key in ep_summary_feeder.keys(): ep_summary_feeder[key] += [stat[key]] else: ep_summary_feeder[key] = [stat[key]] # Average values among thread_runners, if any, and write episode summary: if ep_summary_feeder != {}: ep_summary_feed_dict = { self.ep_summary[key]: np.average(list) for key, list in ep_summary_feeder.items() } if self.test_mode: # Atari: fetched_episode_stat = sess.run( self.ep_summary['atari_stat_op'], ep_summary_feed_dict) else: # BTGym fetched_episode_stat = sess.run( self.ep_summary['btgym_stat_op'], ep_summary_feed_dict) self.summary_writer.add_summary(fetched_episode_stat, sess.run(self.global_episode)) self.summary_writer.flush() # Every worker writes test episode summaries: test_ep_summary_feeder = {} # Look for test episode summaries: for stat in data['test_ep_summary']: if stat is not None: for key in stat.keys(): if key in test_ep_summary_feeder.keys(): test_ep_summary_feeder[key] += [stat[key]] else: test_ep_summary_feeder[key] = [stat[key]] # Average values among thread_runners, if any, and write episode summary: if test_ep_summary_feeder != {}: test_ep_summary_feed_dict = { self.ep_summary[key]: np.average(list) for key, list in test_ep_summary_feeder.items() } fetched_test_episode_stat = sess.run( self.ep_summary['test_btgym_stat_op'], test_ep_summary_feed_dict) self.summary_writer.add_summary(fetched_test_episode_stat, sess.run(self.global_episode)) self.summary_writer.flush() wirte_model_summary =\ self.local_steps % self.model_summary_freq == 0 # Look for renderings (chief worker only, always 0-numbered environment): if self.task == 0: if data['render_summary'][0] is not None: render_feed_dict = { self.ep_summary[key]: pic for key, pic in data['render_summary'][0].items() } renderings = sess.run(self.ep_summary['render_op'], render_feed_dict) #if False: # if self.test_mode: # renderings = sess.run(self.ep_summary['atari_render_op'], render_feed_dict) # # else: # renderings = sess.run(self.ep_summary['btgym_render_op'], render_feed_dict) self.summary_writer.add_summary(renderings, sess.run(self.global_episode)) self.summary_writer.flush() #fetches = [self.train_op, self.local_network.debug] # include policy debug shapes fetches = [self.train_op] if wirte_model_summary: fetches_last = fetches + [self.model_summary_op, self.inc_step] else: fetches_last = fetches + [self.inc_step] # Do a number of SGD train epochs: # When doing more than one epoch, we actually use only last summary: for i in range(self.num_epochs - 1): fetched = sess.run(fetches, feed_dict=feed_dict) fetched = sess.run(fetches_last, feed_dict=feed_dict) if wirte_model_summary: self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]), fetched[-1]) self.summary_writer.flush() self.local_steps += 1
def process(self, sess): """ Grabs a on_policy_rollout that's been produced by the thread runner, samples off_policy rollout[s] from replay memory and updates the parameters. The update is then sent to the parameter server. """ # Copy weights from local policy to local target policy: if self.use_target_policy and self.local_steps % self.pi_prime_update_period == 0: sess.run(self.sync_pi_prime) # Copy weights from shared to local new_policy: sess.run(self.sync_pi) # Collect data from child thread runners: data = self.get_data() # Process minibatch for on-policy train step: on_policy_rollouts = data['on_policy'] on_policy_batch = batch_stack([ r.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda, size=self.rollout_length, time_flat=self.time_flat, ) for r in on_policy_rollouts ]) # Feeder for on-policy AAC loss estimation graph: feed_dict = feed_dict_from_nested(self.local_network.on_state_in, on_policy_batch['state']) feed_dict.update( feed_dict_rnn_context(self.local_network.on_lstm_state_pl_flatten, on_policy_batch['context'])) feed_dict.update({ self.local_network.on_a_r_in: on_policy_batch['last_action_reward'], self.local_network.on_batch_size: on_policy_batch['batch_size'], self.local_network.on_time_length: on_policy_batch['time_steps'], self.on_pi_act_target: on_policy_batch['action'], self.on_pi_adv_target: on_policy_batch['advantage'], self.on_pi_r_target: on_policy_batch['r'], self.local_network.train_phase: True, }) if self.use_target_policy: feed_dict.update( feed_dict_from_nested(self.local_network_prime.on_state_in, on_policy_batch['state'])) feed_dict.update( feed_dict_rnn_context( self.local_network_prime.on_lstm_state_pl_flatten, on_policy_batch['context'])) feed_dict.update({ self.local_network_prime.on_batch_size: on_policy_batch['batch_size'], self.local_network_prime.on_time_length: on_policy_batch['time_steps'], self.local_network_prime.on_a_r_in: on_policy_batch['last_action_reward'] }) if self.use_memory: # Process rollouts from replay memory: off_policy_rollouts = data['off_policy'] off_policy_batch = batch_stack([ r.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda, size=self.replay_rollout_length, time_flat=self.time_flat, ) for r in off_policy_rollouts ]) # Feeder for off-policy AAC loss estimation graph: off_policy_feed_dict = feed_dict_from_nested( self.local_network.off_state_in, off_policy_batch['state']) off_policy_feed_dict.update( feed_dict_rnn_context( self.local_network.off_lstm_state_pl_flatten, off_policy_batch['context'])) off_policy_feed_dict.update({ self.local_network.off_a_r_in: off_policy_batch['last_action_reward'], self.local_network.off_batch_size: off_policy_batch['batch_size'], self.local_network.off_time_length: off_policy_batch['time_steps'], self.off_pi_act_target: off_policy_batch['action'], self.off_pi_adv_target: off_policy_batch['advantage'], self.off_pi_r_target: off_policy_batch['r'], }) if self.use_target_policy: off_policy_feed_dict.update( feed_dict_from_nested( self.local_network_prime.off_state_in, off_policy_batch['state'])) off_policy_feed_dict.update({ self.local_network_prime.off_batch_size: off_policy_batch['batch_size'], self.local_network_prime.off_time_length: off_policy_batch['time_steps'], self.local_network_prime.off_a_r_in: off_policy_batch['last_action_reward'] }) off_policy_feed_dict.update( feed_dict_rnn_context( self.local_network_prime.off_lstm_state_pl_flatten, off_policy_batch['context'])) feed_dict.update(off_policy_feed_dict) # Update with reward prediction subgraph: if self.use_reward_prediction: # Rebalanced 50/50 sample for RP: rp_rollouts = data['off_policy_rp'] rp_batch = batch_stack([ rp.process_rp(self.rp_reward_threshold) for rp in rp_rollouts ]) feed_dict.update(self.get_rp_feeder(rp_batch)) # Pixel control ... if self.use_pixel_control: feed_dict.update(self.get_pc_feeder(off_policy_batch)) # VR... if self.use_value_replay: feed_dict.update(self.get_vr_feeder(off_policy_batch)) # Every worker writes episode and model summaries: ep_summary_feeder = {} # Collect episode summaries from all env runners: for stat in data['ep_summary']: if stat is not None: for key in stat.keys(): if key in ep_summary_feeder.keys(): ep_summary_feeder[key] += [stat[key]] else: ep_summary_feeder[key] = [stat[key]] # Average values among thread_runners, if any, and write episode summary: if ep_summary_feeder != {}: ep_summary_feed_dict = { self.ep_summary[key]: np.average(list) for key, list in ep_summary_feeder.items() } if self.test_mode: # Atari: fetched_episode_stat = sess.run( self.ep_summary['test_stat_op'], ep_summary_feed_dict) else: # BTGym fetched_episode_stat = sess.run(self.ep_summary['stat_op'], ep_summary_feed_dict) self.summary_writer.add_summary(fetched_episode_stat, sess.run(self.global_episode)) self.summary_writer.flush() wirte_model_summary =\ self.local_steps % self.model_summary_freq == 0 # Look for renderings (chief worker only, always 0-numbered environment): if self.task == 0: if data['render_summary'][0] is not None: render_feed_dict = { self.ep_summary[key]: pic for key, pic in data['render_summary'][0].items() } if self.test_mode: renderings = sess.run(self.ep_summary['test_render_op'], render_feed_dict) else: renderings = sess.run(self.ep_summary['render_op'], render_feed_dict) self.summary_writer.add_summary(renderings, sess.run(self.global_episode)) self.summary_writer.flush() fetches = [self.train_op] if wirte_model_summary: fetches_last = fetches + [self.model_summary_op, self.inc_step] else: fetches_last = fetches + [self.inc_step] # Do a number of SGD train epochs: # When doing more than one epoch, we actually use only last summary: for i in range(self.num_epochs - 1): fetched = sess.run(fetches, feed_dict=feed_dict) fetched = sess.run(fetches_last, feed_dict=feed_dict) if wirte_model_summary: self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]), fetched[-1]) self.summary_writer.flush() self.local_steps += 1