Esempio n. 1
0
 def get_vr_feeder(self, batch):
     """
     Returns feed dictionary for `value replay` loss estimation subgraph.
     """
     if not self.use_off_policy_aac:  # use single pass of network on same off-policy batch
         feeder = feed_dict_from_nested(self.local_network.vr_state_in, batch['state'])
         feeder.update(feed_dict_rnn_context(self.local_network.vr_lstm_state_pl_flatten, batch['context']))
         feeder.update(
             {
                 self.local_network.vr_batch_size: batch['batch_size'],
                 self.local_network.vr_time_length: batch['time_steps'],
                 self.local_network.vr_a_r_in: batch['last_action_reward'],
                 self.vr_target: batch['r']
             }
         )
     else:
         feeder = {self.vr_target: batch['r']}  # redundant actually :)
     return feeder
Esempio n. 2
0
    def process(self, sess):
        """
        Grabs a on_policy_rollout that's been produced by the thread runner. If data identified as 'train data' -
        samples off_policy rollout[s] from replay memory and updates the parameters; writes summaries if any.
        The update is then sent to the parameter server.
        If on_policy_rollout contains 'test data' -  no policy update is performed and learn rate is set to zero;
        Meanwile test data are stored in replay memory.
        """

        # Collect data from child thread runners:
        data = self.get_data()

        # Test or train: if at least one rollout from parallel runners is test rollout -
        # set learn rate to zero for entire minibatch. Doh.
        try:
            is_train = not np.asarray([
                env['state']['metadata']['type'] for env in data['on_policy']
            ]).any()

        except KeyError:
            is_train = True

        # Copy weights from local policy to local target policy:
        if self.use_target_policy and self.local_steps % self.pi_prime_update_period == 0:
            sess.run(self.sync_pi_prime)

        if is_train:
            # If there is no testing rollouts  - copy weights from shared to local new_policy:
            sess.run(self.sync_pi)

        #self.log.debug('is_train: {}'.format(is_train))

        # Process minibatch for on-policy train step:
        on_policy_rollouts = data['on_policy']
        on_policy_batch = batch_stack([
            r.process(
                gamma=self.model_gamma,
                gae_lambda=self.model_gae_lambda,
                size=self.rollout_length,
                time_flat=self.time_flat,
            ) for r in on_policy_rollouts
        ])
        # Feeder for on-policy AAC loss estimation graph:
        feed_dict = feed_dict_from_nested(self.local_network.on_state_in,
                                          on_policy_batch['state'])
        feed_dict.update(
            feed_dict_rnn_context(self.local_network.on_lstm_state_pl_flatten,
                                  on_policy_batch['context']))
        feed_dict.update({
            self.local_network.on_a_r_in:
            on_policy_batch['last_action_reward'],
            self.local_network.on_batch_size:
            on_policy_batch['batch_size'],
            self.local_network.on_time_length:
            on_policy_batch['time_steps'],
            self.on_pi_act_target:
            on_policy_batch['action'],
            self.on_pi_adv_target:
            on_policy_batch['advantage'],
            self.on_pi_r_target:
            on_policy_batch['r'],
            self.local_network.train_phase:
            is_train,  # Zeroes learn rate, [+ batch_norm]
        })
        if self.use_target_policy:
            feed_dict.update(
                feed_dict_from_nested(self.local_network_prime.on_state_in,
                                      on_policy_batch['state']))
            feed_dict.update(
                feed_dict_rnn_context(
                    self.local_network_prime.on_lstm_state_pl_flatten,
                    on_policy_batch['context']))
            feed_dict.update({
                self.local_network_prime.on_batch_size:
                on_policy_batch['batch_size'],
                self.local_network_prime.on_time_length:
                on_policy_batch['time_steps'],
                self.local_network_prime.on_a_r_in:
                on_policy_batch['last_action_reward']
            })
        if self.use_memory:
            # Process rollouts from replay memory:
            off_policy_rollouts = data['off_policy']
            off_policy_batch = batch_stack([
                r.process(
                    gamma=self.model_gamma,
                    gae_lambda=self.model_gae_lambda,
                    size=self.replay_rollout_length,
                    time_flat=self.time_flat,
                ) for r in off_policy_rollouts
            ])
            # Feeder for off-policy AAC loss estimation graph:
            off_policy_feed_dict = feed_dict_from_nested(
                self.local_network.off_state_in, off_policy_batch['state'])
            off_policy_feed_dict.update(
                feed_dict_rnn_context(
                    self.local_network.off_lstm_state_pl_flatten,
                    off_policy_batch['context']))
            off_policy_feed_dict.update({
                self.local_network.off_a_r_in:
                off_policy_batch['last_action_reward'],
                self.local_network.off_batch_size:
                off_policy_batch['batch_size'],
                self.local_network.off_time_length:
                off_policy_batch['time_steps'],
                self.off_pi_act_target:
                off_policy_batch['action'],
                self.off_pi_adv_target:
                off_policy_batch['advantage'],
                self.off_pi_r_target:
                off_policy_batch['r'],
            })
            if self.use_target_policy:
                off_policy_feed_dict.update(
                    feed_dict_from_nested(
                        self.local_network_prime.off_state_in,
                        off_policy_batch['state']))
                off_policy_feed_dict.update({
                    self.local_network_prime.off_batch_size:
                    off_policy_batch['batch_size'],
                    self.local_network_prime.off_time_length:
                    off_policy_batch['time_steps'],
                    self.local_network_prime.off_a_r_in:
                    off_policy_batch['last_action_reward']
                })
                off_policy_feed_dict.update(
                    feed_dict_rnn_context(
                        self.local_network_prime.off_lstm_state_pl_flatten,
                        off_policy_batch['context']))
            feed_dict.update(off_policy_feed_dict)

            # Update with reward prediction subgraph:
            if self.use_reward_prediction:
                # Rebalanced 50/50 sample for RP:
                rp_rollouts = data['off_policy_rp']
                rp_batch = batch_stack([
                    rp.process_rp(self.rp_reward_threshold)
                    for rp in rp_rollouts
                ])
                feed_dict.update(self.get_rp_feeder(rp_batch))

            # Pixel control ...
            if self.use_pixel_control:
                feed_dict.update(self.get_pc_feeder(off_policy_batch))

            # VR...
            if self.use_value_replay:
                feed_dict.update(self.get_vr_feeder(off_policy_batch))

        # Every worker writes train episode and model summaries:
        ep_summary_feeder = {}

        # Look for train episode summaries from all env runners:
        for stat in data['ep_summary']:
            if stat is not None:
                for key in stat.keys():
                    if key in ep_summary_feeder.keys():
                        ep_summary_feeder[key] += [stat[key]]
                    else:
                        ep_summary_feeder[key] = [stat[key]]
        # Average values among thread_runners, if any, and write episode summary:
        if ep_summary_feeder != {}:
            ep_summary_feed_dict = {
                self.ep_summary[key]: np.average(list)
                for key, list in ep_summary_feeder.items()
            }

            if self.test_mode:
                # Atari:
                fetched_episode_stat = sess.run(
                    self.ep_summary['atari_stat_op'], ep_summary_feed_dict)

            else:
                # BTGym
                fetched_episode_stat = sess.run(
                    self.ep_summary['btgym_stat_op'], ep_summary_feed_dict)

            self.summary_writer.add_summary(fetched_episode_stat,
                                            sess.run(self.global_episode))
            self.summary_writer.flush()

        # Every worker writes test episode  summaries:
        test_ep_summary_feeder = {}

        # Look for test episode summaries:
        for stat in data['test_ep_summary']:
            if stat is not None:
                for key in stat.keys():
                    if key in test_ep_summary_feeder.keys():
                        test_ep_summary_feeder[key] += [stat[key]]
                    else:
                        test_ep_summary_feeder[key] = [stat[key]]
                        # Average values among thread_runners, if any, and write episode summary:
            if test_ep_summary_feeder != {}:
                test_ep_summary_feed_dict = {
                    self.ep_summary[key]: np.average(list)
                    for key, list in test_ep_summary_feeder.items()
                }
                fetched_test_episode_stat = sess.run(
                    self.ep_summary['test_btgym_stat_op'],
                    test_ep_summary_feed_dict)
                self.summary_writer.add_summary(fetched_test_episode_stat,
                                                sess.run(self.global_episode))
                self.summary_writer.flush()

        wirte_model_summary =\
            self.local_steps % self.model_summary_freq == 0

        # Look for renderings (chief worker only, always 0-numbered environment):
        if self.task == 0:
            if data['render_summary'][0] is not None:
                render_feed_dict = {
                    self.ep_summary[key]: pic
                    for key, pic in data['render_summary'][0].items()
                }
                renderings = sess.run(self.ep_summary['render_op'],
                                      render_feed_dict)
                #if False:
                #    if self.test_mode:
                #        renderings = sess.run(self.ep_summary['atari_render_op'], render_feed_dict)
                #
                #    else:
                #        renderings = sess.run(self.ep_summary['btgym_render_op'], render_feed_dict)

                self.summary_writer.add_summary(renderings,
                                                sess.run(self.global_episode))
                self.summary_writer.flush()

        #fetches = [self.train_op, self.local_network.debug]  # include policy debug shapes
        fetches = [self.train_op]

        if wirte_model_summary:
            fetches_last = fetches + [self.model_summary_op, self.inc_step]
        else:
            fetches_last = fetches + [self.inc_step]

        # Do a number of SGD train epochs:
        # When doing more than one epoch, we actually use only last summary:
        for i in range(self.num_epochs - 1):
            fetched = sess.run(fetches, feed_dict=feed_dict)

        fetched = sess.run(fetches_last, feed_dict=feed_dict)

        if wirte_model_summary:
            self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]),
                                            fetched[-1])
            self.summary_writer.flush()

        self.local_steps += 1
Esempio n. 3
0
    def process(self, sess):
        """
        Grabs a on_policy_rollout that's been produced by the thread runner,
        samples off_policy rollout[s] from replay memory and updates the parameters.
        The update is then sent to the parameter server.
        """

        # Copy weights from local policy to local target policy:
        if self.use_target_policy and self.local_steps % self.pi_prime_update_period == 0:
            sess.run(self.sync_pi_prime)

        # Copy weights from shared to local new_policy:
        sess.run(self.sync_pi)

        # Collect data from child thread runners:
        data = self.get_data()

        # Process minibatch for on-policy train step:
        on_policy_rollouts = data['on_policy']
        on_policy_batch = batch_stack([
            r.process(
                gamma=self.model_gamma,
                gae_lambda=self.model_gae_lambda,
                size=self.rollout_length,
                time_flat=self.time_flat,
            ) for r in on_policy_rollouts
        ])
        # Feeder for on-policy AAC loss estimation graph:
        feed_dict = feed_dict_from_nested(self.local_network.on_state_in,
                                          on_policy_batch['state'])
        feed_dict.update(
            feed_dict_rnn_context(self.local_network.on_lstm_state_pl_flatten,
                                  on_policy_batch['context']))
        feed_dict.update({
            self.local_network.on_a_r_in:
            on_policy_batch['last_action_reward'],
            self.local_network.on_batch_size:
            on_policy_batch['batch_size'],
            self.local_network.on_time_length:
            on_policy_batch['time_steps'],
            self.on_pi_act_target:
            on_policy_batch['action'],
            self.on_pi_adv_target:
            on_policy_batch['advantage'],
            self.on_pi_r_target:
            on_policy_batch['r'],
            self.local_network.train_phase:
            True,
        })
        if self.use_target_policy:
            feed_dict.update(
                feed_dict_from_nested(self.local_network_prime.on_state_in,
                                      on_policy_batch['state']))
            feed_dict.update(
                feed_dict_rnn_context(
                    self.local_network_prime.on_lstm_state_pl_flatten,
                    on_policy_batch['context']))
            feed_dict.update({
                self.local_network_prime.on_batch_size:
                on_policy_batch['batch_size'],
                self.local_network_prime.on_time_length:
                on_policy_batch['time_steps'],
                self.local_network_prime.on_a_r_in:
                on_policy_batch['last_action_reward']
            })
        if self.use_memory:
            # Process rollouts from replay memory:
            off_policy_rollouts = data['off_policy']
            off_policy_batch = batch_stack([
                r.process(
                    gamma=self.model_gamma,
                    gae_lambda=self.model_gae_lambda,
                    size=self.replay_rollout_length,
                    time_flat=self.time_flat,
                ) for r in off_policy_rollouts
            ])
            # Feeder for off-policy AAC loss estimation graph:
            off_policy_feed_dict = feed_dict_from_nested(
                self.local_network.off_state_in, off_policy_batch['state'])
            off_policy_feed_dict.update(
                feed_dict_rnn_context(
                    self.local_network.off_lstm_state_pl_flatten,
                    off_policy_batch['context']))
            off_policy_feed_dict.update({
                self.local_network.off_a_r_in:
                off_policy_batch['last_action_reward'],
                self.local_network.off_batch_size:
                off_policy_batch['batch_size'],
                self.local_network.off_time_length:
                off_policy_batch['time_steps'],
                self.off_pi_act_target:
                off_policy_batch['action'],
                self.off_pi_adv_target:
                off_policy_batch['advantage'],
                self.off_pi_r_target:
                off_policy_batch['r'],
            })
            if self.use_target_policy:
                off_policy_feed_dict.update(
                    feed_dict_from_nested(
                        self.local_network_prime.off_state_in,
                        off_policy_batch['state']))
                off_policy_feed_dict.update({
                    self.local_network_prime.off_batch_size:
                    off_policy_batch['batch_size'],
                    self.local_network_prime.off_time_length:
                    off_policy_batch['time_steps'],
                    self.local_network_prime.off_a_r_in:
                    off_policy_batch['last_action_reward']
                })
                off_policy_feed_dict.update(
                    feed_dict_rnn_context(
                        self.local_network_prime.off_lstm_state_pl_flatten,
                        off_policy_batch['context']))
            feed_dict.update(off_policy_feed_dict)

            # Update with reward prediction subgraph:
            if self.use_reward_prediction:
                # Rebalanced 50/50 sample for RP:
                rp_rollouts = data['off_policy_rp']
                rp_batch = batch_stack([
                    rp.process_rp(self.rp_reward_threshold)
                    for rp in rp_rollouts
                ])
                feed_dict.update(self.get_rp_feeder(rp_batch))

            # Pixel control ...
            if self.use_pixel_control:
                feed_dict.update(self.get_pc_feeder(off_policy_batch))

            # VR...
            if self.use_value_replay:
                feed_dict.update(self.get_vr_feeder(off_policy_batch))

        # Every worker writes episode and model summaries:
        ep_summary_feeder = {}

        # Collect episode summaries from all env runners:
        for stat in data['ep_summary']:
            if stat is not None:
                for key in stat.keys():
                    if key in ep_summary_feeder.keys():
                        ep_summary_feeder[key] += [stat[key]]
                    else:
                        ep_summary_feeder[key] = [stat[key]]
        # Average values among thread_runners, if any, and write episode summary:
        if ep_summary_feeder != {}:
            ep_summary_feed_dict = {
                self.ep_summary[key]: np.average(list)
                for key, list in ep_summary_feeder.items()
            }

            if self.test_mode:
                # Atari:
                fetched_episode_stat = sess.run(
                    self.ep_summary['test_stat_op'], ep_summary_feed_dict)

            else:
                # BTGym
                fetched_episode_stat = sess.run(self.ep_summary['stat_op'],
                                                ep_summary_feed_dict)

            self.summary_writer.add_summary(fetched_episode_stat,
                                            sess.run(self.global_episode))
            self.summary_writer.flush()

        wirte_model_summary =\
            self.local_steps % self.model_summary_freq == 0

        # Look for renderings (chief worker only, always 0-numbered environment):
        if self.task == 0:
            if data['render_summary'][0] is not None:
                render_feed_dict = {
                    self.ep_summary[key]: pic
                    for key, pic in data['render_summary'][0].items()
                }
                if self.test_mode:
                    renderings = sess.run(self.ep_summary['test_render_op'],
                                          render_feed_dict)

                else:
                    renderings = sess.run(self.ep_summary['render_op'],
                                          render_feed_dict)

                self.summary_writer.add_summary(renderings,
                                                sess.run(self.global_episode))
                self.summary_writer.flush()

        fetches = [self.train_op]

        if wirte_model_summary:
            fetches_last = fetches + [self.model_summary_op, self.inc_step]
        else:
            fetches_last = fetches + [self.inc_step]

        # Do a number of SGD train epochs:
        # When doing more than one epoch, we actually use only last summary:
        for i in range(self.num_epochs - 1):
            fetched = sess.run(fetches, feed_dict=feed_dict)

        fetched = sess.run(fetches_last, feed_dict=feed_dict)

        if wirte_model_summary:
            self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]),
                                            fetched[-1])
            self.summary_writer.flush()

        self.local_steps += 1