def __init__( self, env, task, policy_class, policy_config, log, random_seed=0, model_gamma=0.99, # A3C decay model_gae_lambda=1.00, # GAE lambda model_beta=0.01, # entropy regularizer opt_max_train_steps=10**7, opt_decay_steps=None, opt_end_learn_rate=None, opt_learn_rate=1e-4, opt_decay=0.99, opt_momentum=0.0, opt_epsilon=1e-10, rollout_length=20, episode_summary_freq=2, # every i`th episode env_render_freq=10, # every i`th episode model_summary_freq=100, # every i`th local_step test_mode=False, # gym_atari test mode replay_memory_size=2000, replay_rollout_length=None, use_off_policy_a3c=False, use_reward_prediction=False, use_pixel_control=False, use_value_replay=False, use_rebalanced_replay=False, # simplified form of prioritized replay rebalance_skewness=2, rp_lambda=1, # aux tasks loss weights pc_lambda=0.1, vr_lambda=1, off_a3c_lambda=1, gamma_pc=0.9, # pixel change gamma-decay - not used rp_reward_threshold=0.1, # r.prediction: abs.rewards values bigger than this are considered non-zero rp_sequence_size=4, # r.prediction sampling **kwargs): """ Implementation of the UNREAL algorithm. Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism. But overall, we'll define the model, specify its inputs, and describe how the policy gradients step should be computed. """ self.log = log self.random_seed = random_seed np.random.seed(self.random_seed) self.log.debug('U_{}_rnd_seed:{}, log_u_sample_(0,1]x5: {}'.format( task, random_seed, self.log_uniform([1e-10, 1], 5))) self.env = env self.task = task self.policy_class = policy_class self.policy_config = policy_config # A3C specific: self.model_gamma = model_gamma # decay self.model_gae_lambda = model_gae_lambda # general advantage estimator lambda self.model_beta = self.log_uniform(model_beta, 1) # entropy reg. # Optimizer self.opt_max_train_steps = opt_max_train_steps self.opt_learn_rate = self.log_uniform(opt_learn_rate, 1) if opt_end_learn_rate is None: self.opt_end_learn_rate = self.opt_learn_rate else: self.opt_end_learn_rate = opt_end_learn_rate if opt_decay_steps is None: self.opt_decay_steps = self.opt_max_train_steps else: self.opt_decay_steps = opt_decay_steps self.opt_decay = opt_decay self.opt_epsilon = opt_epsilon self.opt_momentum = opt_momentum self.rollout_length = rollout_length # Summaries : self.episode_summary_freq = episode_summary_freq self.env_render_freq = env_render_freq self.model_summary_freq = model_summary_freq # If True - use ATARI gym env.: self.test_mode = test_mode # UNREAL specific: self.off_a3c_lambda = off_a3c_lambda self.rp_lambda = rp_lambda self.pc_lambda = self.log_uniform(pc_lambda, 1) self.vr_lambda = vr_lambda self.gamma_pc = gamma_pc self.replay_memory_size = replay_memory_size if replay_rollout_length is not None: self.replay_rollout_length = replay_rollout_length else: self.replay_rollout_length = rollout_length self.rp_sequence_size = rp_sequence_size self.rp_reward_threshold = rp_reward_threshold # On/off switchers for off-policy training and BaseAAC auxiliary tasks: self.use_off_policy_a3c = use_off_policy_a3c self.use_reward_prediction = use_reward_prediction self.use_pixel_control = use_pixel_control if use_off_policy_a3c: self.use_value_replay = False # v-replay is redundant in this case else: self.use_value_replay = use_value_replay self.use_rebalanced_replay = use_rebalanced_replay self.rebalance_skewness = rebalance_skewness self.use_any_aux_tasks = use_value_replay or use_pixel_control or use_reward_prediction self.use_memory = self.use_any_aux_tasks or self.use_off_policy_a3c # Make replay memory: self.memory = Memory(self.replay_memory_size, self.replay_rollout_length, self.rp_reward_threshold) self.log.info( 'U_{}: learn_rate: {:1.6f}, entropy_beta: {:1.6f}, pc_lambda: {:1.8f}.' .format(self.task, self.opt_learn_rate, self.model_beta, self.pc_lambda)) #self.log.info( # 'U_{}: max_steps: {}, decay_steps: {}, end_rate: {:1.6f},'. # format(self.task, self.opt_max_env_steps, self.opt_decay_steps, self.opt_end_learn_rate)) worker_device = "/job:worker/task:{}/cpu:0".format(task) if self.test_mode: model_input_shape = env.observation_space.shape else: model_input_shape = env.observation_space.spaces[ 'model_input'].shape # Start building graph: with tf.device( tf.train.replica_device_setter(1, worker_device=worker_device)): with tf.variable_scope("global"): self.network = self.policy_class(model_input_shape, env.action_space.n, self.rp_sequence_size, **self.policy_config) self.global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) self.global_episode = tf.get_variable( "global_episode", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) # Increment episode count: inc_episode = self.global_episode.assign_add(1) with tf.device(worker_device): with tf.variable_scope("local"): self.local_network = pi = self.policy_class( model_input_shape, env.action_space.n, self.rp_sequence_size, **self.policy_config) pi.global_step = self.global_step pi.global_episode = self.global_episode pi.inc_episode = inc_episode # Meant for Batch-norm layers: pi.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*local.*') self.log.debug( 'U_{}: local_network_upd_ops_collection:\n{}'.format( self.task, pi.update_ops)) self.log.debug('\nU_{}: local_network_var_list_to_save:'.format( self.task)) for v in pi.var_list: self.log.debug('{}: {}'.format(v.name, v.get_shape())) # On-policy A3C loss definition: self.a3c_act_target = tf.placeholder(tf.float32, [None, env.action_space.n], name="a3c_action_pl") self.a3c_adv_target = tf.placeholder(tf.float32, [None], name="a3c_advantage_pl") self.a3c_r_target = tf.placeholder(tf.float32, [None], name="a3c_return_pl") log_prob_tf = tf.nn.log_softmax(pi.a3c_logits) prob_tf = tf.nn.softmax(pi.a3c_logits) # summary only # the "policy gradients" loss: its derivative is precisely the policy gradient # notice that `a3c_action_target` is a placeholder that is provided externally. # `a3c_advantage_target` will contain the advantages, as calculated in process_rollout(): #pi_loss = - tf.reduce_mean( # tf.reduce_sum( # log_prob_tf * self.a3c_act_target, # [1] # ) * self.a3c_adv_target #) neg_log_prob_ac = tf.nn.softmax_cross_entropy_with_logits( logits=pi.a3c_logits, labels=self.a3c_act_target) mean_neg_log_prob_ac = tf.reduce_mean(neg_log_prob_ac) pi_loss = tf.reduce_mean(neg_log_prob_ac * self.a3c_adv_target) # loss of value function: #vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.a3c_vf - self.a3c_r_target)) vf_loss = 0.5 * tf.reduce_mean( tf.square(pi.a3c_vf - self.a3c_r_target)) mean_vf = tf.reduce_mean(pi.a3c_vf) #entropy = - tf.reduce_sum(prob_tf * log_prob_tf) entropy = tf.reduce_mean(self.cat_entropy(pi.a3c_logits)) a3c_bs = tf.to_float(tf.shape( pi.a3c_state_in)[0]) # on-policy batch size a3c_loss = pi_loss + vf_loss - entropy * self.model_beta # Start accumulating total loss: self.loss = a3c_loss # Base summaries: model_summaries = [ tf.summary.scalar("a3c/policy_loss", pi_loss), #tf.summary.histogram("a3c/pi_prob_d", prob_tf), tf.summary.scalar("a3c/value_loss", vf_loss), tf.summary.scalar("a3c/entropy", entropy), tf.summary.scalar("a3c/neg_log_pi", mean_neg_log_prob_ac), tf.summary.scalar("a3c/vf", mean_vf), ] # Off-policy batch size: off_bs = tf.to_float(tf.shape(pi.off_a3c_state_in)[0]) if self.use_rebalanced_replay: # Simplified importance-sampling bias correction: rebalanced_replay_weight = self.rebalance_skewness / off_bs else: rebalanced_replay_weight = 1.0 # Placeholders for off-policy training: self.off_policy_act_target = tf.placeholder( tf.float32, [None, env.action_space.n], name="off_policy_action_pl") self.off_policy_adv_target = tf.placeholder( tf.float32, [None], name="off_policy_advantage_pl") self.off_policy_r_target = tf.placeholder( tf.float32, [None], name="off_policy_return_pl") if self.use_off_policy_a3c: # Off-policy A3C loss graph mirrors on-policy: #off_log_prob_tf = tf.nn.log_softmax(pi.off_a3c_logits) #off_prob_tf = tf.nn.softmax(pi.off_a3c_logits) #off_pi_loss = - tf.reduce_sum( # tf.reduce_sum( # off_log_prob_tf * self.off_policy_action_target, # [1] # ) * self.off_policy_advantage_target #) off_neg_log_prob_ac = tf.nn.softmax_cross_entropy_with_logits( logits=pi.off_a3c_logits, labels=self.off_policy_act_target) off_pi_loss = tf.reduce_mean(off_neg_log_prob_ac * self.off_policy_adv_target) # loss of value function: off_vf_loss = 0.5 * tf.reduce_mean( tf.square(pi.off_a3c_vf - self.off_policy_r_target)) off_entropy = tf.reduce_mean( self.cat_entropy(pi.off_a3c_logits)) off_a3c_loss = off_pi_loss + off_vf_loss - off_entropy * self.model_beta self.loss = self.loss + self.off_a3c_lambda * rebalanced_replay_weight * off_a3c_loss model_summaries += [ tf.summary.scalar("off_a3c/policy_loss", off_pi_loss), tf.summary.scalar("off_a3c/value_loss", off_vf_loss), ] if self.use_pixel_control: # Pixel control loss self.pc_action = tf.placeholder(tf.float32, [None, env.action_space.n], name="pc_action") self.pc_target = tf.placeholder(tf.float32, [None, None, None], name="pc_target") # Get Q-value features for actions been taken and define loss: pc_action_reshaped = tf.reshape(self.pc_action, [-1, 1, 1, env.action_space.n]) pc_q_action = tf.multiply(pi.pc_q, pc_action_reshaped) pc_q_action = tf.reduce_sum(pc_q_action, axis=-1, keep_dims=False) pc_loss = tf.reduce_sum(tf.square(self.pc_target - pc_q_action)) # sum all over self.loss = self.loss + self.pc_lambda * rebalanced_replay_weight * pc_loss # Add specific summary: model_summaries += [ tf.summary.scalar('pixel_control/q_loss', pc_loss) ] if self.use_value_replay: # Value function replay loss: self.vr_target = tf.placeholder(tf.float32, [None], name="vr_target") vr_loss = tf.reduce_mean( tf.square(pi.vr_value - self.vr_target)) self.loss = self.loss + self.vr_lambda * rebalanced_replay_weight * vr_loss model_summaries += [ tf.summary.scalar('v_replay/value_loss', vr_loss) ] if self.use_reward_prediction: # Reward prediction loss: self.rp_target = tf.placeholder(tf.float32, [1, 3], name="rp_target") rp_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels=self.rp_target, logits=pi.rp_logits)) self.loss = self.loss + self.rp_lambda * rp_loss model_summaries += [ tf.summary.scalar('r_predict/class_loss', rp_loss) ] grads = tf.gradients(self.loss, pi.var_list) grads, _ = tf.clip_by_global_norm(grads, 40.0) # copy weights from the parameter server to the local model self.sync = tf.group(*[ v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list) ]) grads_and_vars = list(zip(grads, self.network.var_list)) self.inc_step = self.global_step.assign_add( tf.shape(pi.a3c_state_in)[0]) #self.inc_step = self.global_step.assign_add(1) # Anneal learning rate: learn_rate = tf.train.polynomial_decay( self.opt_learn_rate, self.global_step + 1, self.opt_decay_steps, self.opt_end_learn_rate, power=1, cycle=True, ) # Each worker gets a different set of adam optimizer parameters opt = tf.train.AdamOptimizer(learn_rate) #opt = tf.train.RMSPropOptimizer( # learning_rate=learn_rate, # decay=0.99, # momentum=0.0, # epsilon=1e-8, #) self.train_op = tf.group(*pi.update_ops, opt.apply_gradients(grads_and_vars), self.inc_step) #self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step) # Add model-wide statistics: model_summaries += [ tf.summary.scalar("global/grad_global_norm", tf.global_norm(grads)), tf.summary.scalar("global/var_global_norm", tf.global_norm(pi.var_list)), tf.summary.scalar("global/opt_learn_rate", learn_rate), tf.summary.scalar("global/total_loss", self.loss), ] self.summary_writer = None self.local_steps = 0 self.log.debug('U_{}: train op defined'.format(self.task)) # Model stat. summary: self.model_summary_op = tf.summary.merge(model_summaries, name='model_summary') # Episode-related summaries: self.ep_summary = dict( # Summary placeholders render_human_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_model_input_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_episode_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_atari_pl=tf.placeholder(tf.uint8, [None, None, None, 1]), total_r_pl=tf.placeholder(tf.float32, ), cpu_time_pl=tf.placeholder(tf.float32, ), final_value_pl=tf.placeholder(tf.float32, ), steps_pl=tf.placeholder(tf.int32, ), ) # Environmnet rendering: self.ep_summary['render_op'] = tf.summary.merge([ tf.summary.image('human', self.ep_summary['render_human_pl']), tf.summary.image('model_input', self.ep_summary['render_model_input_pl']), tf.summary.image('episode', self.ep_summary['render_episode_pl']), ], name='render') # For Atari: self.ep_summary['test_render_op'] = tf.summary.image( "model/state", self.ep_summary['render_atari_pl']) # Episode stat. summary: self.ep_summary['stat_op'] = tf.summary.merge([ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/cpu_time_sec', self.ep_summary['cpu_time_pl']), tf.summary.scalar('episode/final_value', self.ep_summary['final_value_pl']), tf.summary.scalar('episode/env_steps', self.ep_summary['steps_pl']) ], name='episode') self.ep_summary['test_stat_op'] = tf.summary.merge( [ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/steps', self.ep_summary['steps_pl']) ], name='episode_atari') # Make runner: # `rollout_length` represents the number of "local steps": the number of timesteps # we run the policy before we update the parameters. # The larger local steps is, the lower is the variance in our policy gradients estimate # on the one hand; but on the other hand, we get less frequent parameter updates, which # slows down learning. In this code, we found that making local steps be much # smaller than 20 makes the algorithm more difficult to tune and to get to work. self.runner = RunnerThread( env, pi, task, self.rollout_length, # ~20 self.episode_summary_freq, self.env_render_freq, self.test_mode, self.ep_summary) self.log.debug('U_{}: init() done'.format(self.task))
def __init__( self, env, task, policy_config, log, random_seed=None, model_gamma=0.99, # decay model_gae_lambda=1.00, # GAE lambda model_beta=0.01, # entropy regularizer opt_max_train_steps=10**7, opt_decay_steps=None, opt_end_learn_rate=None, opt_learn_rate=1e-4, opt_decay=0.99, opt_momentum=0.0, opt_epsilon=1e-10, rollout_length=20, episode_summary_freq=2, # every i`th environment episode env_render_freq=10, # every i`th environment episode model_summary_freq=100, # every i`th algorithm iteration test_mode=False, # gym_atari test mode replay_memory_size=2000, replay_rollout_length=None, use_off_policy_aac=False, use_reward_prediction=False, use_pixel_control=False, use_value_replay=False, use_rebalanced_replay=False, # simplified form of prioritized replay rebalance_skewness=2, rp_lambda=1, # aux tasks loss weights pc_lambda=0.1, vr_lambda=1, off_aac_lambda=1, gamma_pc=0.9, # pixel change gamma-decay - not used rp_reward_threshold=0.1, # r.prediction: abs.rewards values bigger than this are considered non-zero rp_sequence_size=3, ): # r.prediction sampling """ Args: env: envirionment instance. task: int policy_config: policy estimator class and configuration dictionary log: parent log random_seed: int or None model_gamma: gamma discount factor model_gae_lambda: GAE lambda model_beta: entropy regularization beta opt_max_train_steps: train steps to run opt_decay_steps: learn ratio decay steps opt_end_learn_rate: final lerarn rate opt_learn_rate: start learn rate opt_decay: optimizer decay, if apll. opt_momentum: optimizer momentum, if apll. opt_epsilon: optimizer epsilon rollout_length: on-policy rollout length episode_summary_freq: int, write episode summary for every i'th episode env_render_freq: int, write environment rendering summary for every i'th train step model_summary_freq: int, write model summary for every i'th train step test_mode: True: Atari, False: BTGym replay_memory_size: in number of experiences replay_rollout_length: off-policy rollout length use_off_policy_aac: use full AAC off policy training instead of Value-replay use_reward_prediction: use aux. off-policy reward prediction task use_pixel_control: use aux. off-policy pixel control task use_value_replay: use aux. off-policy value replay task (not used, if use_off_policy_aac=True) use_rebalanced_replay: NOT USED rebalance_skewness: NOT USED rp_lambda: reward prediction loss weight pc_lambda: pixel control loss weight vr_lambda: value replay loss weight off_aac_lambda: off-policy AAC loss weight gamma_pc: NOT USED rp_reward_threshold: reward prediction task classification threshold, above which reward is 'non-zero' rp_sequence_size: reward prediction sample size, in number of experiences """ self.log = log self.random_seed = random_seed if self.random_seed is not None: np.random.seed(self.random_seed) tf.set_random_seed(self.random_seed) self.log.debug('AAC_{}_rnd_seed:{}, log_u_sample_(0,1]x5: {}'.format( task, random_seed, log_uniform([1e-10, 1], 5))) self.env = env self.task = task self.policy_class = policy_config['class_ref'] self.policy_kwargs = policy_config['kwargs'] # AAC specific: self.model_gamma = model_gamma # decay self.model_gae_lambda = model_gae_lambda # general advantage estimator lambda self.model_beta = log_uniform(model_beta, 1) # entropy reg. # Optimizer self.opt_max_train_steps = opt_max_train_steps self.opt_learn_rate = log_uniform(opt_learn_rate, 1) if opt_end_learn_rate is None: self.opt_end_learn_rate = self.opt_learn_rate else: self.opt_end_learn_rate = opt_end_learn_rate if opt_decay_steps is None: self.opt_decay_steps = self.opt_max_train_steps else: self.opt_decay_steps = opt_decay_steps self.opt_decay = opt_decay self.opt_epsilon = opt_epsilon self.opt_momentum = opt_momentum self.rollout_length = rollout_length # Summaries : self.episode_summary_freq = episode_summary_freq self.env_render_freq = env_render_freq self.model_summary_freq = model_summary_freq # If True - use ATARI gym env.: self.test_mode = test_mode # UNREAL specific: self.off_aac_lambda = off_aac_lambda self.rp_lambda = rp_lambda self.pc_lambda = log_uniform(pc_lambda, 1) self.vr_lambda = vr_lambda self.gamma_pc = gamma_pc self.replay_memory_size = replay_memory_size if replay_rollout_length is not None: self.replay_rollout_length = replay_rollout_length else: self.replay_rollout_length = rollout_length self.rp_sequence_size = rp_sequence_size self.rp_reward_threshold = rp_reward_threshold # On/off switchers for off-policy training and auxiliary tasks: self.use_off_policy_aac = use_off_policy_aac self.use_reward_prediction = use_reward_prediction self.use_pixel_control = use_pixel_control if use_off_policy_aac: self.use_value_replay = False # v-replay is redundant in this case else: self.use_value_replay = use_value_replay self.use_rebalanced_replay = use_rebalanced_replay self.rebalance_skewness = rebalance_skewness self.use_any_aux_tasks = use_value_replay or use_pixel_control or use_reward_prediction self.use_memory = self.use_any_aux_tasks or self.use_off_policy_aac # Make replay memory: self.memory = Memory(history_size=self.replay_memory_size, max_sample_size=self.replay_rollout_length, reward_threshold=self.rp_reward_threshold, log=self.log) self.log.info( 'AAC_{}: learn_rate: {:1.6f}, entropy_beta: {:1.6f}, pc_lambda: {:1.8f}.' .format(self.task, self.opt_learn_rate, self.model_beta, self.pc_lambda)) #self.log.info( # 'AAC_{}: max_steps: {}, decay_steps: {}, end_rate: {:1.6f},'. # format(self.task, self.opt_max_train_steps, self.opt_decay_steps, self.opt_end_learn_rate)) worker_device = "/job:worker/task:{}/cpu:0".format(task) # Infer observation space shape: if type(env.observation_space) == BTgymMultiSpace: model_input_shape = env.observation_space.get_shapes() else: model_input_shape = env.observation_space.shape # Start building graph: with tf.device( tf.train.replica_device_setter(1, worker_device=worker_device)): with tf.variable_scope("global"): self.network = self.policy_class( ob_space=model_input_shape, ac_space=env.action_space.n, rp_sequence_size=self.rp_sequence_size, **self.policy_kwargs) self.global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) self.global_episode = tf.get_variable( "global_episode", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) # Increment episode count: inc_episode = self.global_episode.assign_add(1) with tf.device(worker_device): with tf.variable_scope("local"): self.local_network = pi = self.policy_class( ob_space=model_input_shape, ac_space=env.action_space.n, rp_sequence_size=self.rp_sequence_size, **self.policy_kwargs) pi.global_step = self.global_step pi.global_episode = self.global_episode pi.inc_episode = inc_episode # Meant for Batch-norm layers: pi.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*local.*') self.log.debug( 'AAC_{}: local_network_upd_ops_collection:\n{}'.format( self.task, pi.update_ops)) self.log.debug('\nAAC_{}: local_network_var_list_to_save:'.format( self.task)) for v in pi.var_list: self.log.debug('{}: {}'.format(v.name, v.get_shape())) # Learning rate annealing: learn_rate = tf.train.polynomial_decay( self.opt_learn_rate, self.global_step + 1, self.opt_decay_steps, self.opt_end_learn_rate, power=1, cycle=False, ) # On-policy AAC loss definition: self.on_pi_act_target = tf.placeholder(tf.float32, [None, env.action_space.n], name="on_policy_action_pl") self.on_pi_adv_target = tf.placeholder( tf.float32, [None], name="on_policy_advantage_pl") self.on_pi_r_target = tf.placeholder(tf.float32, [None], name="on_policy_return_pl") on_pi_loss, on_pi_summaries = aac_loss_def( act_target=self.on_pi_act_target, adv_target=self.on_pi_adv_target, r_target=self.on_pi_r_target, pi_logits=pi.on_logits, pi_vf=pi.on_vf, entropy_beta=self.model_beta, name='on_policy/aac', verbose=True) # Start accumulating total loss: self.loss = on_pi_loss model_summaries = on_pi_summaries # Off-policy batch size: off_bs = tf.to_float(tf.shape(pi.off_state_in)[0]) if self.use_rebalanced_replay: # Simplified importance-sampling bias correction: rebalanced_replay_weight = self.rebalance_skewness / off_bs else: rebalanced_replay_weight = 1.0 # Off policy training: self.off_pi_act_target = tf.placeholder( tf.float32, [None, env.action_space.n], name="off_policy_action_pl") self.off_pi_adv_target = tf.placeholder( tf.float32, [None], name="off_policy_advantage_pl") self.off_pi_r_target = tf.placeholder(tf.float32, [None], name="off_policy_return_pl") if self.use_off_policy_aac: # Off-policy PPO loss graph mirrors on-policy: off_ppo_loss, off_ppo_summaries = aac_loss_def( act_target=self.off_pi_act_target, adv_target=self.off_pi_adv_target, r_target=self.off_pi_r_target, pi_logits=pi.off_logits, pi_vf=pi.off_vf, entropy_beta=self.model_beta, name='off_policy/aac', verbose=False) self.loss = self.loss + self.off_aac_lambda * rebalanced_replay_weight * off_ppo_loss model_summaries += off_ppo_summaries if self.use_pixel_control: # Pixel control loss: self.pc_action = tf.placeholder(tf.float32, [None, env.action_space.n], name="pc_action") self.pc_target = tf.placeholder(tf.float32, [None, None, None], name="pc_target") pc_loss, pc_summaries = pc_loss_def( actions=self.pc_action, targets=self.pc_target, pi_pc_q=pi.pc_q, name='off_policy/pixel_control', verbose=True) self.loss = self.loss + self.pc_lambda * rebalanced_replay_weight * pc_loss # Add specific summary: model_summaries += pc_summaries if self.use_value_replay: # Value function replay loss: self.vr_target = tf.placeholder(tf.float32, [None], name="vr_target") vr_loss, vr_summaries = value_fn_loss_def( r_target=self.vr_target, pi_vf=pi.vr_value, name='off_policy/value_replay', verbose=True) self.loss = self.loss + self.vr_lambda * rebalanced_replay_weight * vr_loss model_summaries += vr_summaries if self.use_reward_prediction: # Reward prediction loss: self.rp_target = tf.placeholder(tf.float32, [1, 3], name="rp_target") rp_loss, rp_summaries = rp_loss_def( rp_targets=self.rp_target, pi_rp_logits=pi.rp_logits, name='off_policy/reward_prediction', verbose=True) self.loss = self.loss + self.rp_lambda * rp_loss model_summaries += rp_summaries grads = tf.gradients(self.loss, pi.var_list) grads, _ = tf.clip_by_global_norm(grads, 40.0) # Copy weights from the parameter server to the local model self.sync = tf.group(*[ v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list) ]) grads_and_vars = list(zip(grads, self.network.var_list)) self.inc_step = self.global_step.assign_add( tf.shape(pi.on_state_in)[0]) # Each worker gets a different set of adam optimizer parameters opt = tf.train.AdamOptimizer(learn_rate, epsilon=1e-5) #opt = tf.train.RMSPropOptimizer( # learning_rate=learn_rate, # decay=0.99, # momentum=0.0, # epsilon=1e-8, #) #self.train_op = tf.group(*pi.update_ops, opt.apply_gradients(grads_and_vars), self.inc_step) #self.train_op = tf.group(opt.apply_gradients(grads_and_vars), self.inc_step) self.train_op = opt.apply_gradients(grads_and_vars) # Add model-wide statistics: with tf.name_scope('model'): model_summaries += [ tf.summary.scalar("grad_global_norm", tf.global_norm(grads)), tf.summary.scalar("var_global_norm", tf.global_norm(pi.var_list)), tf.summary.scalar("learn_rate", learn_rate), tf.summary.scalar("total_loss", self.loss), ] self.summary_writer = None self.local_steps = 0 self.log.debug('AAC_{}: train op defined'.format(self.task)) # Model stat. summary: self.model_summary_op = tf.summary.merge(model_summaries, name='model_summary') # Episode-related summaries: self.ep_summary = dict( # Summary placeholders render_human_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_model_input_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_episode_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_atari_pl=tf.placeholder(tf.uint8, [None, None, None, 1]), total_r_pl=tf.placeholder(tf.float32, ), cpu_time_pl=tf.placeholder(tf.float32, ), final_value_pl=tf.placeholder(tf.float32, ), steps_pl=tf.placeholder(tf.int32, ), ) # Environmnet rendering: self.ep_summary['render_op'] = tf.summary.merge([ tf.summary.image('human', self.ep_summary['render_human_pl']), tf.summary.image('model_input', self.ep_summary['render_model_input_pl']), tf.summary.image('episode', self.ep_summary['render_episode_pl']), ], name='render') # For Atari: self.ep_summary['test_render_op'] = tf.summary.image( "model/state", self.ep_summary['render_atari_pl']) # Episode stat. summary: self.ep_summary['stat_op'] = tf.summary.merge([ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/cpu_time_sec', self.ep_summary['cpu_time_pl']), tf.summary.scalar('episode/final_value', self.ep_summary['final_value_pl']), tf.summary.scalar('episode/env_steps', self.ep_summary['steps_pl']) ], name='episode') self.ep_summary['test_stat_op'] = tf.summary.merge( [ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/steps', self.ep_summary['steps_pl']) ], name='episode_atari') # Make runner: # `rollout_length` represents the number of "local steps": the number of timesteps # we run the policy before we update the parameters. # The larger local steps is, the lower is the variance in our policy gradients estimate # on the one hand; but on the other hand, we get less frequent parameter updates, which # slows down learning. In this code, we found that making local steps be much # smaller than 20 makes the algorithm more difficult to tune and to get to work. self.runner = RunnerThread( env, pi, task, self.rollout_length, # ~20 self.episode_summary_freq, self.env_render_freq, self.test_mode, self.ep_summary) self.log.debug('AAC_{}: init() done'.format(self.task))
class Unreal(object): """____""" def __init__( self, env, task, policy_class, policy_config, log, random_seed=0, model_gamma=0.99, # A3C decay model_gae_lambda=1.00, # GAE lambda model_beta=0.01, # entropy regularizer opt_max_train_steps=10**7, opt_decay_steps=None, opt_end_learn_rate=None, opt_learn_rate=1e-4, opt_decay=0.99, opt_momentum=0.0, opt_epsilon=1e-10, rollout_length=20, episode_summary_freq=2, # every i`th episode env_render_freq=10, # every i`th episode model_summary_freq=100, # every i`th local_step test_mode=False, # gym_atari test mode replay_memory_size=2000, replay_rollout_length=None, use_off_policy_a3c=False, use_reward_prediction=False, use_pixel_control=False, use_value_replay=False, use_rebalanced_replay=False, # simplified form of prioritized replay rebalance_skewness=2, rp_lambda=1, # aux tasks loss weights pc_lambda=0.1, vr_lambda=1, off_a3c_lambda=1, gamma_pc=0.9, # pixel change gamma-decay - not used rp_reward_threshold=0.1, # r.prediction: abs.rewards values bigger than this are considered non-zero rp_sequence_size=4, # r.prediction sampling **kwargs): """ Implementation of the UNREAL algorithm. Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism. But overall, we'll define the model, specify its inputs, and describe how the policy gradients step should be computed. """ self.log = log self.random_seed = random_seed np.random.seed(self.random_seed) self.log.debug('U_{}_rnd_seed:{}, log_u_sample_(0,1]x5: {}'.format( task, random_seed, self.log_uniform([1e-10, 1], 5))) self.env = env self.task = task self.policy_class = policy_class self.policy_config = policy_config # A3C specific: self.model_gamma = model_gamma # decay self.model_gae_lambda = model_gae_lambda # general advantage estimator lambda self.model_beta = self.log_uniform(model_beta, 1) # entropy reg. # Optimizer self.opt_max_train_steps = opt_max_train_steps self.opt_learn_rate = self.log_uniform(opt_learn_rate, 1) if opt_end_learn_rate is None: self.opt_end_learn_rate = self.opt_learn_rate else: self.opt_end_learn_rate = opt_end_learn_rate if opt_decay_steps is None: self.opt_decay_steps = self.opt_max_train_steps else: self.opt_decay_steps = opt_decay_steps self.opt_decay = opt_decay self.opt_epsilon = opt_epsilon self.opt_momentum = opt_momentum self.rollout_length = rollout_length # Summaries : self.episode_summary_freq = episode_summary_freq self.env_render_freq = env_render_freq self.model_summary_freq = model_summary_freq # If True - use ATARI gym env.: self.test_mode = test_mode # UNREAL specific: self.off_a3c_lambda = off_a3c_lambda self.rp_lambda = rp_lambda self.pc_lambda = self.log_uniform(pc_lambda, 1) self.vr_lambda = vr_lambda self.gamma_pc = gamma_pc self.replay_memory_size = replay_memory_size if replay_rollout_length is not None: self.replay_rollout_length = replay_rollout_length else: self.replay_rollout_length = rollout_length self.rp_sequence_size = rp_sequence_size self.rp_reward_threshold = rp_reward_threshold # On/off switchers for off-policy training and BaseAAC auxiliary tasks: self.use_off_policy_a3c = use_off_policy_a3c self.use_reward_prediction = use_reward_prediction self.use_pixel_control = use_pixel_control if use_off_policy_a3c: self.use_value_replay = False # v-replay is redundant in this case else: self.use_value_replay = use_value_replay self.use_rebalanced_replay = use_rebalanced_replay self.rebalance_skewness = rebalance_skewness self.use_any_aux_tasks = use_value_replay or use_pixel_control or use_reward_prediction self.use_memory = self.use_any_aux_tasks or self.use_off_policy_a3c # Make replay memory: self.memory = Memory(self.replay_memory_size, self.replay_rollout_length, self.rp_reward_threshold) self.log.info( 'U_{}: learn_rate: {:1.6f}, entropy_beta: {:1.6f}, pc_lambda: {:1.8f}.' .format(self.task, self.opt_learn_rate, self.model_beta, self.pc_lambda)) #self.log.info( # 'U_{}: max_steps: {}, decay_steps: {}, end_rate: {:1.6f},'. # format(self.task, self.opt_max_env_steps, self.opt_decay_steps, self.opt_end_learn_rate)) worker_device = "/job:worker/task:{}/cpu:0".format(task) if self.test_mode: model_input_shape = env.observation_space.shape else: model_input_shape = env.observation_space.spaces[ 'model_input'].shape # Start building graph: with tf.device( tf.train.replica_device_setter(1, worker_device=worker_device)): with tf.variable_scope("global"): self.network = self.policy_class(model_input_shape, env.action_space.n, self.rp_sequence_size, **self.policy_config) self.global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) self.global_episode = tf.get_variable( "global_episode", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) # Increment episode count: inc_episode = self.global_episode.assign_add(1) with tf.device(worker_device): with tf.variable_scope("local"): self.local_network = pi = self.policy_class( model_input_shape, env.action_space.n, self.rp_sequence_size, **self.policy_config) pi.global_step = self.global_step pi.global_episode = self.global_episode pi.inc_episode = inc_episode # Meant for Batch-norm layers: pi.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*local.*') self.log.debug( 'U_{}: local_network_upd_ops_collection:\n{}'.format( self.task, pi.update_ops)) self.log.debug('\nU_{}: local_network_var_list_to_save:'.format( self.task)) for v in pi.var_list: self.log.debug('{}: {}'.format(v.name, v.get_shape())) # On-policy A3C loss definition: self.a3c_act_target = tf.placeholder(tf.float32, [None, env.action_space.n], name="a3c_action_pl") self.a3c_adv_target = tf.placeholder(tf.float32, [None], name="a3c_advantage_pl") self.a3c_r_target = tf.placeholder(tf.float32, [None], name="a3c_return_pl") log_prob_tf = tf.nn.log_softmax(pi.a3c_logits) prob_tf = tf.nn.softmax(pi.a3c_logits) # summary only # the "policy gradients" loss: its derivative is precisely the policy gradient # notice that `a3c_action_target` is a placeholder that is provided externally. # `a3c_advantage_target` will contain the advantages, as calculated in process_rollout(): #pi_loss = - tf.reduce_mean( # tf.reduce_sum( # log_prob_tf * self.a3c_act_target, # [1] # ) * self.a3c_adv_target #) neg_log_prob_ac = tf.nn.softmax_cross_entropy_with_logits( logits=pi.a3c_logits, labels=self.a3c_act_target) mean_neg_log_prob_ac = tf.reduce_mean(neg_log_prob_ac) pi_loss = tf.reduce_mean(neg_log_prob_ac * self.a3c_adv_target) # loss of value function: #vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.a3c_vf - self.a3c_r_target)) vf_loss = 0.5 * tf.reduce_mean( tf.square(pi.a3c_vf - self.a3c_r_target)) mean_vf = tf.reduce_mean(pi.a3c_vf) #entropy = - tf.reduce_sum(prob_tf * log_prob_tf) entropy = tf.reduce_mean(self.cat_entropy(pi.a3c_logits)) a3c_bs = tf.to_float(tf.shape( pi.a3c_state_in)[0]) # on-policy batch size a3c_loss = pi_loss + vf_loss - entropy * self.model_beta # Start accumulating total loss: self.loss = a3c_loss # Base summaries: model_summaries = [ tf.summary.scalar("a3c/policy_loss", pi_loss), #tf.summary.histogram("a3c/pi_prob_d", prob_tf), tf.summary.scalar("a3c/value_loss", vf_loss), tf.summary.scalar("a3c/entropy", entropy), tf.summary.scalar("a3c/neg_log_pi", mean_neg_log_prob_ac), tf.summary.scalar("a3c/vf", mean_vf), ] # Off-policy batch size: off_bs = tf.to_float(tf.shape(pi.off_a3c_state_in)[0]) if self.use_rebalanced_replay: # Simplified importance-sampling bias correction: rebalanced_replay_weight = self.rebalance_skewness / off_bs else: rebalanced_replay_weight = 1.0 # Placeholders for off-policy training: self.off_policy_act_target = tf.placeholder( tf.float32, [None, env.action_space.n], name="off_policy_action_pl") self.off_policy_adv_target = tf.placeholder( tf.float32, [None], name="off_policy_advantage_pl") self.off_policy_r_target = tf.placeholder( tf.float32, [None], name="off_policy_return_pl") if self.use_off_policy_a3c: # Off-policy A3C loss graph mirrors on-policy: #off_log_prob_tf = tf.nn.log_softmax(pi.off_a3c_logits) #off_prob_tf = tf.nn.softmax(pi.off_a3c_logits) #off_pi_loss = - tf.reduce_sum( # tf.reduce_sum( # off_log_prob_tf * self.off_policy_action_target, # [1] # ) * self.off_policy_advantage_target #) off_neg_log_prob_ac = tf.nn.softmax_cross_entropy_with_logits( logits=pi.off_a3c_logits, labels=self.off_policy_act_target) off_pi_loss = tf.reduce_mean(off_neg_log_prob_ac * self.off_policy_adv_target) # loss of value function: off_vf_loss = 0.5 * tf.reduce_mean( tf.square(pi.off_a3c_vf - self.off_policy_r_target)) off_entropy = tf.reduce_mean( self.cat_entropy(pi.off_a3c_logits)) off_a3c_loss = off_pi_loss + off_vf_loss - off_entropy * self.model_beta self.loss = self.loss + self.off_a3c_lambda * rebalanced_replay_weight * off_a3c_loss model_summaries += [ tf.summary.scalar("off_a3c/policy_loss", off_pi_loss), tf.summary.scalar("off_a3c/value_loss", off_vf_loss), ] if self.use_pixel_control: # Pixel control loss self.pc_action = tf.placeholder(tf.float32, [None, env.action_space.n], name="pc_action") self.pc_target = tf.placeholder(tf.float32, [None, None, None], name="pc_target") # Get Q-value features for actions been taken and define loss: pc_action_reshaped = tf.reshape(self.pc_action, [-1, 1, 1, env.action_space.n]) pc_q_action = tf.multiply(pi.pc_q, pc_action_reshaped) pc_q_action = tf.reduce_sum(pc_q_action, axis=-1, keep_dims=False) pc_loss = tf.reduce_sum(tf.square(self.pc_target - pc_q_action)) # sum all over self.loss = self.loss + self.pc_lambda * rebalanced_replay_weight * pc_loss # Add specific summary: model_summaries += [ tf.summary.scalar('pixel_control/q_loss', pc_loss) ] if self.use_value_replay: # Value function replay loss: self.vr_target = tf.placeholder(tf.float32, [None], name="vr_target") vr_loss = tf.reduce_mean( tf.square(pi.vr_value - self.vr_target)) self.loss = self.loss + self.vr_lambda * rebalanced_replay_weight * vr_loss model_summaries += [ tf.summary.scalar('v_replay/value_loss', vr_loss) ] if self.use_reward_prediction: # Reward prediction loss: self.rp_target = tf.placeholder(tf.float32, [1, 3], name="rp_target") rp_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels=self.rp_target, logits=pi.rp_logits)) self.loss = self.loss + self.rp_lambda * rp_loss model_summaries += [ tf.summary.scalar('r_predict/class_loss', rp_loss) ] grads = tf.gradients(self.loss, pi.var_list) grads, _ = tf.clip_by_global_norm(grads, 40.0) # copy weights from the parameter server to the local model self.sync = tf.group(*[ v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list) ]) grads_and_vars = list(zip(grads, self.network.var_list)) self.inc_step = self.global_step.assign_add( tf.shape(pi.a3c_state_in)[0]) #self.inc_step = self.global_step.assign_add(1) # Anneal learning rate: learn_rate = tf.train.polynomial_decay( self.opt_learn_rate, self.global_step + 1, self.opt_decay_steps, self.opt_end_learn_rate, power=1, cycle=True, ) # Each worker gets a different set of adam optimizer parameters opt = tf.train.AdamOptimizer(learn_rate) #opt = tf.train.RMSPropOptimizer( # learning_rate=learn_rate, # decay=0.99, # momentum=0.0, # epsilon=1e-8, #) self.train_op = tf.group(*pi.update_ops, opt.apply_gradients(grads_and_vars), self.inc_step) #self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step) # Add model-wide statistics: model_summaries += [ tf.summary.scalar("global/grad_global_norm", tf.global_norm(grads)), tf.summary.scalar("global/var_global_norm", tf.global_norm(pi.var_list)), tf.summary.scalar("global/opt_learn_rate", learn_rate), tf.summary.scalar("global/total_loss", self.loss), ] self.summary_writer = None self.local_steps = 0 self.log.debug('U_{}: train op defined'.format(self.task)) # Model stat. summary: self.model_summary_op = tf.summary.merge(model_summaries, name='model_summary') # Episode-related summaries: self.ep_summary = dict( # Summary placeholders render_human_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_model_input_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_episode_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_atari_pl=tf.placeholder(tf.uint8, [None, None, None, 1]), total_r_pl=tf.placeholder(tf.float32, ), cpu_time_pl=tf.placeholder(tf.float32, ), final_value_pl=tf.placeholder(tf.float32, ), steps_pl=tf.placeholder(tf.int32, ), ) # Environmnet rendering: self.ep_summary['render_op'] = tf.summary.merge([ tf.summary.image('human', self.ep_summary['render_human_pl']), tf.summary.image('model_input', self.ep_summary['render_model_input_pl']), tf.summary.image('episode', self.ep_summary['render_episode_pl']), ], name='render') # For Atari: self.ep_summary['test_render_op'] = tf.summary.image( "model/state", self.ep_summary['render_atari_pl']) # Episode stat. summary: self.ep_summary['stat_op'] = tf.summary.merge([ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/cpu_time_sec', self.ep_summary['cpu_time_pl']), tf.summary.scalar('episode/final_value', self.ep_summary['final_value_pl']), tf.summary.scalar('episode/env_steps', self.ep_summary['steps_pl']) ], name='episode') self.ep_summary['test_stat_op'] = tf.summary.merge( [ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/steps', self.ep_summary['steps_pl']) ], name='episode_atari') # Make runner: # `rollout_length` represents the number of "local steps": the number of timesteps # we run the policy before we update the parameters. # The larger local steps is, the lower is the variance in our policy gradients estimate # on the one hand; but on the other hand, we get less frequent parameter updates, which # slows down learning. In this code, we found that making local steps be much # smaller than 20 makes the algorithm more difficult to tune and to get to work. self.runner = RunnerThread( env, pi, task, self.rollout_length, # ~20 self.episode_summary_freq, self.env_render_freq, self.test_mode, self.ep_summary) self.log.debug('U_{}: init() done'.format(self.task)) def log_uniform(self, lo_hi, size): """ Samples from log-uniform distribution in range specified by `lo_hi`. Takes: lo_hi: either scalar or [low_value, high_value] size: sample size Returns: np.array or np.float (if size=1). """ r = np.asarray(lo_hi) try: lo = r[0] hi = r[-1] except: lo = hi = r x = np.random.random(size) log_lo = np.log(lo) log_hi = np.log(hi) v = log_lo * (1 - x) + log_hi * x if size > 1: return np.exp(v) else: return np.exp(v)[0] def cat_entropy(self, logits): a0 = logits - tf.reduce_max(logits, 1, keep_dims=True) ea0 = tf.exp(a0) z0 = tf.reduce_sum(ea0, 1, keep_dims=True) p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), 1) def start(self, sess, summary_writer): self.runner.start_runner(sess, summary_writer) # starting runner thread self.summary_writer = summary_writer def pull_batch_from_queue(self): """ Self explanatory: take a rollout from the queue of the thread runner. """ rollout = self.runner.queue.get(timeout=600.0) #self.log.debug('Rollout position:{}\nactions:{}\nrewards:{}\nlast_action:{}\nlast_reward:{}\nterminal:{}\n'. # format(rollout.position, rollout.actions, # rollout.rewards, rollout.last_actions, rollout.last_rewards, rollout.terminal)) """ while not rollout.terminal: try: rollout.extend(self.runner.queue.get_nowait()) except queue.Empty: break return rollout """ return rollout def process_rp(self, rp_experience_frames): """ Estimates reward prediction target. Returns feed dictionary for `reward prediction` loss estimation subgraph. """ batch_rp_state = [] batch_rp_target = [] for i in range(self.rp_sequence_size - 1): batch_rp_state.append(rp_experience_frames[i].state) # One hot vector for target reward (i.e. reward taken from last of sampled frames): r = rp_experience_frames[-1].reward rp_t = [0.0, 0.0, 0.0] if r > self.rp_reward_threshold: rp_t[1] = 1.0 # positive [010] elif r < -self.rp_reward_threshold: rp_t[2] = 1.0 # negative [001] else: rp_t[0] = 1.0 # zero [100] batch_rp_target.append(rp_t) feeder = { self.local_network.rp_state_in: batch_rp_state, self.rp_target: batch_rp_target } return feeder def process_vr(self, batch): """ Returns feed dictionary for `value replay` loss estimation subgraph. """ if not self.use_off_policy_a3c: # use single pass of network on same off-policy batch feeder = { pl: value for pl, value in zip( self.local_network.vr_lstm_state_pl_flatten, flatten_nested(batch.features)) } # ...passes lstm context feeder.update({ self.local_network.vr_state_in: batch.si, self.local_network.vr_a_r_in: batch.last_ar, #self.vr_action: batch.a, # don't need those for value fn. estimation #self.vr_advantage: batch.adv, # neither.. self.vr_target: batch.r, }) else: feeder = {self.vr_target: batch.r} # redundant actually :) return feeder def process_pc(self, batch): """ Returns feed dictionary for `pixel control` loss estimation subgraph. """ if not self.use_off_policy_a3c: # use single pass of network on same off-policy batch feeder = { pl: value for pl, value in zip( self.local_network.pc_lstm_state_pl_flatten, flatten_nested(batch.features)) } feeder.update({ self.local_network.pc_state_in: batch.si, self.local_network.pc_a_r_in: batch.last_ar, self.pc_action: batch.a, self.pc_target: batch.pc }) else: feeder = {self.pc_action: batch.a, self.pc_target: batch.pc} return feeder def fill_replay_memory(self, sess): """ Fills replay memory with initial experiences. Supposed to be called by worker() just before training begins. """ if self.use_memory: sess.run(self.sync) while not self.memory.is_full(): rollout = self.pull_batch_from_queue() self.memory.add_rollout(rollout) self.log.info('U_{}: replay memory filled.'.format(self.task)) def process(self, sess): """ Grabs a on_policy_rollout that's been produced by the thread runner, samples off_policy rollout[s] from replay memory and updates the parameters. The update is then sent to the parameter server. """ sess.run(self.sync) # copy weights from shared to local # Get and process on_policy_rollout for A3C train step: on_policy_rollout = self.pull_batch_from_queue() on_policy_batch = on_policy_rollout.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda) # Feeder for on-policy A3C loss estimation graph: feed_dict = { pl: value for pl, value in zip(self.local_network.a3c_lstm_state_pl_flatten, flatten_nested(on_policy_batch.features)) } # ..passes lstm context feed_dict.update({ self.local_network.a3c_state_in: on_policy_batch.si, self.local_network.a3c_a_r_in: on_policy_batch.last_ar, self.a3c_act_target: on_policy_batch.a, self.a3c_adv_target: on_policy_batch.adv, self.a3c_r_target: on_policy_batch.r, self.local_network.train_phase: True, }) if self.use_off_policy_a3c or self.use_pixel_control or self.use_value_replay: # Get sample from replay memory: if self.use_rebalanced_replay: off_policy_sample = self.memory.sample_priority( self.replay_rollout_length, skewness=self.rebalance_skewness, exact_size=False) else: off_policy_sample = self.memory.sample_uniform( self.replay_rollout_length) off_policy_rollout = Rollout() off_policy_rollout.add_memory_sample(off_policy_sample) off_policy_batch = off_policy_rollout.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda) # Feeder for off-policy A3C loss estimation graph: off_policy_feeder = { pl: value for pl, value in zip( self.local_network.off_a3c_lstm_state_pl_flatten, flatten_nested(off_policy_batch.features)) } off_policy_feeder.update({ self.local_network.off_a3c_state_in: off_policy_batch.si, self.local_network.off_a3c_a_r_in: off_policy_batch.last_ar, self.off_policy_act_target: off_policy_batch.a, self.off_policy_adv_target: off_policy_batch.adv, self.off_policy_r_target: off_policy_batch.r, }) feed_dict.update(off_policy_feeder) # Update with reward prediction subgraph: if self.use_reward_prediction: # Rebalanced 50/50 sample for RP: rp_sample = self.memory.sample_priority(self.rp_sequence_size, skewness=2, exact_size=True) feed_dict.update(self.process_rp(rp_sample)) # Pixel control ... if self.use_pixel_control: feed_dict.update(self.process_pc(off_policy_batch)) # VR... if self.use_value_replay: feed_dict.update(self.process_vr(off_policy_batch)) if self.use_memory: # Save on_policy_rollout to replay memory: self.memory.add_rollout(on_policy_rollout) # Every worker writes model summaries: should_compute_summary =\ self.local_steps % self.model_summary_freq == 0 # self.task == 0 and if should_compute_summary: fetches = [self.model_summary_op, self.train_op, self.global_step] else: fetches = [self.train_op, self.global_step] #print('TRAIN_FEED_DICT:\n', feed_dict) #print('\n=======S=======\n') #for key,value in feed_dict.items(): # try: # print(key,':', value.shape,'\n') # except: # print(key, ':', value, '\n') #print('\n=====E======\n') # And finally... fetched = sess.run(fetches, feed_dict=feed_dict) if should_compute_summary: self.summary_writer.add_summary(tf.Summary.FromString(fetched[0]), fetched[-1]) self.summary_writer.flush() self.local_steps += 1
class Unreal(object): """ Asynchronous Advantage Actor Critic with auxiliary control tasks. This UNREAL implementation borrows heavily from Kosuke Miyoshi code, under Apache License 2.0: https://miyosuda.github.io/ https://github.com/miyosuda/unreal Original A3C code comes from OpenAI repository under MIT licence: https://github.com/openai/universe-starter-agent Papers: https://arxiv.org/abs/1602.01783 https://arxiv.org/abs/1611.05397 """ def __init__( self, env, task, policy_config, log, random_seed=None, model_gamma=0.99, # decay model_gae_lambda=1.00, # GAE lambda model_beta=0.01, # entropy regularizer opt_max_train_steps=10**7, opt_decay_steps=None, opt_end_learn_rate=None, opt_learn_rate=1e-4, opt_decay=0.99, opt_momentum=0.0, opt_epsilon=1e-10, rollout_length=20, episode_summary_freq=2, # every i`th environment episode env_render_freq=10, # every i`th environment episode model_summary_freq=100, # every i`th algorithm iteration test_mode=False, # gym_atari test mode replay_memory_size=2000, replay_rollout_length=None, use_off_policy_aac=False, use_reward_prediction=False, use_pixel_control=False, use_value_replay=False, use_rebalanced_replay=False, # simplified form of prioritized replay rebalance_skewness=2, rp_lambda=1, # aux tasks loss weights pc_lambda=0.1, vr_lambda=1, off_aac_lambda=1, gamma_pc=0.9, # pixel change gamma-decay - not used rp_reward_threshold=0.1, # r.prediction: abs.rewards values bigger than this are considered non-zero rp_sequence_size=3, ): # r.prediction sampling """ Args: env: envirionment instance. task: int policy_config: policy estimator class and configuration dictionary log: parent log random_seed: int or None model_gamma: gamma discount factor model_gae_lambda: GAE lambda model_beta: entropy regularization beta opt_max_train_steps: train steps to run opt_decay_steps: learn ratio decay steps opt_end_learn_rate: final lerarn rate opt_learn_rate: start learn rate opt_decay: optimizer decay, if apll. opt_momentum: optimizer momentum, if apll. opt_epsilon: optimizer epsilon rollout_length: on-policy rollout length episode_summary_freq: int, write episode summary for every i'th episode env_render_freq: int, write environment rendering summary for every i'th train step model_summary_freq: int, write model summary for every i'th train step test_mode: True: Atari, False: BTGym replay_memory_size: in number of experiences replay_rollout_length: off-policy rollout length use_off_policy_aac: use full AAC off policy training instead of Value-replay use_reward_prediction: use aux. off-policy reward prediction task use_pixel_control: use aux. off-policy pixel control task use_value_replay: use aux. off-policy value replay task (not used, if use_off_policy_aac=True) use_rebalanced_replay: NOT USED rebalance_skewness: NOT USED rp_lambda: reward prediction loss weight pc_lambda: pixel control loss weight vr_lambda: value replay loss weight off_aac_lambda: off-policy AAC loss weight gamma_pc: NOT USED rp_reward_threshold: reward prediction task classification threshold, above which reward is 'non-zero' rp_sequence_size: reward prediction sample size, in number of experiences """ self.log = log self.random_seed = random_seed if self.random_seed is not None: np.random.seed(self.random_seed) tf.set_random_seed(self.random_seed) self.log.debug('AAC_{}_rnd_seed:{}, log_u_sample_(0,1]x5: {}'.format( task, random_seed, log_uniform([1e-10, 1], 5))) self.env = env self.task = task self.policy_class = policy_config['class_ref'] self.policy_kwargs = policy_config['kwargs'] # AAC specific: self.model_gamma = model_gamma # decay self.model_gae_lambda = model_gae_lambda # general advantage estimator lambda self.model_beta = log_uniform(model_beta, 1) # entropy reg. # Optimizer self.opt_max_train_steps = opt_max_train_steps self.opt_learn_rate = log_uniform(opt_learn_rate, 1) if opt_end_learn_rate is None: self.opt_end_learn_rate = self.opt_learn_rate else: self.opt_end_learn_rate = opt_end_learn_rate if opt_decay_steps is None: self.opt_decay_steps = self.opt_max_train_steps else: self.opt_decay_steps = opt_decay_steps self.opt_decay = opt_decay self.opt_epsilon = opt_epsilon self.opt_momentum = opt_momentum self.rollout_length = rollout_length # Summaries : self.episode_summary_freq = episode_summary_freq self.env_render_freq = env_render_freq self.model_summary_freq = model_summary_freq # If True - use ATARI gym env.: self.test_mode = test_mode # UNREAL specific: self.off_aac_lambda = off_aac_lambda self.rp_lambda = rp_lambda self.pc_lambda = log_uniform(pc_lambda, 1) self.vr_lambda = vr_lambda self.gamma_pc = gamma_pc self.replay_memory_size = replay_memory_size if replay_rollout_length is not None: self.replay_rollout_length = replay_rollout_length else: self.replay_rollout_length = rollout_length self.rp_sequence_size = rp_sequence_size self.rp_reward_threshold = rp_reward_threshold # On/off switchers for off-policy training and auxiliary tasks: self.use_off_policy_aac = use_off_policy_aac self.use_reward_prediction = use_reward_prediction self.use_pixel_control = use_pixel_control if use_off_policy_aac: self.use_value_replay = False # v-replay is redundant in this case else: self.use_value_replay = use_value_replay self.use_rebalanced_replay = use_rebalanced_replay self.rebalance_skewness = rebalance_skewness self.use_any_aux_tasks = use_value_replay or use_pixel_control or use_reward_prediction self.use_memory = self.use_any_aux_tasks or self.use_off_policy_aac # Make replay memory: self.memory = Memory(history_size=self.replay_memory_size, max_sample_size=self.replay_rollout_length, reward_threshold=self.rp_reward_threshold, log=self.log) self.log.info( 'AAC_{}: learn_rate: {:1.6f}, entropy_beta: {:1.6f}, pc_lambda: {:1.8f}.' .format(self.task, self.opt_learn_rate, self.model_beta, self.pc_lambda)) #self.log.info( # 'AAC_{}: max_steps: {}, decay_steps: {}, end_rate: {:1.6f},'. # format(self.task, self.opt_max_train_steps, self.opt_decay_steps, self.opt_end_learn_rate)) worker_device = "/job:worker/task:{}/cpu:0".format(task) # Infer observation space shape: if type(env.observation_space) == BTgymMultiSpace: model_input_shape = env.observation_space.get_shapes() else: model_input_shape = env.observation_space.shape # Start building graph: with tf.device( tf.train.replica_device_setter(1, worker_device=worker_device)): with tf.variable_scope("global"): self.network = self.policy_class( ob_space=model_input_shape, ac_space=env.action_space.n, rp_sequence_size=self.rp_sequence_size, **self.policy_kwargs) self.global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) self.global_episode = tf.get_variable( "global_episode", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) # Increment episode count: inc_episode = self.global_episode.assign_add(1) with tf.device(worker_device): with tf.variable_scope("local"): self.local_network = pi = self.policy_class( ob_space=model_input_shape, ac_space=env.action_space.n, rp_sequence_size=self.rp_sequence_size, **self.policy_kwargs) pi.global_step = self.global_step pi.global_episode = self.global_episode pi.inc_episode = inc_episode # Meant for Batch-norm layers: pi.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*local.*') self.log.debug( 'AAC_{}: local_network_upd_ops_collection:\n{}'.format( self.task, pi.update_ops)) self.log.debug('\nAAC_{}: local_network_var_list_to_save:'.format( self.task)) for v in pi.var_list: self.log.debug('{}: {}'.format(v.name, v.get_shape())) # Learning rate annealing: learn_rate = tf.train.polynomial_decay( self.opt_learn_rate, self.global_step + 1, self.opt_decay_steps, self.opt_end_learn_rate, power=1, cycle=False, ) # On-policy AAC loss definition: self.on_pi_act_target = tf.placeholder(tf.float32, [None, env.action_space.n], name="on_policy_action_pl") self.on_pi_adv_target = tf.placeholder( tf.float32, [None], name="on_policy_advantage_pl") self.on_pi_r_target = tf.placeholder(tf.float32, [None], name="on_policy_return_pl") on_pi_loss, on_pi_summaries = aac_loss_def( act_target=self.on_pi_act_target, adv_target=self.on_pi_adv_target, r_target=self.on_pi_r_target, pi_logits=pi.on_logits, pi_vf=pi.on_vf, entropy_beta=self.model_beta, name='on_policy/aac', verbose=True) # Start accumulating total loss: self.loss = on_pi_loss model_summaries = on_pi_summaries # Off-policy batch size: off_bs = tf.to_float(tf.shape(pi.off_state_in)[0]) if self.use_rebalanced_replay: # Simplified importance-sampling bias correction: rebalanced_replay_weight = self.rebalance_skewness / off_bs else: rebalanced_replay_weight = 1.0 # Off policy training: self.off_pi_act_target = tf.placeholder( tf.float32, [None, env.action_space.n], name="off_policy_action_pl") self.off_pi_adv_target = tf.placeholder( tf.float32, [None], name="off_policy_advantage_pl") self.off_pi_r_target = tf.placeholder(tf.float32, [None], name="off_policy_return_pl") if self.use_off_policy_aac: # Off-policy PPO loss graph mirrors on-policy: off_ppo_loss, off_ppo_summaries = aac_loss_def( act_target=self.off_pi_act_target, adv_target=self.off_pi_adv_target, r_target=self.off_pi_r_target, pi_logits=pi.off_logits, pi_vf=pi.off_vf, entropy_beta=self.model_beta, name='off_policy/aac', verbose=False) self.loss = self.loss + self.off_aac_lambda * rebalanced_replay_weight * off_ppo_loss model_summaries += off_ppo_summaries if self.use_pixel_control: # Pixel control loss: self.pc_action = tf.placeholder(tf.float32, [None, env.action_space.n], name="pc_action") self.pc_target = tf.placeholder(tf.float32, [None, None, None], name="pc_target") pc_loss, pc_summaries = pc_loss_def( actions=self.pc_action, targets=self.pc_target, pi_pc_q=pi.pc_q, name='off_policy/pixel_control', verbose=True) self.loss = self.loss + self.pc_lambda * rebalanced_replay_weight * pc_loss # Add specific summary: model_summaries += pc_summaries if self.use_value_replay: # Value function replay loss: self.vr_target = tf.placeholder(tf.float32, [None], name="vr_target") vr_loss, vr_summaries = value_fn_loss_def( r_target=self.vr_target, pi_vf=pi.vr_value, name='off_policy/value_replay', verbose=True) self.loss = self.loss + self.vr_lambda * rebalanced_replay_weight * vr_loss model_summaries += vr_summaries if self.use_reward_prediction: # Reward prediction loss: self.rp_target = tf.placeholder(tf.float32, [1, 3], name="rp_target") rp_loss, rp_summaries = rp_loss_def( rp_targets=self.rp_target, pi_rp_logits=pi.rp_logits, name='off_policy/reward_prediction', verbose=True) self.loss = self.loss + self.rp_lambda * rp_loss model_summaries += rp_summaries grads = tf.gradients(self.loss, pi.var_list) grads, _ = tf.clip_by_global_norm(grads, 40.0) # Copy weights from the parameter server to the local model self.sync = tf.group(*[ v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list) ]) grads_and_vars = list(zip(grads, self.network.var_list)) self.inc_step = self.global_step.assign_add( tf.shape(pi.on_state_in)[0]) # Each worker gets a different set of adam optimizer parameters opt = tf.train.AdamOptimizer(learn_rate, epsilon=1e-5) #opt = tf.train.RMSPropOptimizer( # learning_rate=learn_rate, # decay=0.99, # momentum=0.0, # epsilon=1e-8, #) #self.train_op = tf.group(*pi.update_ops, opt.apply_gradients(grads_and_vars), self.inc_step) #self.train_op = tf.group(opt.apply_gradients(grads_and_vars), self.inc_step) self.train_op = opt.apply_gradients(grads_and_vars) # Add model-wide statistics: with tf.name_scope('model'): model_summaries += [ tf.summary.scalar("grad_global_norm", tf.global_norm(grads)), tf.summary.scalar("var_global_norm", tf.global_norm(pi.var_list)), tf.summary.scalar("learn_rate", learn_rate), tf.summary.scalar("total_loss", self.loss), ] self.summary_writer = None self.local_steps = 0 self.log.debug('AAC_{}: train op defined'.format(self.task)) # Model stat. summary: self.model_summary_op = tf.summary.merge(model_summaries, name='model_summary') # Episode-related summaries: self.ep_summary = dict( # Summary placeholders render_human_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_model_input_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_episode_pl=tf.placeholder(tf.uint8, [None, None, None, 3]), render_atari_pl=tf.placeholder(tf.uint8, [None, None, None, 1]), total_r_pl=tf.placeholder(tf.float32, ), cpu_time_pl=tf.placeholder(tf.float32, ), final_value_pl=tf.placeholder(tf.float32, ), steps_pl=tf.placeholder(tf.int32, ), ) # Environmnet rendering: self.ep_summary['render_op'] = tf.summary.merge([ tf.summary.image('human', self.ep_summary['render_human_pl']), tf.summary.image('model_input', self.ep_summary['render_model_input_pl']), tf.summary.image('episode', self.ep_summary['render_episode_pl']), ], name='render') # For Atari: self.ep_summary['test_render_op'] = tf.summary.image( "model/state", self.ep_summary['render_atari_pl']) # Episode stat. summary: self.ep_summary['stat_op'] = tf.summary.merge([ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/cpu_time_sec', self.ep_summary['cpu_time_pl']), tf.summary.scalar('episode/final_value', self.ep_summary['final_value_pl']), tf.summary.scalar('episode/env_steps', self.ep_summary['steps_pl']) ], name='episode') self.ep_summary['test_stat_op'] = tf.summary.merge( [ tf.summary.scalar('episode/total_reward', self.ep_summary['total_r_pl']), tf.summary.scalar('episode/steps', self.ep_summary['steps_pl']) ], name='episode_atari') # Make runner: # `rollout_length` represents the number of "local steps": the number of timesteps # we run the policy before we update the parameters. # The larger local steps is, the lower is the variance in our policy gradients estimate # on the one hand; but on the other hand, we get less frequent parameter updates, which # slows down learning. In this code, we found that making local steps be much # smaller than 20 makes the algorithm more difficult to tune and to get to work. self.runner = RunnerThread( env, pi, task, self.rollout_length, # ~20 self.episode_summary_freq, self.env_render_freq, self.test_mode, self.ep_summary) self.log.debug('AAC_{}: init() done'.format(self.task)) def start(self, sess, summary_writer): self.runner.start_runner(sess, summary_writer) # starting runner thread self.summary_writer = summary_writer def pull_batch_from_queue(self): """ Self explanatory: take a rollout from the queue of the thread runner. """ rollout = self.runner.queue.get(timeout=600.0) #self.log.debug('Rollout position:{}\nactions:{}\nrewards:{}\nlast_action:{}\nlast_reward:{}\nterminal:{}\n'. # format(rollout.position, rollout.actions, # rollout.rewards, rollout.last_actions, rollout.last_rewards, rollout.terminal)) return rollout def process_rp(self, rp_experience_frames): """ Estimates reward prediction target. Returns feed dictionary for `reward prediction` loss estimation subgraph. """ batch_rp_state = [] batch_rp_target = [] for i in range(self.rp_sequence_size - 1): batch_rp_state.append(rp_experience_frames[i]['state']) # One hot vector for target reward (i.e. reward taken from last of sampled frames): r = rp_experience_frames[-1]['reward'] rp_t = [0.0, 0.0, 0.0] if r > self.rp_reward_threshold: rp_t[1] = 1.0 # positive [010] elif r < -self.rp_reward_threshold: rp_t[2] = 1.0 # negative [001] else: rp_t[0] = 1.0 # zero [100] batch_rp_target.append(rp_t) feeder = { self.local_network.rp_state_in: np.asarray(batch_rp_state), self.rp_target: np.asarray(batch_rp_target) } return feeder def process_vr(self, batch): """ Returns feed dictionary for `value replay` loss estimation subgraph. """ if not self.use_off_policy_aac: # use single pass of network on same off-policy batch feeder = { pl: value for pl, value in zip( self.local_network.vr_lstm_state_pl_flatten, flatten_nested(batch['context'])) } # ...passes lstm context feeder.update({ self.local_network.vr_state_in: batch['state'], self.local_network.vr_a_r_in: batch['last_action_reward'], self.vr_target: batch['r'], }) else: feeder = {self.vr_target: batch['r']} # redundant actually :) return feeder def process_pc(self, batch): """ Returns feed dictionary for `pixel control` loss estimation subgraph. """ if not self.use_off_policy_aac: # use single pass of network on same off-policy batch feeder = { pl: value for pl, value in zip( self.local_network.pc_lstm_state_pl_flatten, flatten_nested(batch['context'])) } feeder.update({ self.local_network.pc_state_in: batch['state'], self.local_network.pc_a_r_in: batch['last_action_reward'], self.pc_action: batch['action'], self.pc_target: batch['pixel_change'] }) else: feeder = { self.pc_action: batch['action'], self.pc_target: batch['pixel_change'] } return feeder def fill_replay_memory(self, sess): """ Fills replay memory with initial experiences. Supposed to be called by parent worker() just before training begins. """ if self.use_memory: sess.run(self.sync) while not self.memory.is_full(): rollout = self.pull_batch_from_queue() self.memory.add_rollout(rollout) self.log.info('AAC_{}: replay memory filled.'.format(self.task)) def process(self, sess): """ Grabs a on_policy_rollout that's been produced by the thread runner, samples off_policy rollout[s] from replay memory and updates the parameters. The update is then sent to the parameter server. """ # Copy weights from shared to local new_policy: sess.run(self.sync) # Get and process rollout for on-policy train step: on_policy_rollout = self.pull_batch_from_queue() on_policy_batch = on_policy_rollout.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda) # Feeder for on-policy AAC loss estimation graph: feed_dict = { pl: value for pl, value in zip(self.local_network.on_lstm_state_pl_flatten, flatten_nested(on_policy_batch['context'])) } feed_dict.update({ self.local_network.on_state_in: on_policy_batch['state'], self.local_network.on_a_r_in: on_policy_batch['last_action_reward'], self.on_pi_act_target: on_policy_batch['action'], self.on_pi_adv_target: on_policy_batch['advantage'], self.on_pi_r_target: on_policy_batch['r'], self.local_network.train_phase: True, }) if self.use_off_policy_aac or self.use_pixel_control or self.use_value_replay: # Get sample from replay memory: if self.use_rebalanced_replay: off_policy_sample = self.memory.sample_priority( self.replay_rollout_length, skewness=self.rebalance_skewness, exact_size=False) else: off_policy_sample = self.memory.sample_uniform( self.replay_rollout_length) off_policy_rollout = Rollout() off_policy_rollout.add_memory_sample(off_policy_sample) off_policy_batch = off_policy_rollout.process( gamma=self.model_gamma, gae_lambda=self.model_gae_lambda) # Feeder for off-policy AAC loss estimation graph: off_policy_feeder = { pl: value for pl, value in zip( self.local_network.off_lstm_state_pl_flatten, flatten_nested(off_policy_batch['context'])) } off_policy_feeder.update({ self.local_network.off_state_in: off_policy_batch['state'], self.local_network.off_a_r_in: off_policy_batch['last_action_reward'], self.off_pi_act_target: off_policy_batch['action'], self.off_pi_adv_target: off_policy_batch['advantage'], self.off_pi_r_target: off_policy_batch['r'], }) feed_dict.update(off_policy_feeder) # Update with reward prediction subgraph: if self.use_reward_prediction: # Rebalanced 50/50 sample for RP: rp_sample = self.memory.sample_priority(self.rp_sequence_size, skewness=2, exact_size=True) feed_dict.update(self.process_rp(rp_sample)) # Pixel control ... if self.use_pixel_control: feed_dict.update(self.process_pc(off_policy_batch)) # VR... if self.use_value_replay: feed_dict.update(self.process_vr(off_policy_batch)) if self.use_memory: # Save on_policy_rollout to replay memory: self.memory.add_rollout(on_policy_rollout) # Every worker writes model summaries: should_compute_summary =\ self.local_steps % self.model_summary_freq == 0 fetches = [self.train_op] if should_compute_summary: fetches = [self.train_op, self.model_summary_op, self.inc_step] else: fetches = [self.train_op, self.inc_step] fetched = sess.run(fetches, feed_dict=feed_dict) if should_compute_summary: self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]), fetched[-1]) self.summary_writer.flush() self.local_steps += 1