Esempio n. 1
0
class BaseTrainer(object):
    def __init__(self,
               runner,
               global_network,
               initial_learning_rate,
               learning_rate_input,
               grad_applier,
               env_type,
               env_name,
               entropy_beta,
               gamma,
               experience,
               max_global_time_step,
               device,
               value_lambda):
        self.runner = runner
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.gamma = gamma
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)
        self.obs_size = Environment.get_obs_size(env_type, env_name)
        self.global_network = global_network
        self.local_network = UnrealModel(self.action_size,
                                         self.obs_size,
                                         1,
                                         entropy_beta,
                                         device,
                                         value_lambda=value_lambda)

        self.local_network.prepare_loss()
        
        self.apply_gradients = grad_applier.minimize_local(self.local_network.total_loss,
                                                                    self.global_network.get_vars(),
                                                                     self.local_network.get_vars())
        self.sync = self.local_network.sync_from(self.global_network, name="base_trainer")
        self.experience = experience
        self.local_t = 0
        self.next_log_t = 0
        self.next_performance_t = PERFORMANCE_LOG_INTERVAL
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # trackers for the experience replay creation
        self.last_state = None
        self.last_action = 0
        self.last_reward = 0
        self.ep_ploss = 0.
        self.ep_vloss = 0.
        self.ep_entr = []
        self.ep_grad = []
        self.ep_l = 0
        
    
    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (self.max_global_time_step - global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate
        
    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)
    
    def set_start_time(self, start_time, global_t):
        self.start_time = start_time
        self.local_t = global_t
        
    def pull_batch_from_queue(self):
        """
        take a rollout from the queue of the thread runner.
        """
        rollout_full = False
        count = 0
        while not rollout_full:
            if count == 0:
                rollout = self.runner.queue.get(timeout=600.0)
                count += 1
            else:
                try:
                    rollout.extend(self.runner.queue.get_nowait())
                    count += 1
                except queue.Empty:
                    #logger.warn("!!! queue was empty !!!")
                    continue
            if count == 5 or rollout.terminal:
                rollout_full = True
        #logger.debug("pulled batch from rollout, length:{}".format(len(rollout.rewards)))
        return rollout
        
    def _print_log(self, global_t):
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            logger.info("Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour".format(
            global_t,  elapsed_time, steps_per_sec, steps_per_sec * 3600 / 1000000.))
    
    def _add_batch_to_exp(self, batch):
        # if we just started, copy the first state as last state
        if self.last_state is None:
                self.last_state = batch.si[0]
        #logger.debug("adding batch to exp. len:{}".format(len(batch.si)))
        for k in range(len(batch.si)):
            state = batch.si[k]
            action = batch.a[k]#np.argmax(batch.a[k])
            reward = batch.a_r[k][-1]

            self.episode_reward += reward
            features = batch.features[k]
            pixel_change = batch.pc[k]
            #logger.debug("k = {} of {} -- terminal = {}".format(k,len(batch.si), batch.terminal))
            if k == len(batch.si)-1 and batch.terminal:
                terminal = True
            else:
                terminal = False
            frame = ExperienceFrame(state, reward, action, terminal, features, pixel_change,

                            self.last_action, self.last_reward)
            self.experience.add_frame(frame)
            self.last_state = state
            self.last_action = action
            self.last_reward = reward
            
        if terminal:
            total_ep_reward = self.episode_reward
            self.episode_reward = 0
            return total_ep_reward
        else:
            return None
            
    
    def process(self, sess, global_t, summary_writer, summary_op, summary_values, base_lambda):
        sess.run(self.sync)
        cur_learning_rate = self._anneal_learning_rate(global_t)
        # Copy weights from shared to local
        #logger.debug("Syncing to global net -- current learning rate:{}".format(cur_learning_rate))
        #logger.debug("local_t:{} - global_t:{}".format(self.local_t,global_t))


        # get batch from process_rollout
        rollout = self.pull_batch_from_queue()
        batch = process_rollout(rollout, gamma=0.99, lambda_=base_lambda)
        self.local_t += len(batch.si)


        #logger.debug("si:{}".format(batch.si.shape))
        feed_dict = {
            self.local_network.base_input: batch.si,
            self.local_network.base_last_action_reward_input: batch.a_r,
            self.local_network.base_a: batch.a,
            self.local_network.base_adv: batch.adv,
            self.local_network.base_r: batch.r,
            #self.local_network.base_initial_lstm_state: batch.features[0],
            # [common]
            self.learning_rate_input: cur_learning_rate
        }
        
        
        # Calculate gradients and copy them to global network.
        [_, grad], policy_loss, value_loss, entropy, baseinput, policy, value = sess.run(
                                              [self.apply_gradients,
                                              self.local_network.policy_loss,
                                              self.local_network.value_loss,
                                              self.local_network.entropy, 
                                              self.local_network.base_input,
                                              self.local_network.base_pi,
                                              self.local_network.base_v],
                                     feed_dict=feed_dict )
        self.ep_l += batch.si.shape[0]
        self.ep_ploss += policy_loss
        self.ep_vloss += value_loss
        self.ep_entr.append(entropy)

        self.ep_grad.append(grad)
        # add batch to experience replay
        total_ep_reward = self._add_batch_to_exp(batch)
        if total_ep_reward is not None:
            laststate = baseinput[np.newaxis,-1,...]
            summary_str = sess.run(summary_op, feed_dict={summary_values[0]: total_ep_reward,
                                                          summary_values[1]: self.ep_l,
                                                          summary_values[2]: self.ep_ploss/self.ep_l,
                                                          summary_values[3]: self.ep_vloss/self.ep_l,
                                                          summary_values[4]: np.mean(self.ep_entr),
                                                          summary_values[5]: np.mean(self.ep_grad),
                                                          summary_values[6]: cur_learning_rate})#,
                                                          #summary_values[7]: laststate})
            summary_writer.add_summary(summary_str, global_t)
            summary_writer.flush()
                    
            if self.local_t > self.next_performance_t:
                self._print_log(global_t)
                self.next_performance_t += PERFORMANCE_LOG_INTERVAL
                    
            if self.local_t >= self.next_log_t:
                logger.info("localtime={}".format(self.local_t))
                logger.info("action={}".format(self.last_action))
                logger.info("policy={}".format(policy[-1]))
                logger.info("V={}".format(np.mean(value)))
                logger.info("ep score={}".format(total_ep_reward))
                self.next_log_t += LOG_INTERVAL
            
            #try:
            #sess.run(self.sync)
            #except Exception:
            #    logger.warn("--- !! parallel syncing !! ---")
            self.ep_l = 0
            self.ep_ploss = 0.
            self.ep_vloss = 0.
            self.ep_entr = []
            self.ep_grad = []
            
        # Return advanced local step size
        diff_global_t = self.local_t - global_t
        return diff_global_t
Esempio n. 2
0
class AuxTrainer(object):
    def __init__(self, global_network, thread_index, use_base,
                 use_pixel_change, use_value_replay, use_reward_prediction,
                 use_temporal_coherence, use_proportionality, use_causality,
                 use_repeatability, value_lambda, pixel_change_lambda,
                 temporal_coherence_lambda, proportionality_lambda,
                 causality_lambda, repeatability_lambda, initial_learning_rate,
                 learning_rate_input, grad_applier, aux_t, env_type, env_name,
                 entropy_beta, local_t_max, gamma, aux_lambda, gamma_pc,
                 experience, max_global_time_step, device):

        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.use_temporal_coherence = use_temporal_coherence
        self.use_proportionality = use_proportionality
        self.use_causality = use_causality
        self.use_repeatability = use_repeatability
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.entropy_beta = entropy_beta
        self.local_t = 0
        self.next_sync_t = 0
        self.next_log_t = 0
        self.local_t_max = local_t_max
        self.gamma = gamma
        self.aux_lambda = aux_lambda
        self.gamma_pc = gamma_pc
        self.experience = experience
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)
        self.obs_size = Environment.get_obs_size(env_type, env_name)
        self.thread_index = thread_index
        self.local_network = UnrealModel(
            self.action_size,
            self.obs_size,
            self.thread_index,
            self.entropy_beta,
            device,
            use_pixel_change=use_pixel_change,
            use_value_replay=use_value_replay,
            use_reward_prediction=use_reward_prediction,
            use_temporal_coherence=use_temporal_coherence,
            use_proportionality=use_proportionality,
            use_causality=use_causality,
            use_repeatability=use_repeatability,
            value_lambda=value_lambda,
            pixel_change_lambda=pixel_change_lambda,
            temporal_coherence_lambda=temporal_coherence_lambda,
            proportionality_lambda=proportionality_lambda,
            causality_lambda=causality_lambda,
            repeatability_lambda=repeatability_lambda,
            for_display=False,
            use_base=use_base)

        self.local_network.prepare_loss()
        self.global_network = global_network

        #logger.debug("ln.total_loss:{}".format(self.local_network.total_loss))

        self.apply_gradients = grad_applier.minimize_local(
            self.local_network.total_loss, self.global_network.get_vars(),
            self.local_network.get_vars())
        self.sync = self.local_network.sync_from(self.global_network,
                                                 name="aux_trainer_{}".format(
                                                     self.thread_index))
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # trackers for the experience replay creation
        self.last_action = np.zeros(self.action_size)
        self.last_reward = 0

        self.aux_losses = []
        self.aux_losses.append(self.local_network.policy_loss)
        self.aux_losses.append(self.local_network.value_loss)
        if self.use_pixel_change:
            self.aux_losses.append(self.local_network.pc_loss)
        if self.use_value_replay:
            self.aux_losses.append(self.local_network.vr_loss)
        if self.use_reward_prediction:
            self.aux_losses.append(self.local_network.rp_loss)
        if self.use_temporal_coherence:
            self.aux_losses.append(self.local_network.tc_loss)
        if self.use_proportionality:
            self.aux_losses.append(self.local_network.prop_loss)
        if self.use_causality:
            self.aux_losses.append(self.local_network.caus_loss)
        if self.use_repeatability:
            self.aux_losses.append(self.local_network.rep_loss)

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def discount(self, x, gamma):
        return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]

    def _process_base(self, sess, policy, gamma, lambda_=1.0):
        # base A3C from experience replay
        experience_frames = self.experience.sample_sequence(self.local_t_max +
                                                            1)
        batch_si = []
        batch_a = []
        rewards = []
        action_reward = []
        batch_features = []
        values = []
        last_state = experience_frames[0].state
        last_action_reward = experience_frames[0].concat_action_and_reward(
            experience_frames[0].action, self.action_size,
            experience_frames[0].reward)
        policy.set_state(
            np.asarray(experience_frames[0].features).reshape([2, 1, -1]))

        for frame in range(1, len(experience_frames)):
            state = experience_frames[frame].state
            #logger.debug("state:{}".format(state.shape))
            batch_si.append(state)
            action = experience_frames[frame].action
            reward = experience_frames[frame].reward
            a_r = experience_frames[frame].concat_action_and_reward(
                action, self.action_size, reward)
            action_reward.append(a_r)
            batch_a.append(a_r[:-1])
            rewards.append(reward)
            _, value, features = policy.run_base_policy_and_value(
                sess, last_state, last_action_reward)
            batch_features.append(features)
            values.append(value)
            last_state = state
            last_action_reward = action_reward[-1]

        if not experience_frames[-1].terminal:
            r = policy.run_base_value(sess, last_state, last_action_reward)
        else:
            r = 0.

        vpred_t = np.asarray(values + [r])
        rewards_plus_v = np.asarray(rewards + [r])
        batch_r = self.discount(rewards_plus_v, gamma)[:-1]
        delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
        # this formula for the advantage comes "Generalized Advantage Estimation":
        # https://arxiv.org/abs/1506.02438
        batch_adv = self.discount(delta_t, gamma * lambda_)

        start_features = []  #batch_features[0]

        return Batch(batch_si, batch_a, action_reward, batch_adv, batch_r,
                     experience_frames[-1].terminal, start_features)

    def _process_pc(self, sess):
        # [pixel change]
        # Sample 20+1 frame (+1 for last next state)
        pc_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Revese sequence to calculate from the last
        pc_experience_frames.reverse()

        batch_pc_si = []
        batch_pc_a = []
        batch_pc_R = []
        batch_pc_last_action_reward = []

        pc_R = np.zeros([20, 20], dtype=np.float32)
        if not pc_experience_frames[0].terminal:
            pc_R = self.local_network.run_pc_q_max(
                sess, pc_experience_frames[0].state,
                pc_experience_frames[0].get_last_action_reward(
                    self.action_size))

        for frame in pc_experience_frames[1:]:
            pc_R = frame.pixel_change + self.gamma_pc * pc_R
            a = np.zeros([self.action_size])
            a[frame.action] = 1.0
            last_action_reward = frame.get_last_action_reward(self.action_size)

            batch_pc_si.append(frame.state)
            batch_pc_a.append(a)
            batch_pc_R.append(pc_R)
            batch_pc_last_action_reward.append(last_action_reward)

        batch_pc_si.reverse()
        batch_pc_a.reverse()
        batch_pc_R.reverse()
        batch_pc_last_action_reward.reverse()

        return batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R

    def _process_vr(self, sess):
        # [Value replay]
        # Sample 20+1 frame (+1 for last next state)
        vr_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Revese sequence to calculate from the last
        vr_experience_frames.reverse()

        batch_vr_si = []
        batch_vr_R = []
        batch_vr_last_action_reward = []

        vr_R = 0.0
        if not vr_experience_frames[0].terminal:
            vr_R = self.local_network.run_vr_value(
                sess, vr_experience_frames[0].state,
                vr_experience_frames[0].get_last_action_reward(
                    self.action_size))

        # t_max times loop
        for frame in vr_experience_frames[1:]:
            vr_R = frame.reward + self.gamma * vr_R
            batch_vr_si.append(frame.state)
            batch_vr_R.append(vr_R)
            last_action_reward = frame.get_last_action_reward(self.action_size)
            batch_vr_last_action_reward.append(last_action_reward)

        batch_vr_si.reverse()
        batch_vr_R.reverse()
        batch_vr_last_action_reward.reverse()

        return batch_vr_si, batch_vr_last_action_reward, batch_vr_R

    def _process_rp(self):
        # [Reward prediction]
        rp_experience_frames = self.experience.sample_rp_sequence()
        # 4 frames

        batch_rp_si = []
        batch_rp_c = []

        for i in range(3):
            batch_rp_si.append(rp_experience_frames[i].state)

        # one hot vector for target reward
        r = rp_experience_frames[3].reward
        rp_c = [0.0, 0.0, 0.0]
        if r == 0:
            rp_c[0] = 1.0  # zero
        elif r > 0:
            rp_c[1] = 1.0  # positive
        else:
            rp_c[2] = 1.0  # negative
        batch_rp_c.append(rp_c)
        return batch_rp_si, batch_rp_c

    def _process_robotics(self):
        # [proportionality, causality and repeatability]
        frames1, frames2 = self.experience.sample_b2b_seq_recursive(
            self.local_t_max + 1)
        b_inp1_1 = []
        b_inp1_2 = []
        b_inp2_1 = []
        b_inp2_2 = []
        actioncheck = []
        rewardcheck = []
        for frame in range(len(frames1) - 1):
            b_inp1_1.append(frames1[frame].state)
            b_inp1_2.append(frames1[frame + 1].state)
            b_inp2_1.append(frames2[frame].state)
            b_inp2_2.append(frames2[frame + 1].state)
            if np.argmax(frames1[frame].action) == np.argmax(
                    frames2[frame].action):
                actioncheck.append(1)
            else:
                actioncheck.append(0)
            if frames1[frame].reward == frames2[frame].reward:
                rewardcheck.append(0)
            else:
                rewardcheck.append(1)
        #logger.debug(actioncheck)

        return b_inp1_1, b_inp1_2, b_inp2_1, b_inp2_2, actioncheck, rewardcheck

    def process(self, sess, global_t, aux_t, summary_writer, summary_op_aux,
                summary_aux):
        sess.run(self.sync)
        cur_learning_rate = self._anneal_learning_rate(global_t)
        """
        if self.local_t >= self.next_sync_t:
            # Copy weights from shared to local
            #logger.debug("aux_t:{} -- local_t:{} -- syncing...".format(aux_t, self.local_t))
            try:
                sess.run(self.sync)
                self.next_sync_t += SYNC_INTERVAL
            except Exception:
                logger.warn("--- !! parallel syncing !! ---")
            #logger.debug("next_sync:{}".format(self.next_sync_t))
        """

        batch = self._process_base(sess, self.local_network, self.gamma,
                                   self.aux_lambda)

        feed_dict = {
            self.local_network.base_input:
            batch.si,
            self.local_network.base_last_action_reward_input:
            batch.a_r,
            self.local_network.base_a:
            batch.a,
            self.local_network.base_adv:
            batch.adv,
            self.local_network.base_r:
            batch.r,
            #self.local_network.base_initial_lstm_state: batch.features,
            # [common]
            self.learning_rate_input:
            cur_learning_rate
        }

        # [Pixel change]
        if self.use_pixel_change:
            batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R = self._process_pc(
                sess)

            pc_feed_dict = {
                self.local_network.pc_input: batch_pc_si,
                self.local_network.pc_last_action_reward_input:
                batch_pc_last_action_reward,
                self.local_network.pc_a: batch_pc_a,
                self.local_network.pc_r: batch_pc_R,
                # [common]
                self.learning_rate_input: cur_learning_rate
            }
            feed_dict.update(pc_feed_dict)

        # [Value replay]
        if self.use_value_replay:
            batch_vr_si, batch_vr_last_action_reward, batch_vr_R = self._process_vr(
                sess)

            vr_feed_dict = {
                self.local_network.vr_input: batch_vr_si,
                self.local_network.vr_last_action_reward_input:
                batch_vr_last_action_reward,
                self.local_network.vr_r: batch_vr_R,
                # [common]
                self.learning_rate_input: cur_learning_rate
            }
            feed_dict.update(vr_feed_dict)

        # [Reward prediction]
        if self.use_reward_prediction:
            batch_rp_si, batch_rp_c = self._process_rp()
            rp_feed_dict = {
                self.local_network.rp_input: batch_rp_si,
                self.local_network.rp_c_target: batch_rp_c,
                # [common]
                self.learning_rate_input: cur_learning_rate
            }
            feed_dict.update(rp_feed_dict)

        # [Robotic Priors]
        if self.use_temporal_coherence or self.use_proportionality or self.use_causality or self.use_repeatability:
            bri11, bri12, bri21, bri22, sameact, diffrew = self._process_robotics(
            )

        #logger.debug("sameact:{}".format(sameact))
        #logger.debug("diffrew:{}".format(diffrew))

        if self.use_temporal_coherence:
            tc_feed_dict = {
                self.local_network.tc_input1: np.asarray(bri11),
                self.local_network.tc_input2: np.asarray(bri12)
            }
            feed_dict.update(tc_feed_dict)

        if self.use_proportionality:
            prop_feed_dict = {
                self.local_network.prop_input1_1: np.asarray(bri11),
                self.local_network.prop_input1_2: np.asarray(bri12),
                self.local_network.prop_input2_1: np.asarray(bri21),
                self.local_network.prop_input2_2: np.asarray(bri22),
                self.local_network.prop_actioncheck: np.asarray(sameact)
            }
            feed_dict.update(prop_feed_dict)

        if self.use_causality:
            caus_feed_dict = {
                self.local_network.caus_input1: np.asarray(bri11),
                self.local_network.caus_input2: np.asarray(bri21),
                self.local_network.caus_actioncheck: np.asarray(sameact),
                self.local_network.caus_rewardcheck: np.asarray(diffrew)
            }
            feed_dict.update(caus_feed_dict)

        if self.use_repeatability:
            rep_feed_dict = {
                self.local_network.rep_input1_1: np.asarray(bri11),
                self.local_network.rep_input1_2: np.asarray(bri12),
                self.local_network.rep_input2_1: np.asarray(bri21),
                self.local_network.rep_input2_2: np.asarray(bri22),
                self.local_network.rep_actioncheck: np.asarray(sameact)
            }
            feed_dict.update(rep_feed_dict)

        # Calculate gradients and copy them to global netowrk.
        [_, grad], losses, entropy = sess.run([
            self.apply_gradients, self.aux_losses, self.local_network.entropy
        ],
                                              feed_dict=feed_dict)

        if self.thread_index == 2 and aux_t >= self.next_log_t:
            #logger.debug("losses:{}".format(losses))

            self.next_log_t += LOG_INTERVAL
            feed_dict_aux = {}
            for k in range(len(losses)):
                feed_dict_aux.update({summary_aux[k]: losses[k]})
            feed_dict_aux.update({
                summary_aux[-2]: np.mean(entropy),
                summary_aux[-1]: np.mean(grad)
            })
            summary_str = sess.run(summary_op_aux, feed_dict=feed_dict_aux)
            summary_writer.add_summary(summary_str, aux_t)
            summary_writer.flush()

        self.local_t += len(batch.si)
        return len(batch.si)