示例#1
0
    def check_model_var_size(self, use_pixel_change, use_value_replay,
                             use_reward_prediction, var_size):
        """ Check variable size of the model """

        model = UnrealModel(1, -1, use_pixel_change, use_value_replay,
                            use_reward_prediction, 1.0, 1.0, "/cpu:0")
        variables = model.get_vars()
        self.assertEqual(len(variables), var_size)
示例#2
0
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, grad_applier, env_type, env_name,
                 use_pixel_change, use_value_replay, use_reward_prediction,
                 pixel_change_lambda, entropy_beta, local_t_max, gamma,
                 gamma_pc, experience_history_size, max_global_time_step,
                 device):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.local_t_max = local_t_max
        self.gamma = gamma
        self.gamma_pc = gamma_pc
        self.experience_history_size = experience_history_size
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)

        self.local_network = UnrealModel(self.action_size, thread_index,
                                         use_pixel_change, use_value_replay,
                                         use_reward_prediction,
                                         pixel_change_lambda, entropy_beta,
                                         device)
        self.local_network.prepare_loss()

        self.apply_gradients = grad_applier.minimize_local(
            self.local_network.total_loss, global_network.get_vars(),
            self.local_network.get_vars())

        self.sync = self.local_network.sync_from(global_network)
        self.experience = Experience(self.experience_history_size)
        self.local_t = 0
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # For log output
        self.prev_local_t = 0
示例#3
0
def save():


device = "/cpu:0"

if USE_GPU:
  device = "/gpu:0"

initial_learning_rate = log_uniform(initial_alpha_low,
                                        initial_alpha_high,
                                        initial_alpha_log_rate)
global_t = 0

action_size = len(action_list)


global_network = UnrealModel(action_size,
                                      -1,
                                      flags.use_pixel_change,
                                      flags.use_value_replay,
                                      flags.use_reward_prediction,
                                      flags.pixel_change_lambda,
                                      flags.entropy_beta,
                                      device)



learning_rate_input = tf.placeholder("float")

grad_applier = RMSPropApplier(learning_rate = learning_rate_input,
                                  decay = flags.rmsp_alpha,
                                  momentum = 0.0,
                                  epsilon = flags.rmsp_epsilon,
                                  clip_norm = flags.grad_norm_clip,
                                  device = device)

trainers = []

for i in range(flags.parallel_size):

	print('creating trainer', i)
	trainer = Trainer(i,
	                global_network,
	                initial_learning_rate,
	                learning_rate_input,
	                grad_applier,
	                flags.env_type,
	                flags.env_name,
	                flags.use_pixel_change,
	                flags.use_value_replay,
	                flags.use_reward_prediction,
	                flags.pixel_change_lambda,
	                flags.entropy_beta,
	                flags.local_t_max,
	                flags.gamma,
	                flags.gamma_pc,
	                flags.experience_history_size,
	                flags.max_time_step,
	                device)

	trainers.append(trainer)
	print('')

config = tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
config.gpu_options.allow_growth = True

sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

# summary for tensorboard
score_input = tf.placeholder(tf.int32)
tf.summary.scalar("score", score_input)

summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(flags.log_file, sess.graph)

# init or load checkpoint with saver
saver = tf.train.Saver(global_network.get_vars())

checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)

if checkpoint and checkpoint.model_checkpoint_path:
	self.saver.restore(sess, checkpoint.model_checkpoint_path)
	print("checkpoint loaded:", checkpoint.model_checkpoint_path)
	tokens = checkpoint.model_checkpoint_path.split("-")
	# set global step
	global_t = int(tokens[1])
	print(">>> global step set: ", global_t)
	# set wall time
	wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(global_t)

	with open(wall_t_fname, 'r') as f:
		wall_t = float(f.read())
		next_save_steps = (global_t + flags.save_interval_step) // flags.save_interval_step * flags.save_interval_step

else:
	print("Could not find old checkpoint")
	# set wall time
	wall_t = 0.0
	next_save_steps = flags.save_interval_step


# run training threads
train_threads = []

for i in range(flags.parallel_size):

	train_threads.append(threading.Thread(target=train_function, args=(i,True)))

#signal.signal(signal.SIGINT, signal_handler)

# set start time
start_time = time.time() - wall_t

for t in train_threads:
	t.start()
示例#4
0
class Application(object):
    def __init__(self):
        pass

    def train_function(self, parallel_index, preparing):
        """ Train each environment. """

        trainer = self.trainers[parallel_index]
        if preparing:
            trainer.prepare()

        # set start_time
        trainer.set_start_time(self.start_time)

        while True:
            if self.stop_requested:
                break
            if self.terminate_reqested:
                trainer.stop()
                break
            if self.global_t > flags.max_time_step:
                trainer.stop()
                break
            if parallel_index == 0 and self.global_t > self.next_save_steps:
                # Save checkpoint
                self.save()

            diff_global_t = trainer.process(self.sess, self.global_t,
                                            self.summary_writer,
                                            self.summary_op, self.score_input)

            self.global_t += diff_global_t

    def run(self):
        device = "/cpu:0"
        if USE_GPU:
            device = "/gpu:0"

        initial_learning_rate = log_uniform(flags.initial_alpha_low,
                                            flags.initial_alpha_high,
                                            flags.initial_alpha_log_rate)

        self.global_t = 0

        self.stop_requested = False
        self.terminate_reqested = False

        action_size = Environment.get_action_size(flags.env_type,
                                                  flags.env_name)
        #print(action_size)

        self.global_network = UnrealModel(action_size, -1,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          flags.pixel_change_lambda,
                                          flags.entropy_beta, device)
        self.trainers = []

        #print('model')

        learning_rate_input = tf.placeholder("float")

        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                      decay=flags.rmsp_alpha,
                                      momentum=0.0,
                                      epsilon=flags.rmsp_epsilon,
                                      clip_norm=flags.grad_norm_clip,
                                      device=device)

        #print(flags.parallel_size)

        for i in range(flags.parallel_size):
            trainer = Trainer(i, self.global_network, initial_learning_rate,
                              learning_rate_input, grad_applier,
                              flags.env_type, flags.env_name,
                              flags.use_pixel_change, flags.use_value_replay,
                              flags.use_reward_prediction,
                              flags.pixel_change_lambda, flags.entropy_beta,
                              flags.local_t_max, flags.gamma, flags.gamma_pc,
                              flags.experience_history_size,
                              flags.max_time_step, device)
            self.trainers.append(trainer)

        #print(len(self.trainers))

        # prepare session
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        #config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)

        self.sess.run(tf.global_variables_initializer())
        #print('init')

        # summary for tensorboard
        self.score_input = tf.placeholder(tf.int32)
        tf.summary.scalar("score", self.score_input)

        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(flags.log_file,
                                                    self.sess.graph)

        # init or load checkpoint with saver
        self.saver = tf.train.Saver(self.global_network.get_vars())

        checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)

        if checkpoint and checkpoint.model_checkpoint_path:

            self.saver.restore(self.sess, checkpoint.model_checkpoint_path)

            print("checkpoint loaded:", checkpoint.model_checkpoint_path)

            tokens = checkpoint.model_checkpoint_path.split("-")
            # set global step
            self.global_t = int(tokens[1])
            print(">>> global step set: ", self.global_t)
            # set wall time
            wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(
                self.global_t)
            with open(wall_t_fname, 'r') as f:
                self.wall_t = float(f.read())
                self.next_save_steps = (
                    self.global_t + flags.save_interval_step
                ) // flags.save_interval_step * flags.save_interval_step

        else:
            print("Could not find old checkpoint")
            # set wall time
            self.wall_t = 0.0
            self.next_save_steps = flags.save_interval_step

        #print('checkpoint stuff')

        # run training threads
        self.train_threads = []

        for i in range(flags.parallel_size):

            self.train_threads.append(
                threading.Thread(target=self.train_function, args=(i, True)))

        signal.signal(signal.SIGINT, self.signal_handler)

        # set start time
        self.start_time = time.time() - self.wall_t

        for t in self.train_threads:
            t.start()

        print('Press Ctrl+C to stop')
        signal.pause()

    def save(self):
        """ Save checkpoint.
    Called from therad-0.
    """
        self.stop_requested = True

        # Wait for all other threads to stop
        for (i, t) in enumerate(self.train_threads):
            if i != 0:
                t.join()

        # Save
        if not os.path.exists(flags.checkpoint_dir):
            os.mkdir(flags.checkpoint_dir)

        # Write wall time
        wall_t = time.time() - self.start_time
        wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(
            self.global_t)
        with open(wall_t_fname, 'w') as f:
            f.write(str(wall_t))

        print('Start saving.')
        self.saver.save(self.sess,
                        flags.checkpoint_dir + '/' + 'checkpoint',
                        global_step=self.global_t)
        print('End saving.')

        self.stop_requested = False
        self.next_save_steps += flags.save_interval_step

        # Restart other threads
        for i in range(flags.parallel_size):
            if i != 0:
                thread = threading.Thread(target=self.train_function,
                                          args=(i, False))
                self.train_threads[i] = thread
                thread.start()

    def signal_handler(self, signal, frame):
        print('You pressed Ctrl+C!')
        self.terminate_reqested = True
示例#5
0
class Trainer(object):
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, grad_applier, env_type, env_name,
                 use_pixel_change, use_value_replay, use_reward_prediction,
                 pixel_change_lambda, entropy_beta, local_t_max, gamma,
                 gamma_pc, experience_history_size, max_global_time_step,
                 device):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.local_t_max = local_t_max
        self.gamma = gamma
        self.gamma_pc = gamma_pc
        self.experience_history_size = experience_history_size
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)

        self.local_network = UnrealModel(self.action_size, thread_index,
                                         use_pixel_change, use_value_replay,
                                         use_reward_prediction,
                                         pixel_change_lambda, entropy_beta,
                                         device)
        self.local_network.prepare_loss()

        self.apply_gradients = grad_applier.minimize_local(
            self.local_network.total_loss, global_network.get_vars(),
            self.local_network.get_vars())

        self.sync = self.local_network.sync_from(global_network)
        self.experience = Experience(self.experience_history_size)
        self.local_t = 0
        self.initial_learning_rate = initial_learning_rate
        self.episode_reward = 0
        # For log output
        self.prev_local_t = 0

    def prepare(self):
        print('')
        print('trainer creating env...')
        print('')
        self.environment = Environment.create_environment(
            self.env_type, self.env_name)

    def stop(self):
        self.environment.stop()

    def _anneal_learning_rate(self, global_time_step):
        learning_rate = self.initial_learning_rate * (
            self.max_global_time_step -
            global_time_step) / self.max_global_time_step
        if learning_rate < 0.0:
            learning_rate = 0.0
        return learning_rate

    def choose_action(self, pi_values):
        return np.random.choice(range(len(pi_values)), p=pi_values)

    def _record_score(self, sess, summary_writer, summary_op, score_input,
                      score, global_t):
        summary_str = sess.run(summary_op, feed_dict={score_input: score})
        summary_writer.add_summary(summary_str, global_t)
        summary_writer.flush()

    def set_start_time(self, start_time):
        self.start_time = start_time

    def _fill_experience(self, sess):
        """
    Fill experience buffer until buffer is full.
    """
        prev_state = self.environment.last_state
        last_action = self.environment.last_action
        last_reward = self.environment.last_reward
        last_action_reward = ExperienceFrame.concat_action_and_reward(
            last_action, self.action_size, last_reward)

        pi_, _ = self.local_network.run_base_policy_and_value(
            sess, self.environment.last_state, last_action_reward)
        action = self.choose_action(pi_)

        new_state, reward, terminal, pixel_change = self.environment.process(
            action)

        #print('action:', action, terminal)

        frame = ExperienceFrame(prev_state, reward, action, terminal,
                                pixel_change, last_action, last_reward)
        self.experience.add_frame(frame)

        if terminal:
            self.environment.reset()
        if self.experience.is_full():
            self.environment.reset()
            print("Replay buffer filled")

    def _print_log(self, global_t):
        if (self.thread_index == 0) and (self.local_t - self.prev_local_t >=
                                         PERFORMANCE_LOG_INTERVAL):
            self.prev_local_t += PERFORMANCE_LOG_INTERVAL
            elapsed_time = time.time() - self.start_time
            steps_per_sec = global_t / elapsed_time
            print(
                "### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour"
                .format(global_t, elapsed_time, steps_per_sec,
                        steps_per_sec * 3600 / 1000000.))

    def _process_base(self, sess, global_t, summary_writer, summary_op,
                      score_input):
        # [Base A3C]
        states = []
        last_action_rewards = []
        actions = []
        rewards = []
        values = []

        terminal_end = False

        start_lstm_state = self.local_network.base_lstm_state_out

        # t_max times loop
        for _ in range(self.local_t_max):
            # Prepare last action reward
            last_action = self.environment.last_action
            last_reward = self.environment.last_reward
            last_action_reward = ExperienceFrame.concat_action_and_reward(
                last_action, self.action_size, last_reward)

            pi_, value_ = self.local_network.run_base_policy_and_value(
                sess, self.environment.last_state, last_action_reward)

            action = self.choose_action(pi_)

            states.append(self.environment.last_state)
            last_action_rewards.append(last_action_reward)
            actions.append(action)
            values.append(value_)

            if (self.thread_index == 0) and (self.local_t % LOG_INTERVAL == 0):
                print("pi={}".format(pi_))
                print(" V={}".format(value_))

            prev_state = self.environment.last_state

            # Process game
            new_state, reward, terminal, pixel_change = self.environment.process(
                action)
            frame = ExperienceFrame(prev_state, reward, action, terminal,
                                    pixel_change, last_action, last_reward)

            # Store to experience
            self.experience.add_frame(frame)

            self.episode_reward += reward

            rewards.append(reward)

            self.local_t += 1

            if terminal:
                terminal_end = True
                print("score={}".format(self.episode_reward))

                self._record_score(sess, summary_writer, summary_op,
                                   score_input, self.episode_reward, global_t)

                self.episode_reward = 0
                self.environment.reset()
                self.local_network.reset_state()
                break

        R = 0.0
        if not terminal_end:
            R = self.local_network.run_base_value(
                sess, new_state, frame.get_action_reward(self.action_size))

        actions.reverse()
        states.reverse()
        rewards.reverse()
        values.reverse()

        batch_si = []
        batch_a = []
        batch_adv = []
        batch_R = []

        for (ai, ri, si, Vi) in zip(actions, rewards, states, values):
            R = ri + self.gamma * R
            adv = R - Vi
            a = np.zeros([self.action_size])
            a[ai] = 1.0

            batch_si.append(si)
            batch_a.append(a)
            batch_adv.append(adv)
            batch_R.append(R)

        batch_si.reverse()
        batch_a.reverse()
        batch_adv.reverse()
        batch_R.reverse()

        return batch_si, last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state

    def _process_pc(self, sess):
        # [pixel change]
        # Sample 20+1 frame (+1 for last next state)
        pc_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        pc_experience_frames.reverse()

        batch_pc_si = []
        batch_pc_a = []
        batch_pc_R = []
        batch_pc_last_action_reward = []

        pc_R = np.zeros([20, 20], dtype=np.float32)
        if not pc_experience_frames[1].terminal:
            pc_R = self.local_network.run_pc_q_max(
                sess, pc_experience_frames[0].state,
                pc_experience_frames[0].get_last_action_reward(
                    self.action_size))

        for frame in pc_experience_frames[1:]:
            pc_R = frame.pixel_change + self.gamma_pc * pc_R
            a = np.zeros([self.action_size])
            a[frame.action] = 1.0
            last_action_reward = frame.get_last_action_reward(self.action_size)

            batch_pc_si.append(frame.state)
            batch_pc_a.append(a)
            batch_pc_R.append(pc_R)
            batch_pc_last_action_reward.append(last_action_reward)

        batch_pc_si.reverse()
        batch_pc_a.reverse()
        batch_pc_R.reverse()
        batch_pc_last_action_reward.reverse()

        return batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R

    def _process_vr(self, sess):
        # [Value replay]
        # Sample 20+1 frame (+1 for last next state)
        vr_experience_frames = self.experience.sample_sequence(
            self.local_t_max + 1)
        # Reverse sequence to calculate from the last
        vr_experience_frames.reverse()

        batch_vr_si = []
        batch_vr_R = []
        batch_vr_last_action_reward = []

        vr_R = 0.0
        if not vr_experience_frames[1].terminal:
            vr_R = self.local_network.run_vr_value(
                sess, vr_experience_frames[0].state,
                vr_experience_frames[0].get_last_action_reward(
                    self.action_size))

        # t_max times loop
        for frame in vr_experience_frames[1:]:
            vr_R = frame.reward + self.gamma * vr_R
            batch_vr_si.append(frame.state)
            batch_vr_R.append(vr_R)
            last_action_reward = frame.get_last_action_reward(self.action_size)
            batch_vr_last_action_reward.append(last_action_reward)

        batch_vr_si.reverse()
        batch_vr_R.reverse()
        batch_vr_last_action_reward.reverse()

        return batch_vr_si, batch_vr_last_action_reward, batch_vr_R

    def _process_rp(self):
        # [Reward prediction]
        rp_experience_frames = self.experience.sample_rp_sequence()
        # 4 frames

        batch_rp_si = []
        batch_rp_c = []

        for i in range(3):
            batch_rp_si.append(rp_experience_frames[i].state)

        # one hot vector for target reward
        r = rp_experience_frames[3].reward
        rp_c = [0.0, 0.0, 0.0]
        if r == 0:
            rp_c[0] = 1.0  # zero
        elif r > 0:
            rp_c[1] = 1.0  # positive
        else:
            rp_c[2] = 1.0  # negative
        batch_rp_c.append(rp_c)
        return batch_rp_si, batch_rp_c

    def process(self, sess, global_t, summary_writer, summary_op, score_input):
        # Fill experience replay buffer
        if not self.experience.is_full():
            self._fill_experience(sess)
            return 0

        start_local_t = self.local_t

        cur_learning_rate = self._anneal_learning_rate(global_t)

        # Copy weights from shared to local
        sess.run(self.sync)

        # [Base]
        batch_si, batch_last_action_rewards, batch_a, batch_adv, batch_R, start_lstm_state = \
              self._process_base(sess,
                                 global_t,
                                 summary_writer,
                                 summary_op,
                                 score_input)
        feed_dict = {
            self.local_network.base_input: batch_si,
            self.local_network.base_last_action_reward_input:
            batch_last_action_rewards,
            self.local_network.base_a: batch_a,
            self.local_network.base_adv: batch_adv,
            self.local_network.base_r: batch_R,
            self.local_network.base_initial_lstm_state: start_lstm_state,
            # [common]
            self.learning_rate_input: cur_learning_rate
        }

        # [Pixel change]
        if self.use_pixel_change:
            batch_pc_si, batch_pc_last_action_reward, batch_pc_a, batch_pc_R = self._process_pc(
                sess)

            pc_feed_dict = {
                self.local_network.pc_input: batch_pc_si,
                self.local_network.pc_last_action_reward_input:
                batch_pc_last_action_reward,
                self.local_network.pc_a: batch_pc_a,
                self.local_network.pc_r: batch_pc_R
            }
            feed_dict.update(pc_feed_dict)

        # [Value replay]
        if self.use_value_replay:
            batch_vr_si, batch_vr_last_action_reward, batch_vr_R = self._process_vr(
                sess)

            vr_feed_dict = {
                self.local_network.vr_input: batch_vr_si,
                self.local_network.vr_last_action_reward_input:
                batch_vr_last_action_reward,
                self.local_network.vr_r: batch_vr_R
            }
            feed_dict.update(vr_feed_dict)

        # [Reward prediction]
        if self.use_reward_prediction:
            batch_rp_si, batch_rp_c = self._process_rp()
            rp_feed_dict = {
                self.local_network.rp_input: batch_rp_si,
                self.local_network.rp_c_target: batch_rp_c
            }
            feed_dict.update(rp_feed_dict)

        # Calculate gradients and copy them to global network.
        sess.run(self.apply_gradients, feed_dict=feed_dict)

        self._print_log(global_t)

        # Return advanced local step size
        diff_local_t = self.local_t - start_local_t
        return diff_local_t
示例#6
0
    def __init__(self, env, task, visualise):
        self.env = env
        self.task = task
        self.ob_shape = [HEIGHT, WIDTH, CHANNEL]
        self.action_n = Environment.get_action_size()
        # define the network stored in ps which is used to sync
        worker_device = '/job:worker/task:{}'.format(task)
        with tf.device(
                tf.train.replica_device_setter(1,
                                               worker_device=worker_device)):
            with tf.variable_scope('global'):
                self.experience = Experience(
                    EXPERIENCE_HISTORY_SIZE)  # exp replay pool
                self.network = UnrealModel(self.action_n, self.env,
                                           self.experience)
                self.global_step = tf.get_variable('global_step',
                                                   dtype=tf.int32,
                                                   initializer=tf.constant(
                                                       0, dtype=tf.int32),
                                                   trainable=False)
        # define the local network which is used to calculate the gradient
        with tf.device(worker_device):
            with tf.variable_scope('local'):
                self.local_network = net = UnrealModel(self.action_n, self.env,
                                                       self.experience)
                net.global_step = self.global_step

            # add summaries for losses and norms
            self.batch_size = tf.to_float(tf.shape(net.base_input)[0])
            base_loss = self.local_network.base_loss
            pc_loss = self.local_network.pc_loss
            rp_loss = self.local_network.rp_loss
            vr_loss = self.local_network.vr_loss
            entropy = tf.reduce_sum(self.local_network.entropy)
            self.loss = base_loss + pc_loss + rp_loss + vr_loss
            grads = tf.gradients(self.loss, net.var_list)
            tf.summary.scalar('model/a3c_loss', base_loss / self.batch_size)
            tf.summary.scalar('model/pc_loss', pc_loss / self.batch_size)
            tf.summary.scalar('model/rp_loss', rp_loss / self.batch_size)
            tf.summary.scalar('model/vr_loss', vr_loss / self.batch_size)
            tf.summary.scalar('model/grad_global_norm', tf.global_norm(grads))
            tf.summary.scalar('model/var_global_norm',
                              tf.global_norm(net.var_list))
            tf.summary.scalar('model/entropy', entropy / self.batch_size)
            tf.summary.image('model/state', net.base_input)
            self.summary_op = tf.summary.merge_all()

            # clip the gradients to avoid gradient explosion
            grads, _ = tf.clip_by_global_norm(grads, GRAD_NORM_CLIP)

            self.sync = tf.group(*[
                v1.assign(v2)
                for v1, v2 in zip(net.var_list, self.network.var_list)
            ])
            grads_and_vars = list(zip(grads, self.network.var_list))
            inc_step = self.global_step.assign_add(tf.to_int32(
                self.batch_size))
            lr = log_uniform(LR_LOW, LR_HIGH)
            opt = tf.train.RMSPropOptimizer(learning_rate=lr,
                                            decay=RMSP_ALPHA,
                                            momentum=0.0,
                                            epsilon=RMSP_EPSILON)
            self.train_op = tf.group(opt.apply_gradients(grads_and_vars),
                                     inc_step)
            self.summary_writer = None
            self.local_step = 0