예제 #1
0
class BaseTrainer(object):
    def __init__(self,
               runner,
               global_network,
               initial_learning_rate,
               learning_rate_input,
               grad_applier,
               env_type,
               env_name,
               entropy_beta,
               gamma,
               experience,
               max_global_time_step,
               device,
               value_lambda):
        self.runner = runner
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.gamma = gamma
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)
        self.obs_size = Environment.get_obs_size(env_type, env_name)
        self.global_network = global_network
        self.local_network = UnrealModel(self.action_size,
                                         self.obs_size,
                                         1,
                                         entropy_beta,
                                         device,
                                         value_lambda=value_lambda)

        self.local_network.prepare_loss()
        
        self.apply_gradients = grad_applier.minimize_local(self.local_network.total_loss,
                                                                    self.global_network.get_vars(),
                                                                     self.local_network.get_vars())
        self.sync = self.local_network.sync_from(self.global_network, name="base_trainer")
        self.experience = experience
        self.local_t = 0
        self.next_log_t = 0
        self.next_performance_t = PERFORMANCE_LOG_INTERVAL
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # trackers for the experience replay creation
        self.last_state = None
        self.last_action = 0
        self.last_reward = 0
        self.ep_ploss = 0.
        self.ep_vloss = 0.
        self.ep_entr = []
        self.ep_grad = []
        self.ep_l = 0
        
    
    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (self.max_global_time_step - global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate
        
    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)
    
    def set_start_time(self, start_time, global_t):
        self.start_time = start_time
        self.local_t = global_t
        
    def pull_batch_from_queue(self):
        """
        take a rollout from the queue of the thread runner.
        """
        rollout_full = False
        count = 0
        while not rollout_full:
            if count == 0:
                rollout = self.runner.queue.get(timeout=600.0)
                count += 1
            else:
                try:
                    rollout.extend(self.runner.queue.get_nowait())
                    count += 1
                except queue.Empty:
                    #logger.warn("!!! queue was empty !!!")
                    continue
            if count == 5 or rollout.terminal:
                rollout_full = True
        #logger.debug("pulled batch from rollout, length:{}".format(len(rollout.rewards)))
        return rollout
        
    def _print_log(self, global_t):
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            logger.info("Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour".format(
            global_t,  elapsed_time, steps_per_sec, steps_per_sec * 3600 / 1000000.))
    
    def _add_batch_to_exp(self, batch):
        # if we just started, copy the first state as last state
        if self.last_state is None:
                self.last_state = batch.si[0]
        #logger.debug("adding batch to exp. len:{}".format(len(batch.si)))
        for k in range(len(batch.si)):
            state = batch.si[k]
            action = batch.a[k]#np.argmax(batch.a[k])
            reward = batch.a_r[k][-1]

            self.episode_reward += reward
            features = batch.features[k]
            pixel_change = batch.pc[k]
            #logger.debug("k = {} of {} -- terminal = {}".format(k,len(batch.si), batch.terminal))
            if k == len(batch.si)-1 and batch.terminal:
                terminal = True
            else:
                terminal = False
            frame = ExperienceFrame(state, reward, action, terminal, features, pixel_change,

                            self.last_action, self.last_reward)
            self.experience.add_frame(frame)
            self.last_state = state
            self.last_action = action
            self.last_reward = reward
            
        if terminal:
            total_ep_reward = self.episode_reward
            self.episode_reward = 0
            return total_ep_reward
        else:
            return None
            
    
    def process(self, sess, global_t, summary_writer, summary_op, summary_values, base_lambda):
        sess.run(self.sync)
        cur_learning_rate = self._anneal_learning_rate(global_t)
        # Copy weights from shared to local
        #logger.debug("Syncing to global net -- current learning rate:{}".format(cur_learning_rate))
        #logger.debug("local_t:{} - global_t:{}".format(self.local_t,global_t))


        # get batch from process_rollout
        rollout = self.pull_batch_from_queue()
        batch = process_rollout(rollout, gamma=0.99, lambda_=base_lambda)
        self.local_t += len(batch.si)


        #logger.debug("si:{}".format(batch.si.shape))
        feed_dict = {
            self.local_network.base_input: batch.si,
            self.local_network.base_last_action_reward_input: batch.a_r,
            self.local_network.base_a: batch.a,
            self.local_network.base_adv: batch.adv,
            self.local_network.base_r: batch.r,
            #self.local_network.base_initial_lstm_state: batch.features[0],
            # [common]
            self.learning_rate_input: cur_learning_rate
        }
        
        
        # Calculate gradients and copy them to global network.
        [_, grad], policy_loss, value_loss, entropy, baseinput, policy, value = sess.run(
                                              [self.apply_gradients,
                                              self.local_network.policy_loss,
                                              self.local_network.value_loss,
                                              self.local_network.entropy, 
                                              self.local_network.base_input,
                                              self.local_network.base_pi,
                                              self.local_network.base_v],
                                     feed_dict=feed_dict )
        self.ep_l += batch.si.shape[0]
        self.ep_ploss += policy_loss
        self.ep_vloss += value_loss
        self.ep_entr.append(entropy)

        self.ep_grad.append(grad)
        # add batch to experience replay
        total_ep_reward = self._add_batch_to_exp(batch)
        if total_ep_reward is not None:
            laststate = baseinput[np.newaxis,-1,...]
            summary_str = sess.run(summary_op, feed_dict={summary_values[0]: total_ep_reward,
                                                          summary_values[1]: self.ep_l,
                                                          summary_values[2]: self.ep_ploss/self.ep_l,
                                                          summary_values[3]: self.ep_vloss/self.ep_l,
                                                          summary_values[4]: np.mean(self.ep_entr),
                                                          summary_values[5]: np.mean(self.ep_grad),
                                                          summary_values[6]: cur_learning_rate})#,
                                                          #summary_values[7]: laststate})
            summary_writer.add_summary(summary_str, global_t)
            summary_writer.flush()
                    
            if self.local_t > self.next_performance_t:
                self._print_log(global_t)
                self.next_performance_t += PERFORMANCE_LOG_INTERVAL
                    
            if self.local_t >= self.next_log_t:
                logger.info("localtime={}".format(self.local_t))
                logger.info("action={}".format(self.last_action))
                logger.info("policy={}".format(policy[-1]))
                logger.info("V={}".format(np.mean(value)))
                logger.info("ep score={}".format(total_ep_reward))
                self.next_log_t += LOG_INTERVAL
            
            #try:
            #sess.run(self.sync)
            #except Exception:
            #    logger.warn("--- !! parallel syncing !! ---")
            self.ep_l = 0
            self.ep_ploss = 0.
            self.ep_vloss = 0.
            self.ep_entr = []
            self.ep_grad = []
            
        # Return advanced local step size
        diff_global_t = self.local_t - global_t
        return diff_global_t
예제 #2
0
class Application(object):
    def __init__(self):
        pass

    def base_train_function(self):
        """ Train routine for base_trainer. """

        trainer = self.base_trainer

        # set start_time
        trainer.set_start_time(self.start_time, self.global_t)

        while True:
            if self.stop_requested:
                break
            if self.terminate_requested:
                break
            if self.global_t > flags.max_time_step:
                break
            if self.global_t > self.next_save_steps:
                # Save checkpoint
                logger.debug("Saving at steps:{}".format(self.global_t))
                #logger.debug(self.next_save_steps)

                self.save()

            diff_global_t = trainer.process(self.sess, self.global_t,
                                            self.summary_writer,
                                            self.summary_op,
                                            self.summary_values,
                                            flags.base_lambda)
            self.global_t += diff_global_t
        logger.warn("exiting training!")
        self.terminate_requested = True
        self.environment.stop()
        #sys.exit(0)
        time.sleep(1)
        os._exit(0)

    def aux_train_function(self, aux_index):
        """ Train routine for aux_trainer. """

        trainer = self.aux_trainers[aux_index]

        while True:
            if self.global_t < 1000:
                continue
            if self.stop_requested:
                continue
                logger.warn("aux stop requested")
            if self.terminate_requested:
                break
            if self.global_t > flags.max_time_step:
                break

            diff_aux_t = trainer.process(self.sess, self.global_t, self.aux_t,
                                         self.summary_writer,
                                         self.summary_op_aux, self.summary_aux)
            self.aux_t += diff_aux_t

        logger.warn("!!! stopping aux at aux_t:{}".format(self.aux_t))

    def run(self):
        device = "/cpu:0"
        if USE_GPU:
            device = "/gpu:0"
        logger.debug("start App")
        initial_learning_rate = flags.initial_learning_rate

        self.global_t = 0
        self.aux_t = 0
        self.stop_requested = False
        self.terminate_requested = False
        logger.debug("getting action size and observation size...")
        action_size = Environment.get_action_size(flags.env_type,
                                                  flags.env_name)
        obs_size = Environment.get_obs_size(flags.env_type, flags.env_name)
        # Setup Global Network
        logger.debug("loading global model...")
        self.global_network = UnrealModel(
            action_size,
            obs_size,
            -1,
            flags.entropy_beta,
            device,
            use_pixel_change=flags.use_pixel_change,
            use_value_replay=flags.use_value_replay,
            use_reward_prediction=flags.use_reward_prediction,
            use_temporal_coherence=flags.use_temporal_coherence,
            use_proportionality=flags.use_proportionality,
            use_causality=flags.use_causality,
            use_repeatability=flags.use_repeatability,
            value_lambda=flags.value_lambda,
            pixel_change_lambda=flags.pixel_change_lambda,
            temporal_coherence_lambda=flags.temporal_coherence_lambda,
            proportionality_lambda=flags.proportionality_lambda,
            causality_lambda=flags.causality_lambda,
            repeatability_lambda=flags.repeatability_lambda)
        logger.debug("done loading global model")
        learning_rate_input = tf.placeholder("float")

        # Setup gradient calculator
        #"""
        grad_applier = RMSPropApplier(
            learning_rate=learning_rate_input,
            #decay = flags.rmsp_alpha,
            momentum=0.0,
            #epsilon = flags.rmsp_epsilon,
            clip_norm=flags.grad_norm_clip,
            device=device)
        """
        grad_applier = AdamApplier(learning_rate = learning_rate_input,
                                   clip_norm=flags.grad_norm_clip,
                                   device=device)
        """
        # Start environment
        self.environment = Environment.create_environment(
            flags.env_type, flags.env_name)
        logger.debug("done loading environment")

        # Setup runner
        self.runner = RunnerThread(flags, self.environment,
                                   self.global_network, action_size, obs_size,
                                   device, visualise)
        logger.debug("done setting up RunnerTread")

        # Setup experience
        self.experience = Experience(flags.experience_history_size)

        #@TODO check device usage: should we build a cluster?
        # Setup Base Network
        self.base_trainer = BaseTrainer(
            self.runner, self.global_network, initial_learning_rate,
            learning_rate_input, grad_applier, flags.env_type, flags.env_name,
            flags.entropy_beta, flags.gamma, self.experience,
            flags.max_time_step, device, flags.value_lambda)

        # Setup Aux Networks
        self.aux_trainers = []
        for k in range(flags.parallel_size):
            self.aux_trainers.append(
                AuxTrainer(
                    self.global_network,
                    k + 2,  #-1 is global, 0 is runnerthread, 1 is base
                    flags.use_base,
                    flags.use_pixel_change,
                    flags.use_value_replay,
                    flags.use_reward_prediction,
                    flags.use_temporal_coherence,
                    flags.use_proportionality,
                    flags.use_causality,
                    flags.use_repeatability,
                    flags.value_lambda,
                    flags.pixel_change_lambda,
                    flags.temporal_coherence_lambda,
                    flags.proportionality_lambda,
                    flags.causality_lambda,
                    flags.repeatability_lambda,
                    flags.aux_initial_learning_rate,
                    learning_rate_input,
                    grad_applier,
                    self.aux_t,
                    flags.env_type,
                    flags.env_name,
                    flags.entropy_beta,
                    flags.local_t_max,
                    flags.gamma,
                    flags.aux_lambda,
                    flags.gamma_pc,
                    self.experience,
                    flags.max_time_step,
                    device))

        # Start tensorflow session
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)

        self.sess.run(tf.global_variables_initializer())

        self.init_tensorboard()

        # init or load checkpoint with saver
        self.saver = tf.train.Saver(self.global_network.get_vars())

        checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)
        if CONTINUE_TRAINING and checkpoint and checkpoint.model_checkpoint_path:
            self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
            checkpointpath = checkpoint.model_checkpoint_path.replace(
                "/", "\\")
            logger.info("checkpoint loaded: {}".format(checkpointpath))
            tokens = checkpoint.model_checkpoint_path.split("-")
            # set global step
            self.global_t = int(tokens[1])
            logger.info(">>> global step set: {}".format(self.global_t))
            logger.info(">>> aux step: {}".format(self.aux_t))
            # set wall time
            wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(
                self.global_t)
            with open(wall_t_fname, 'r') as f:
                self.wall_t = float(f.read())
                self.next_save_steps = (
                    self.global_t + flags.save_interval_step
                ) // flags.save_interval_step * flags.save_interval_step
                logger.debug("next save steps:{}".format(self.next_save_steps))
        else:
            logger.info("Could not find old checkpoint")
            # set wall time
            self.wall_t = 0.0
            self.next_save_steps = flags.save_interval_step

        signal.signal(signal.SIGINT, self.signal_handler)

        # set start time
        self.start_time = time.time() - self.wall_t
        # Start runner
        self.runner.start_runner(self.sess)
        # Start base_network thread
        self.base_train_thread = threading.Thread(
            target=self.base_train_function, args=())
        self.base_train_thread.start()

        # Start aux_network threads
        self.aux_train_threads = []
        for k in range(flags.parallel_size):
            self.aux_train_threads.append(
                threading.Thread(target=self.aux_train_function, args=(k, )))
            self.aux_train_threads[k].start()

        logger.debug(threading.enumerate())

        logger.info('Press Ctrl+C to stop')
        signal.pause()

    def init_tensorboard(self):
        # tensorboard summary for base
        self.score_input = tf.placeholder(tf.int32)
        self.epl_input = tf.placeholder(tf.int32)
        self.policy_loss = tf.placeholder(tf.float32)
        self.value_loss = tf.placeholder(tf.float32)
        self.base_entropy = tf.placeholder(tf.float32)
        self.base_gradient = tf.placeholder(tf.float32)
        self.base_lr = tf.placeholder(tf.float32)
        #self.laststate = tf.placeholder(tf.float32, [1, flags.vis_w, flags.vis_h, len(flags.vision)], name="laststate")
        score = tf.summary.scalar("env/score", self.score_input)
        epl = tf.summary.scalar("env/ep_length", self.epl_input)
        policy_loss = tf.summary.scalar("base/policy_loss", self.policy_loss)
        value_loss = tf.summary.scalar("base/value_loss", self.value_loss)
        entropy = tf.summary.scalar("base/entropy", self.base_entropy)
        gradient = tf.summary.scalar("base/gradient", self.base_gradient)
        lr = tf.summary.scalar("base/learning_rate", self.base_lr)
        #laststate = tf.summary.image("base/laststate", self.laststate)

        self.summary_values = [
            self.score_input, self.epl_input, self.policy_loss,
            self.value_loss, self.base_entropy, self.base_gradient,
            self.base_lr
        ]  #, self.laststate]
        self.summary_op = tf.summary.merge_all(
        )  # we want to merge model histograms as well here

        # tensorboard summary for aux
        self.summary_aux = []
        aux_losses = []
        self.aux_basep_loss = tf.placeholder(tf.float32)
        self.aux_basev_loss = tf.placeholder(tf.float32)
        self.aux_entropy = tf.placeholder(tf.float32)
        self.aux_gradient = tf.placeholder(tf.float32)
        self.summary_aux.append(self.aux_basep_loss)
        self.summary_aux.append(self.aux_basev_loss)

        aux_losses.append(
            tf.summary.scalar("aux/basep_loss", self.aux_basep_loss))
        aux_losses.append(
            tf.summary.scalar("aux/basev_loss", self.aux_basev_loss))

        if flags.use_pixel_change:
            self.pc_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.pc_loss)
            aux_losses.append(tf.summary.scalar("aux/pc_loss", self.pc_loss))
        if flags.use_value_replay:
            self.vr_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.vr_loss)
            aux_losses.append(tf.summary.scalar("aux/vr_loss", self.vr_loss))
        if flags.use_reward_prediction:
            self.rp_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.rp_loss)
            aux_losses.append(tf.summary.scalar("aux/rp_loss", self.rp_loss))
        if flags.use_temporal_coherence:
            self.tc_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.tc_loss)
            aux_losses.append(tf.summary.scalar("aux/tc_loss", self.tc_loss))
        if flags.use_proportionality:
            self.prop_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.prop_loss)
            aux_losses.append(
                tf.summary.scalar("aux/prop_loss", self.prop_loss))
        if flags.use_causality:
            self.caus_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.caus_loss)
            aux_losses.append(
                tf.summary.scalar("aux/caus_loss", self.caus_loss))
        if flags.use_repeatability:
            self.repeat_loss = tf.placeholder(tf.float32)
            self.summary_aux.append(self.repeat_loss)
            aux_losses.append(
                tf.summary.scalar("aux/repeat_loss", self.repeat_loss))

        # append entropy and gradient last
        self.summary_aux.append(self.aux_entropy)
        self.summary_aux.append(self.aux_gradient)
        aux_losses.append(tf.summary.scalar("aux/entropy", self.aux_entropy))
        aux_losses.append(tf.summary.scalar("aux/gradient", self.aux_gradient))

        self.summary_op_aux = tf.summary.merge(aux_losses)

        #self.summary_op = tf.summary.merge_all()
        tensorboard_path = flags.temp_dir + TRAINING_NAME + "/"
        logger.info("tensorboard path:" + tensorboard_path)
        if not os.path.exists(tensorboard_path):
            os.makedirs(tensorboard_path)
        self.summary_writer = tf.summary.FileWriter(tensorboard_path)
        self.summary_writer.add_graph(self.sess.graph)

    def save(self):
        """ Save checkpoint. 
        Called from base_trainer.
        """
        self.stop_requested = True

        # Save
        if not os.path.exists(flags.checkpoint_dir):
            os.mkdir(flags.checkpoint_dir)

        # Write wall time
        wall_t = time.time() - self.start_time
        wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(
            self.global_t)
        with open(wall_t_fname, 'w') as f:
            f.write(str(wall_t))
        #logger.info('Start saving.')
        self.saver.save(self.sess,
                        flags.checkpoint_dir + '/' + 'checkpoint',
                        global_step=self.global_t)
        #logger.info('End saving.')

        self.stop_requested = False
        self.next_save_steps += flags.save_interval_step

    def signal_handler(self, signal, frame):
        logger.warn('Ctrl+C detected, shutting down...')
        logger.info('run name: {} -- terminated'.format(TRAINING_NAME))
        self.terminate_requested = True
예제 #3
0
class AuxTrainer(object):
    def __init__(self, global_network, thread_index, use_base,
                 use_pixel_change, use_value_replay, use_reward_prediction,
                 use_temporal_coherence, use_proportionality, use_causality,
                 use_repeatability, value_lambda, pixel_change_lambda,
                 temporal_coherence_lambda, proportionality_lambda,
                 causality_lambda, repeatability_lambda, initial_learning_rate,
                 learning_rate_input, grad_applier, aux_t, env_type, env_name,
                 entropy_beta, local_t_max, gamma, aux_lambda, gamma_pc,
                 experience, max_global_time_step, device):

        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.use_temporal_coherence = use_temporal_coherence
        self.use_proportionality = use_proportionality
        self.use_causality = use_causality
        self.use_repeatability = use_repeatability
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.entropy_beta = entropy_beta
        self.local_t = 0
        self.next_sync_t = 0
        self.next_log_t = 0
        self.local_t_max = local_t_max
        self.gamma = gamma
        self.aux_lambda = aux_lambda
        self.gamma_pc = gamma_pc
        self.experience = experience
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)
        self.obs_size = Environment.get_obs_size(env_type, env_name)
        self.thread_index = thread_index
        self.local_network = UnrealModel(
            self.action_size,
            self.obs_size,
            self.thread_index,
            self.entropy_beta,
            device,
            use_pixel_change=use_pixel_change,
            use_value_replay=use_value_replay,
            use_reward_prediction=use_reward_prediction,
            use_temporal_coherence=use_temporal_coherence,
            use_proportionality=use_proportionality,
            use_causality=use_causality,
            use_repeatability=use_repeatability,
            value_lambda=value_lambda,
            pixel_change_lambda=pixel_change_lambda,
            temporal_coherence_lambda=temporal_coherence_lambda,
            proportionality_lambda=proportionality_lambda,
            causality_lambda=causality_lambda,
            repeatability_lambda=repeatability_lambda,
            for_display=False,
            use_base=use_base)

        self.local_network.prepare_loss()
        self.global_network = global_network

        #logger.debug("ln.total_loss:{}".format(self.local_network.total_loss))

        self.apply_gradients = grad_applier.minimize_local(
            self.local_network.total_loss, self.global_network.get_vars(),
            self.local_network.get_vars())
        self.sync = self.local_network.sync_from(self.global_network,
                                                 name="aux_trainer_{}".format(
                                                     self.thread_index))
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # trackers for the experience replay creation
        self.last_action = np.zeros(self.action_size)
        self.last_reward = 0

        self.aux_losses = []
        self.aux_losses.append(self.local_network.policy_loss)
        self.aux_losses.append(self.local_network.value_loss)
        if self.use_pixel_change:
            self.aux_losses.append(self.local_network.pc_loss)
        if self.use_value_replay:
            self.aux_losses.append(self.local_network.vr_loss)
        if self.use_reward_prediction:
            self.aux_losses.append(self.local_network.rp_loss)
        if self.use_temporal_coherence:
            self.aux_losses.append(self.local_network.tc_loss)
        if self.use_proportionality:
            self.aux_losses.append(self.local_network.prop_loss)
        if self.use_causality:
            self.aux_losses.append(self.local_network.caus_loss)
        if self.use_repeatability:
            self.aux_losses.append(self.local_network.rep_loss)

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def discount(self, x, gamma):
        return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]

    def _process_base(self, sess, policy, gamma, lambda_=1.0):
        # base A3C from experience replay
        experience_frames = self.experience.sample_sequence(self.local_t_max +
                                                            1)
        batch_si = []
        batch_a = []
        rewards = []
        action_reward = []
        batch_features = []
        values = []
        last_state = experience_frames[0].state
        last_action_reward = experience_frames[0].concat_action_and_reward(
            experience_frames[0].action, self.action_size,
            experience_frames[0].reward)
        policy.set_state(
            np.asarray(experience_frames[0].features).reshape([2, 1, -1]))

        for frame in range(1, len(experience_frames)):
            state = experience_frames[frame].state
            #logger.debug("state:{}".format(state.shape))
            batch_si.append(state)
            action = experience_frames[frame].action
            reward = experience_frames[frame].reward
            a_r = experience_frames[frame].concat_action_and_reward(
                action, self.action_size, reward)
            action_reward.append(a_r)
            batch_a.append(a_r[:-1])
            rewards.append(reward)
            _, value, features = policy.run_base_policy_and_value(
                sess, last_state, last_action_reward)
            batch_features.append(features)
            values.append(value)
            last_state = state
            last_action_reward = action_reward[-1]

        if not experience_frames[-1].terminal:
            r = policy.run_base_value(sess, last_state, last_action_reward)
        else:
            r = 0.

        vpred_t = np.asarray(values + [r])
        rewards_plus_v = np.asarray(rewards + [r])
        batch_r = self.discount(rewards_plus_v, gamma)[:-1]
        delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
        # this formula for the advantage comes "Generalized Advantage Estimation":
        # https://arxiv.org/abs/1506.02438
        batch_adv = self.discount(delta_t, gamma * lambda_)

        start_features = []  #batch_features[0]

        return Batch(batch_si, batch_a, action_reward, batch_adv, batch_r,
                     experience_frames[-1].terminal, start_features)

    def _process_pc(self, sess):
        # [pixel change]
        # Sample 20+1 frame (+1 for last next state)
        pc_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Revese sequence to calculate from the last
        pc_experience_frames.reverse()

        batch_pc_si = []
        batch_pc_a = []
        batch_pc_R = []
        batch_pc_last_action_reward = []

        pc_R = np.zeros([20, 20], dtype=np.float32)
        if not pc_experience_frames[0].terminal:
            pc_R = self.local_network.run_pc_q_max(
                sess, pc_experience_frames[0].state,
                pc_experience_frames[0].get_last_action_reward(
                    self.action_size))

        for frame in pc_experience_frames[1:]:
            pc_R = frame.pixel_change + self.gamma_pc * pc_R
            a = np.zeros([self.action_size])
            a[frame.action] = 1.0
            last_action_reward = frame.get_last_action_reward(self.action_size)

            batch_pc_si.append(frame.state)
            batch_pc_a.append(a)
            batch_pc_R.append(pc_R)
            batch_pc_last_action_reward.append(last_action_reward)

        batch_pc_si.reverse()
        batch_pc_a.reverse()
        batch_pc_R.reverse()
        batch_pc_last_action_reward.reverse()

        return batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R

    def _process_vr(self, sess):
        # [Value replay]
        # Sample 20+1 frame (+1 for last next state)
        vr_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Revese sequence to calculate from the last
        vr_experience_frames.reverse()

        batch_vr_si = []
        batch_vr_R = []
        batch_vr_last_action_reward = []

        vr_R = 0.0
        if not vr_experience_frames[0].terminal:
            vr_R = self.local_network.run_vr_value(
                sess, vr_experience_frames[0].state,
                vr_experience_frames[0].get_last_action_reward(
                    self.action_size))

        # t_max times loop
        for frame in vr_experience_frames[1:]:
            vr_R = frame.reward + self.gamma * vr_R
            batch_vr_si.append(frame.state)
            batch_vr_R.append(vr_R)
            last_action_reward = frame.get_last_action_reward(self.action_size)
            batch_vr_last_action_reward.append(last_action_reward)

        batch_vr_si.reverse()
        batch_vr_R.reverse()
        batch_vr_last_action_reward.reverse()

        return batch_vr_si, batch_vr_last_action_reward, batch_vr_R

    def _process_rp(self):
        # [Reward prediction]
        rp_experience_frames = self.experience.sample_rp_sequence()
        # 4 frames

        batch_rp_si = []
        batch_rp_c = []

        for i in range(3):
            batch_rp_si.append(rp_experience_frames[i].state)

        # one hot vector for target reward
        r = rp_experience_frames[3].reward
        rp_c = [0.0, 0.0, 0.0]
        if r == 0:
            rp_c[0] = 1.0  # zero
        elif r > 0:
            rp_c[1] = 1.0  # positive
        else:
            rp_c[2] = 1.0  # negative
        batch_rp_c.append(rp_c)
        return batch_rp_si, batch_rp_c

    def _process_robotics(self):
        # [proportionality, causality and repeatability]
        frames1, frames2 = self.experience.sample_b2b_seq_recursive(
            self.local_t_max + 1)
        b_inp1_1 = []
        b_inp1_2 = []
        b_inp2_1 = []
        b_inp2_2 = []
        actioncheck = []
        rewardcheck = []
        for frame in range(len(frames1) - 1):
            b_inp1_1.append(frames1[frame].state)
            b_inp1_2.append(frames1[frame + 1].state)
            b_inp2_1.append(frames2[frame].state)
            b_inp2_2.append(frames2[frame + 1].state)
            if np.argmax(frames1[frame].action) == np.argmax(
                    frames2[frame].action):
                actioncheck.append(1)
            else:
                actioncheck.append(0)
            if frames1[frame].reward == frames2[frame].reward:
                rewardcheck.append(0)
            else:
                rewardcheck.append(1)
        #logger.debug(actioncheck)

        return b_inp1_1, b_inp1_2, b_inp2_1, b_inp2_2, actioncheck, rewardcheck

    def process(self, sess, global_t, aux_t, summary_writer, summary_op_aux,
                summary_aux):
        sess.run(self.sync)
        cur_learning_rate = self._anneal_learning_rate(global_t)
        """
        if self.local_t >= self.next_sync_t:
            # Copy weights from shared to local
            #logger.debug("aux_t:{} -- local_t:{} -- syncing...".format(aux_t, self.local_t))
            try:
                sess.run(self.sync)
                self.next_sync_t += SYNC_INTERVAL
            except Exception:
                logger.warn("--- !! parallel syncing !! ---")
            #logger.debug("next_sync:{}".format(self.next_sync_t))
        """

        batch = self._process_base(sess, self.local_network, self.gamma,
                                   self.aux_lambda)

        feed_dict = {
            self.local_network.base_input:
            batch.si,
            self.local_network.base_last_action_reward_input:
            batch.a_r,
            self.local_network.base_a:
            batch.a,
            self.local_network.base_adv:
            batch.adv,
            self.local_network.base_r:
            batch.r,
            #self.local_network.base_initial_lstm_state: batch.features,
            # [common]
            self.learning_rate_input:
            cur_learning_rate
        }

        # [Pixel change]
        if self.use_pixel_change:
            batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R = self._process_pc(
                sess)

            pc_feed_dict = {
                self.local_network.pc_input: batch_pc_si,
                self.local_network.pc_last_action_reward_input:
                batch_pc_last_action_reward,
                self.local_network.pc_a: batch_pc_a,
                self.local_network.pc_r: batch_pc_R,
                # [common]
                self.learning_rate_input: cur_learning_rate
            }
            feed_dict.update(pc_feed_dict)

        # [Value replay]
        if self.use_value_replay:
            batch_vr_si, batch_vr_last_action_reward, batch_vr_R = self._process_vr(
                sess)

            vr_feed_dict = {
                self.local_network.vr_input: batch_vr_si,
                self.local_network.vr_last_action_reward_input:
                batch_vr_last_action_reward,
                self.local_network.vr_r: batch_vr_R,
                # [common]
                self.learning_rate_input: cur_learning_rate
            }
            feed_dict.update(vr_feed_dict)

        # [Reward prediction]
        if self.use_reward_prediction:
            batch_rp_si, batch_rp_c = self._process_rp()
            rp_feed_dict = {
                self.local_network.rp_input: batch_rp_si,
                self.local_network.rp_c_target: batch_rp_c,
                # [common]
                self.learning_rate_input: cur_learning_rate
            }
            feed_dict.update(rp_feed_dict)

        # [Robotic Priors]
        if self.use_temporal_coherence or self.use_proportionality or self.use_causality or self.use_repeatability:
            bri11, bri12, bri21, bri22, sameact, diffrew = self._process_robotics(
            )

        #logger.debug("sameact:{}".format(sameact))
        #logger.debug("diffrew:{}".format(diffrew))

        if self.use_temporal_coherence:
            tc_feed_dict = {
                self.local_network.tc_input1: np.asarray(bri11),
                self.local_network.tc_input2: np.asarray(bri12)
            }
            feed_dict.update(tc_feed_dict)

        if self.use_proportionality:
            prop_feed_dict = {
                self.local_network.prop_input1_1: np.asarray(bri11),
                self.local_network.prop_input1_2: np.asarray(bri12),
                self.local_network.prop_input2_1: np.asarray(bri21),
                self.local_network.prop_input2_2: np.asarray(bri22),
                self.local_network.prop_actioncheck: np.asarray(sameact)
            }
            feed_dict.update(prop_feed_dict)

        if self.use_causality:
            caus_feed_dict = {
                self.local_network.caus_input1: np.asarray(bri11),
                self.local_network.caus_input2: np.asarray(bri21),
                self.local_network.caus_actioncheck: np.asarray(sameact),
                self.local_network.caus_rewardcheck: np.asarray(diffrew)
            }
            feed_dict.update(caus_feed_dict)

        if self.use_repeatability:
            rep_feed_dict = {
                self.local_network.rep_input1_1: np.asarray(bri11),
                self.local_network.rep_input1_2: np.asarray(bri12),
                self.local_network.rep_input2_1: np.asarray(bri21),
                self.local_network.rep_input2_2: np.asarray(bri22),
                self.local_network.rep_actioncheck: np.asarray(sameact)
            }
            feed_dict.update(rep_feed_dict)

        # Calculate gradients and copy them to global netowrk.
        [_, grad], losses, entropy = sess.run([
            self.apply_gradients, self.aux_losses, self.local_network.entropy
        ],
                                              feed_dict=feed_dict)

        if self.thread_index == 2 and aux_t >= self.next_log_t:
            #logger.debug("losses:{}".format(losses))

            self.next_log_t += LOG_INTERVAL
            feed_dict_aux = {}
            for k in range(len(losses)):
                feed_dict_aux.update({summary_aux[k]: losses[k]})
            feed_dict_aux.update({
                summary_aux[-2]: np.mean(entropy),
                summary_aux[-1]: np.mean(grad)
            })
            summary_str = sess.run(summary_op_aux, feed_dict=feed_dict_aux)
            summary_writer.add_summary(summary_str, aux_t)
            summary_writer.flush()

        self.local_t += len(batch.si)
        return len(batch.si)