def __init__(self, thread_index, global_net, local_net, initial_learning_rate, learning_rate_input, grad_applier, device=None, batch_size=None, use_rollout=False, one_buffer=False, sampleR=False): """Initialize A3CTrainingThread class.""" assert self.action_size != -1 self.is_sil_thread = True self.thread_idx = thread_index self.initial_learning_rate = initial_learning_rate self.learning_rate_input = learning_rate_input self.local_net = local_net self.batch_size = batch_size self.use_rollout = use_rollout self.one_buffer = one_buffer self.sampleR = sampleR logger.info("===SIL thread_index: {}===".format(self.thread_idx)) logger.info("device: {}".format(device)) logger.info("action_size: {}".format(self.action_size)) logger.info("entropy_beta: {}".format(self.entropy_beta)) logger.info("gamma: {}".format(self.gamma)) logger.info("reward_type: {}".format(self.reward_type)) logger.info("transformed_bellman: {}".format( colored(self.transformed_bellman, "green" if self.transformed_bellman else "red"))) logger.info("clip_norm: {}".format(self.clip_norm)) logger.info("use_grad_cam: {}".format( colored(self.use_grad_cam, "green" if self.use_grad_cam else "red"))) reward_clipped = True if self.reward_type == 'CLIP' else False local_vars = self.local_net.get_vars with tf.device(device): critic_lr = 0.1 entropy_beta = 0 w_loss = 1.0 logger.info("sil batch_size: {}".format(self.batch_size)) logger.info("sil w_loss: {}".format(w_loss)) logger.info("sil critic_lr: {}".format(critic_lr)) logger.info("sil entropy_beta: {}".format(entropy_beta)) self.local_net.prepare_sil_loss(entropy_beta=entropy_beta, w_loss=w_loss, critic_lr=critic_lr) var_refs = [v._ref() for v in local_vars()] self.sil_gradients = tf.gradients(self.local_net.total_loss_sil, var_refs) global_vars = global_net.get_vars with tf.device(device): if self.clip_norm is not None: self.sil_gradients, grad_norm = tf.clip_by_global_norm( self.sil_gradients, self.clip_norm) sil_gradients_global = list(zip(self.sil_gradients, global_vars())) sil_gradients_local = list(zip(self.sil_gradients, local_vars())) self.sil_apply_gradients = grad_applier.apply_gradients( sil_gradients_global) self.sil_apply_gradients_local = grad_applier.apply_gradients( sil_gradients_local) self.sync = self.local_net.sync_from(global_net) self.episode = SILReplayMemory(self.action_size, max_len=None, gamma=self.gamma, clip=reward_clipped, height=self.local_net.in_shape[0], width=self.local_net.in_shape[1], phi_length=self.local_net.in_shape[2], reward_constant=self.reward_constant) # temp_buffer for mixing and re-sample (brown arrow in Figure 1) # initial only when needed (A3CTBSIL & LiDER-OneBuffer does not need temp_buffer) self.temp_buffer = None if (self.use_rollout) and (not self.one_buffer): self.temp_buffer = SILReplayMemory( self.action_size, max_len=self.batch_size * 2, gamma=self.gamma, clip=reward_clipped, height=self.local_net.in_shape[0], width=self.local_net.in_shape[1], phi_length=self.local_net.in_shape[2], priority=True, reward_constant=self.reward_constant)
def run_a3c(args): """Run A3C experiment.""" GYM_ENV_NAME = args.gym_env.replace('-', '_') GAME_NAME = args.gym_env.replace('NoFrameskip-v4','') # setup folder name and path to folder folder = pathlib.Path(setup_folder(args, GYM_ENV_NAME)) # setup GPU (if applicable) import tensorflow as tf gpu_options = setup_gpu(tf, args.use_gpu, args.gpu_fraction) ###################################################### # setup default device device = "/cpu:0" global_t = 0 rewards = {'train': {}, 'eval': {}} best_model_reward = -(sys.maxsize) if args.load_pretrained_model: class_rewards = {'class_eval': {}} # setup logging info for analysis, see Section 4.2 of the paper sil_dict = { # count number of SIL updates "sil_ctr":{}, # total number of butter D sampled during SIL "sil_a3c_sampled":{}, # total number of buffer D samples (i.e., generated by A3C workers) used during SIL (i.e., passed max op) "sil_a3c_used":{}, # the return of used samples for buffer D "sil_a3c_used_return":{}, # total number of buffer R sampled during SIL "sil_rollout_sampled":{}, # total number of buffer R samples (i.e., generated by refresher worker) used during SIL (i.e., passed max op) "sil_rollout_used":{}, # the return of used samples for buffer R "sil_rollout_used_return":{}, # number of old samples still used (even after refreshing) "sil_old_used":{} } sil_ctr, sil_a3c_sampled, sil_a3c_used, sil_a3c_used_return = 0, 0, 0, 0 sil_rollout_sampled, sil_rollout_used, sil_rollout_used_return = 0, 0, 0 sil_old_used = 0 rollout_dict = { # total number of rollout performed "rollout_ctr": {}, # total number of successful rollout (i.e., Gnew > G) "rollout_added_ctr":{}, # the return of Gnew "rollout_new_return":{}, # the return of G "rollout_old_return":{} } rollout_ctr, rollout_added_ctr = 0, 0 rollout_new_return = 0 # this records the total, avg = this / rollout_added_ctr rollout_old_return = 0 # this records the total, avg = this / rollout_added_ctr # setup file names reward_fname = folder / '{}-a3c-rewards.pkl'.format(GYM_ENV_NAME) sil_fname = folder / '{}-a3c-dict-sil.pkl'.format(GYM_ENV_NAME) rollout_fname = folder / '{}-a3c-dict-rollout.pkl'.format(GYM_ENV_NAME) if args.load_pretrained_model: class_reward_fname = folder / '{}-class-rewards.pkl'.format(GYM_ENV_NAME) sharedmem_fname = folder / '{}-sharedmem.pkl'.format(GYM_ENV_NAME) sharedmem_params_fname = folder / '{}-sharedmem-params.pkl'.format(GYM_ENV_NAME) sharedmem_trees_fname = folder / '{}-sharedmem-trees.pkl'.format(GYM_ENV_NAME) rolloutmem_fname = folder / '{}-rolloutmem.pkl'.format(GYM_ENV_NAME) rolloutmem_params_fname = folder / '{}-rolloutmem-params.pkl'.format(GYM_ENV_NAME) rolloutmem_trees_fname = folder / '{}-rolloutmem-trees.pkl'.format(GYM_ENV_NAME) # for removing older ckpt, save mem space prev_ckpt_t = -1 stop_req = False game_state = GameState(env_id=args.gym_env) action_size = game_state.env.action_space.n game_state.close() del game_state.env del game_state input_shape = (args.input_shape, args.input_shape, 4) ####################################################### # setup global A3C GameACFFNetwork.use_mnih_2015 = args.use_mnih_2015 global_network = GameACFFNetwork( action_size, -1, device, padding=args.padding, in_shape=input_shape) logger.info('A3C Initial Learning Rate={}'.format(args.initial_learn_rate)) # setup pretrained model global_pretrained_model = None local_pretrained_model = None pretrain_graph = None # if use pretrained model to refresh # then must load pretrained model # otherwise, don't load model if args.use_lider and args.nstep_bc > 0: assert args.load_pretrained_model, "refreshing with other policies, must load a pre-trained model (TA or BC)" else: assert not args.load_pretrained_model, "refreshing with the current policy, don't load pre-trained models" if args.load_pretrained_model: pretrain_graph, global_pretrained_model = setup_pretrained_model(tf, args, action_size, input_shape, device="/gpu:0" if args.use_gpu else device) assert global_pretrained_model is not None assert pretrain_graph is not None time.sleep(2.0) # setup experience memory shared_memory = None # => this is BufferD rollout_buffer = None # => this is BufferR if args.use_sil: shared_memory = SILReplayMemory( action_size, max_len=args.memory_length, gamma=args.gamma, clip=False if args.unclipped_reward else True, height=input_shape[0], width=input_shape[1], phi_length=input_shape[2], priority=args.priority_memory, reward_constant=args.reward_constant) if args.use_lider and not args.onebuffer: rollout_buffer = SILReplayMemory( action_size, max_len=args.memory_length, gamma=args.gamma, clip=False if args.unclipped_reward else True, height=input_shape[0], width=input_shape[1], phi_length=input_shape[2], priority=args.priority_memory, reward_constant=args.reward_constant) # log memory information shared_memory.log() if args.use_lider and not args.onebuffer: rollout_buffer.log() ############## Setup Thread Workers BEGIN ################ # 17 total number of threads for all experiments assert args.parallel_size ==17, "use 17 workers for all experiments" startIndex = 0 all_workers = [] # a3c and sil learning rate and optimizer learning_rate_input = tf.placeholder(tf.float32, shape=(), name="opt_lr") grad_applier = tf.train.RMSPropOptimizer( learning_rate=learning_rate_input, decay=args.rmsp_alpha, epsilon=args.rmsp_epsilon) setup_common_worker(CommonWorker, args, action_size) # setup SIL worker sil_worker = None if args.use_sil: _device = "/gpu:0" if args.use_gpu else device sil_network = GameACFFNetwork( action_size, startIndex, device=_device, padding=args.padding, in_shape=input_shape) sil_worker = SILTrainingThread(startIndex, global_network, sil_network, args.initial_learn_rate, learning_rate_input, grad_applier, device=_device, batch_size=args.batch_size, use_rollout=args.use_lider, one_buffer=args.onebuffer, sampleR=args.sampleR) all_workers.append(sil_worker) startIndex += 1 # setup refresh worker refresh_worker = None if args.use_lider: _device = "/gpu:0" if args.use_gpu else device refresh_network = GameACFFNetwork( action_size, startIndex, device=_device, padding=args.padding, in_shape=input_shape) refresh_local_pretrained_model = None # if refreshing with other polies if args.nstep_bc > 0: refresh_local_pretrained_model = PretrainedModelNetwork( pretrain_graph, action_size, startIndex, padding=args.padding, in_shape=input_shape, sae=False, tied_weights=False, use_denoising=False, noise_factor=0.3, loss_function='mse', use_slv=False, device=_device) refresh_worker = RefreshThread( thread_index=startIndex, action_size=action_size, env_id=args.gym_env, global_a3c=global_network, local_a3c=refresh_network, update_in_rollout=args.update_in_rollout, nstep_bc=args.nstep_bc, global_pretrained_model=global_pretrained_model, local_pretrained_model=refresh_local_pretrained_model, transformed_bellman = args.transformed_bellman, device=_device, entropy_beta=args.entropy_beta, clip_norm=args.grad_norm_clip, grad_applier=grad_applier, initial_learn_rate=args.initial_learn_rate, learning_rate_input=learning_rate_input) all_workers.append(refresh_worker) startIndex += 1 # setup a3c workers setup_a3c_worker(A3CTrainingThread, args, startIndex) for i in range(startIndex, args.parallel_size): local_network = GameACFFNetwork( action_size, i, device="/cpu:0", padding=args.padding, in_shape=input_shape) a3c_worker = A3CTrainingThread( i, global_network, local_network, args.initial_learn_rate, learning_rate_input, grad_applier, device="/cpu:0", no_op_max=30) all_workers.append(a3c_worker) ############## Setup Thread Workers END ################ # setup config for tensorflow config = tf.ConfigProto( gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True) # prepare sessions sess = tf.Session(config=config) pretrain_sess = None if global_pretrained_model: pretrain_sess = tf.Session(config=config, graph=pretrain_graph) # initial pretrained model if pretrain_sess: assert args.pretrained_model_folder is not None global_pretrained_model.load( pretrain_sess, args.pretrained_model_folder) sess.run(tf.global_variables_initializer()) if global_pretrained_model: initialize_uninitialized(tf, pretrain_sess, global_pretrained_model) if local_pretrained_model: initialize_uninitialized(tf, pretrain_sess, local_pretrained_model) # summary writer for tensorboard summ_file = args.save_to+'log/a3c/{}/'.format(GYM_ENV_NAME) + str(folder)[58:] # str(folder)[12:] summary_writer = tf.summary.FileWriter(summ_file, sess.graph) # init or load checkpoint with saver root_saver = tf.train.Saver(max_to_keep=1) saver = tf.train.Saver(max_to_keep=1) best_saver = tf.train.Saver(max_to_keep=1) checkpoint = tf.train.get_checkpoint_state(str(folder)+'/model_checkpoints') if checkpoint and checkpoint.model_checkpoint_path: root_saver.restore(sess, checkpoint.model_checkpoint_path) logger.info("checkpoint loaded:{}".format( checkpoint.model_checkpoint_path)) tokens = checkpoint.model_checkpoint_path.split("-") # set global step global_t = int(tokens[-1]) logger.info(">>> global step set: {}".format(global_t)) tmp_t = (global_t // args.eval_freq) * args.eval_freq logger.info(">>> tmp_t: {}".format(tmp_t)) # set wall time wall_t = 0. # set up reward files best_reward_file = folder / 'model_best/best_model_reward' with best_reward_file.open('r') as f: best_model_reward = float(f.read()) # restore rewards rewards = restore_dict(reward_fname, global_t) logger.info(">>> restored: rewards") # restore loggings sil_dict = restore_dict(sil_fname, global_t) sil_ctr = sil_dict['sil_ctr'][tmp_t] sil_a3c_sampled = sil_dict['sil_a3c_sampled'][tmp_t] sil_a3c_used = sil_dict['sil_a3c_used'][tmp_t] sil_a3c_used_return = sil_dict['sil_a3c_used_return'][tmp_t] sil_rollout_sampled = sil_dict['sil_rollout_sampled'][tmp_t] sil_rollout_used = sil_dict['sil_rollout_used'][tmp_t] sil_rollout_used_return = sil_dict['sil_rollout_used_return'][tmp_t] sil_old_used = sil_dict['sil_old_used'][tmp_t] logger.info(">>> restored: sil_dict") rollout_dict = restore_dict(rollout_fname, global_t) rollout_ctr = rollout_dict['rollout_ctr'][tmp_t] rollout_added_ctr = rollout_dict['rollout_added_ctr'][tmp_t] rollout_new_return = rollout_dict['rollout_new_return'][tmp_t] rollout_old_return = rollout_dict['rollout_old_return'][tmp_t] logger.info(">>> restored: rollout_dict") if args.load_pretrained_model: class_reward_file = folder / '{}-class-rewards.pkl'.format(GYM_ENV_NAME) class_rewards = restore_dict(class_reward_file, global_t) # restore replay buffers (if saved) if args.checkpoint_buffer: # restore buffer D if args.use_sil and args.priority_memory: shared_memory = restore_buffer(sharedmem_fname, shared_memory, global_t) shared_memory = restore_buffer_trees(sharedmem_trees_fname, shared_memory, global_t) shared_memory = restore_buffer_params(sharedmem_params_fname, shared_memory, global_t) logger.info(">>> restored: shared_memory (Buffer D)") shared_memory.log() # restore buffer R if args.use_lider and not args.onebuffer: rollout_buffer = restore_buffer(rolloutmem_fname, rollout_buffer, global_t) rollout_buffer = restore_buffer_trees(rolloutmem_trees_fname, rollout_buffer, global_t) rollout_buffer = restore_buffer_params(rolloutmem_params_fname, rollout_buffer, global_t) logger.info(">>> restored: rollout_buffer (Buffer R)") rollout_buffer.log() # if all restores okay, remove old ckpt to save storage space prev_ckpt_t = global_t else: logger.warning("Could not find old checkpoint") wall_t = 0.0 prepare_dir(folder, empty=True) prepare_dir(folder / 'model_checkpoints', empty=True) prepare_dir(folder / 'model_best', empty=True) prepare_dir(folder / 'frames', empty=True) lock = threading.Lock() # next saving global_t def next_t(current_t, freq): return np.ceil((current_t + 0.00001) / freq) * freq next_global_t = next_t(global_t, args.eval_freq) next_save_t = next_t( global_t, args.eval_freq*args.checkpoint_freq) step_t = 0 def train_function(parallel_idx, th_ctr, ep_queue, net_updates): nonlocal global_t, step_t, rewards, class_rewards, lock, \ next_save_t, next_global_t, prev_ckpt_t nonlocal shared_memory, rollout_buffer nonlocal sil_dict, sil_ctr, sil_a3c_sampled, sil_a3c_used, sil_a3c_used_return, \ sil_rollout_sampled, sil_rollout_used, sil_rollout_used_return, \ sil_old_used nonlocal rollout_dict, rollout_ctr, rollout_added_ctr, \ rollout_new_return, rollout_old_return parallel_worker = all_workers[parallel_idx] parallel_worker.set_summary_writer(summary_writer) with lock: # Evaluate model before training if not stop_req and global_t == 0 and step_t == 0: rewards['eval'][step_t] = parallel_worker.testing( sess, args.eval_max_steps, global_t, folder, worker=all_workers[-1]) # testing pretrained TA or BC in game if args.load_pretrained_model: assert pretrain_sess is not None assert global_pretrained_model is not None class_rewards['class_eval'][step_t] = \ parallel_worker.test_loaded_classifier(global_t=global_t, max_eps=50, # testing 50 episodes sess=pretrain_sess, worker=all_workers[-1], model=global_pretrained_model) # log pretrained model performance class_eval_file = pathlib.Path(args.pretrained_model_folder[:21]+\ str(GAME_NAME)+"/"+str(GAME_NAME)+'-model-eval.txt') class_std = np.std(class_rewards['class_eval'][step_t][-1]) class_mean = np.mean(class_rewards['class_eval'][step_t][-1]) with class_eval_file.open('w') as f: f.write("class_mean: \n" + str(class_mean) + "\n") f.write("class_std: \n" + str(class_std) + "\n") f.write("class_rewards: \n" + str(class_rewards['class_eval'][step_t][-1]) + "\n") checkpt_file = folder / 'model_checkpoints' checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME) saver.save(sess, str(checkpt_file), global_step=global_t) save_best_model(rewards['eval'][global_t][0]) # saving worker info to dicts for analysis sil_dict['sil_ctr'][step_t] = sil_ctr sil_dict['sil_a3c_sampled'][step_t] = sil_a3c_sampled sil_dict['sil_a3c_used'][step_t] = sil_a3c_used sil_dict['sil_a3c_used_return'][step_t] = sil_a3c_used_return sil_dict['sil_rollout_sampled'][step_t] = sil_rollout_sampled sil_dict['sil_rollout_used'][step_t] = sil_rollout_used sil_dict['sil_rollout_used_return'][step_t] = sil_rollout_used_return sil_dict['sil_old_used'][step_t] = sil_old_used rollout_dict['rollout_ctr'][step_t] = rollout_ctr rollout_dict['rollout_added_ctr'][step_t] = rollout_added_ctr rollout_dict['rollout_new_return'][step_t] = rollout_new_return rollout_dict['rollout_old_return'][step_t] = rollout_old_return # dump pickle dump_pickle([rewards, sil_dict, rollout_dict], [reward_fname, sil_fname, rollout_fname], global_t) if args.load_pretrained_model: dump_pickle([class_rewards], [class_reward_fname], global_t) logger.info('Dump pickle at step {}'.format(global_t)) # save replay buffer (only works under priority mem) if args.checkpoint_buffer: if shared_memory is not None and args.priority_memory: params = [shared_memory.buff._next_idx, shared_memory.buff._max_priority] trees = [shared_memory.buff._it_sum._value, shared_memory.buff._it_min._value] dump_pickle([shared_memory.buff._storage, params, trees], [sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname], global_t) logger.info('Saving shared_memory') if rollout_buffer is not None and args.priority_memory: params = [rollout_buffer.buff._next_idx, rollout_buffer.buff._max_priority] trees = [rollout_buffer.buff._it_sum._value, rollout_buffer.buff._it_min._value] dump_pickle([rollout_buffer.buff._storage, params, trees], [rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname], global_t) logger.info('Saving rollout_buffer') prev_ckpt_t = global_t step_t = 1 # set start_time start_time = time.time() - wall_t parallel_worker.set_start_time(start_time) if parallel_worker.is_sil_thread: sil_interval = 0 # bigger number => slower SIL updates m_repeat = 4 min_mem = args.batch_size * m_repeat sil_train_flag = len(shared_memory) >= min_mem while True: if stop_req: return if global_t >= (args.max_time_step * args.max_time_step_fraction): return if parallel_worker.is_sil_thread: # before sil starts, init local count local_sil_ctr = 0 local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0 local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0 local_sil_old_used = 0 if net_updates.qsize() >= sil_interval \ and len(shared_memory) >= min_mem: sil_train_flag = True if sil_train_flag: sil_train_flag = False th_ctr.get() train_out = parallel_worker.sil_train( sess, global_t, shared_memory, m_repeat, rollout_buffer=rollout_buffer) local_sil_ctr, local_sil_a3c_sampled, local_sil_a3c_used, \ local_sil_a3c_used_return, \ local_sil_rollout_sampled, local_sil_rollout_used, \ local_sil_rollout_used_return, \ local_sil_old_used = train_out th_ctr.put(1) with net_updates.mutex: net_updates.queue.clear() if args.use_lider: parallel_worker.record_sil(sil_ctr=sil_ctr, total_used=(sil_a3c_used + sil_rollout_used), num_a3c_used=sil_a3c_used, a3c_used_return=sil_a3c_used_return/(sil_a3c_used+1),#add one in case divide by zero rollout_used=sil_rollout_used, rollout_used_return=sil_rollout_used_return/(sil_rollout_used+1), old_used=sil_old_used, global_t=global_t) if sil_ctr % 200 == 0 and sil_ctr > 0: rollout_buffsize = 0 if not args.onebuffer: rollout_buffsize = len(rollout_buffer) log_data = (sil_ctr, len(shared_memory), rollout_buffsize, sil_a3c_used+sil_rollout_used, args.batch_size*sil_ctr, sil_a3c_used, sil_a3c_used_return/(sil_a3c_used+1), sil_rollout_used, sil_rollout_used_return/(sil_rollout_used+1), sil_old_used) logger.info("SIL: sil_ctr={0:}" " sil_memory_size={1:}" " rollout_buffer_size={2:}" " total_sample_used={3:}/{4:}" " a3c_used={5:}" " a3c_used_return_avg={6:.2f}" " rollout_used={7:}" " rollout_used_return_avg={8:.2f}" " old_used={9:}".format(*log_data)) else: parallel_worker.record_sil(sil_ctr=sil_ctr, total_used=(sil_a3c_used + sil_rollout_used), num_a3c_used=sil_a3c_used, rollout_used=sil_rollout_used, global_t=global_t) if sil_ctr % 200 == 0 and sil_ctr > 0: log_data = (sil_ctr, sil_a3c_used+sil_rollout_used, args.batch_size*sil_ctr, sil_a3c_used, len(shared_memory)) logger.info("SIL: sil_ctr={0:}" " total_sample_used={1:}/{2:}" " a3c_used={3:}" " sil_memory_size={4:}".format(*log_data)) # Adding episodes to SIL memory is centralize to ensure # sampling and updating of priorities does not become a problem # since we add new episodes to SIL at once and during # SIL training it is guaranteed that SIL memory is untouched. max = args.parallel_size while not ep_queue.empty(): data = ep_queue.get() parallel_worker.episode.set_data(*data) shared_memory.extend(parallel_worker.episode) parallel_worker.episode.reset() max -= 1 if max <= 0: # This ensures that SIL has a chance to train break diff_global_t = 0 # centralized rollout counting local_rollout_ctr, local_rollout_added_ctr = 0, 0 local_rollout_new_return, local_rollout_old_return = 0, 0 elif parallel_worker.is_refresh_thread: # before refresh starts, init local count diff_global_t = 0 local_rollout_ctr, local_rollout_added_ctr = 0, 0 local_rollout_new_return, local_rollout_old_return = 0, 0 if len(shared_memory) >= 1: th_ctr.get() # randomly sample a state from buffer D sample = shared_memory.sample_one_random() # after sample, flip refreshed to True # TODO: fix this so that only *succesful* refresh is flipped to True # currently counting *all* refresh as True assert sample[-1] == True train_out = parallel_worker.rollout(sess, folder, pretrain_sess, global_t, sample, args.addall, args.max_ep_step, args.nstep_bc, args.update_in_rollout) diff_global_t, episode_end, part_end, local_rollout_ctr, \ local_rollout_added_ctr, add, local_rollout_new_return, \ local_rollout_old_return = train_out th_ctr.put(1) if rollout_ctr % 20 == 0 and rollout_ctr > 0: log_msg = "ROLLOUT: rollout_ctr={} added_rollout_ct={} worker={}".format( rollout_ctr, rollout_added_ctr, parallel_worker.thread_idx) logger.info(log_msg) logger.info("ROLLOUT Gnew: {}, G: {}".format(local_rollout_new_return, local_rollout_old_return)) # should always part_end, i.e., end of episode # and only add if new return is better (if not LiDER-AddAll) if part_end and add: if not args.onebuffer: # directly put into Buffer R rollout_buffer.extend(parallel_worker.episode) else: # Buffer D add sample is centralized when OneBuffer ep_queue.put(parallel_worker.episode.get_data()) parallel_worker.episode.reset() # centralized SIL counting local_sil_ctr = 0 local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0 local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0 local_sil_old_used = 0 # a3c training thread worker else: th_ctr.get() train_out = parallel_worker.train(sess, global_t, rewards) diff_global_t, episode_end, part_end = train_out th_ctr.put(1) if args.use_sil: net_updates.put(1) if part_end: ep_queue.put(parallel_worker.episode.get_data()) parallel_worker.episode.reset() # centralized SIL counting local_sil_ctr = 0 local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0 local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0 local_sil_old_used = 0 # centralized rollout counting local_rollout_ctr, local_rollout_added_ctr = 0, 0 local_rollout_new_return, local_rollout_old_return = 0, 0 # ensure only one thread is updating global_t at a time with lock: global_t += diff_global_t # centralize increasing count for SIL and Rollout sil_ctr += local_sil_ctr sil_a3c_sampled += local_sil_a3c_sampled sil_a3c_used += local_sil_a3c_used sil_a3c_used_return += local_sil_a3c_used_return sil_rollout_sampled += local_sil_rollout_sampled sil_rollout_used += local_sil_rollout_used sil_rollout_used_return += local_sil_rollout_used_return sil_old_used += local_sil_old_used rollout_ctr += local_rollout_ctr rollout_added_ctr += local_rollout_added_ctr rollout_new_return += local_rollout_new_return rollout_old_return += local_rollout_old_return # if during a thread's update, global_t has reached a evaluation interval if global_t > next_global_t: next_global_t = next_t(global_t, args.eval_freq) step_t = int(next_global_t - args.eval_freq) # wait for all threads to be done before testing while not stop_req and th_ctr.qsize() < len(all_workers): time.sleep(0.001) step_t = int(next_global_t - args.eval_freq) # Evaluate for 125,000 steps rewards['eval'][step_t] = parallel_worker.testing( sess, args.eval_max_steps, step_t, folder, worker=all_workers[-1]) save_best_model(rewards['eval'][step_t][0]) last_reward = rewards['eval'][step_t][0] # saving worker info to dicts # SIL sil_dict['sil_ctr'][step_t] = sil_ctr sil_dict['sil_a3c_sampled'][step_t] = sil_a3c_sampled sil_dict['sil_a3c_used'][step_t] = sil_a3c_used sil_dict['sil_a3c_used_return'][step_t] = sil_a3c_used_return sil_dict['sil_rollout_sampled'][step_t] = sil_rollout_sampled sil_dict['sil_rollout_used'][step_t] = sil_rollout_used sil_dict['sil_rollout_used_return'][step_t] = sil_rollout_used_return sil_dict['sil_old_used'][step_t] = sil_old_used # ROLLOUT rollout_dict['rollout_ctr'][step_t] = rollout_ctr rollout_dict['rollout_added_ctr'][step_t] = rollout_added_ctr rollout_dict['rollout_new_return'][step_t] = rollout_new_return rollout_dict['rollout_old_return'][step_t] = rollout_old_return # save ckpt after done with eval if global_t > next_save_t: next_save_t = next_t(global_t, args.eval_freq*args.checkpoint_freq) # dump pickle dump_pickle([rewards, sil_dict, rollout_dict], [reward_fname, sil_fname, rollout_fname], global_t) if args.load_pretrained_model: dump_pickle([class_rewards], [class_reward_fname], global_t) logger.info('Dump pickle at step {}'.format(global_t)) # save replay buffer (only works for priority mem for now) if args.checkpoint_buffer: if shared_memory is not None and args.priority_memory: params = [shared_memory.buff._next_idx, shared_memory.buff._max_priority] trees = [shared_memory.buff._it_sum._value, shared_memory.buff._it_min._value] dump_pickle([shared_memory.buff._storage, params, trees], [sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname], global_t) logger.info('Saved shared_memory') if rollout_buffer is not None and args.priority_memory: params = [rollout_buffer.buff._next_idx, rollout_buffer.buff._max_priority] trees = [rollout_buffer.buff._it_sum._value, rollout_buffer.buff._it_min._value] dump_pickle([rollout_buffer.buff._storage, params, trees], [rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname], global_t) logger.info('Saved rollout_buffer') # save a3c after saving buffer -- in case saving buffer OOM # so that at least we can revert back to the previous ckpt checkpt_file = folder / 'model_checkpoints' checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME) saver.save(sess, str(checkpt_file), global_step=global_t, write_meta_graph=False) logger.info('Saved model ckpt') # if everything saves okay, clean up previous ckpt to save space remove_pickle([reward_fname, sil_fname, rollout_fname], prev_ckpt_t) if args.load_pretrained_model: remove_pickle([class_reward_fname], prev_ckpt_t) remove_pickle([sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname], prev_ckpt_t) if rollout_buffer is not None and args.priority_memory: remove_pickle([rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname], prev_ckpt_t) logger.info('Removed ckpt from step {}'.format(prev_ckpt_t)) prev_ckpt_t = global_t def signal_handler(signal, frame): nonlocal stop_req logger.info('You pressed Ctrl+C!') stop_req = True if stop_req and global_t == 0: sys.exit(1) def save_best_model(test_reward): nonlocal best_model_reward if test_reward > best_model_reward: best_model_reward = test_reward best_reward_file = folder / 'model_best/best_model_reward' with best_reward_file.open('w') as f: f.write(str(best_model_reward)) best_checkpt_file = folder / 'model_best' best_checkpt_file /= '{}_checkpoint'.format(GYM_ENV_NAME) best_saver.save(sess, str(best_checkpt_file)) train_threads = [] th_ctr = Queue() for i in range(args.parallel_size): th_ctr.put(1) episodes_queue = None net_updates = None if args.use_sil: episodes_queue = Queue() net_updates = Queue() for i in range(args.parallel_size): worker_thread = Thread( target=train_function, args=(i, th_ctr, episodes_queue, net_updates,)) train_threads.append(worker_thread) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) # set start time start_time = time.time() - wall_t for t in train_threads: t.start() print('Press Ctrl+C to stop') for t in train_threads: t.join() logger.info('Now saving data. Please wait') # write wall time wall_t = time.time() - start_time wall_t_fname = folder / 'wall_t.{}'.format(global_t) with wall_t_fname.open('w') as f: f.write(str(wall_t)) # save final model checkpoint_file = str(folder / '{}_checkpoint_a3c'.format(GYM_ENV_NAME)) root_saver.save(sess, checkpoint_file, global_step=global_t) dump_final_pickle([rewards, sil_dict, rollout_dict], [reward_fname, sil_fname, rollout_fname]) logger.info('Data saved!') # if everything saves okay & is done training (not because of pressed Ctrl+C), # clean up previous ckpt to save space if global_t >= (args.max_time_step * args.max_time_step_fraction): remove_pickle([reward_fname, sil_fname, rollout_fname], prev_ckpt_t) if args.load_pretrained_model: remove_pickle([class_reward_fname], prev_ckpt_t) remove_pickle([sharedmem_fname, sharedmem_params_fname, sharedmem_trees_fname], prev_ckpt_t) if rollout_buffer is not None and args.priority_memory: remove_pickle([rolloutmem_fname, rolloutmem_params_fname, rolloutmem_trees_fname], prev_ckpt_t) logger.info('Done training, removed ckpt from step {}'.format(prev_ckpt_t)) sess.close() if pretrain_sess: pretrain_sess.close()
class SILTrainingThread(CommonWorker): """Asynchronous Actor-Critic Training Thread Class.""" entropy_beta = 0.01 gamma = 0.99 finetune_upper_layers_only = False transformed_bellman = False clip_norm = 0.5 use_grad_cam = False def __init__(self, thread_index, global_net, local_net, initial_learning_rate, learning_rate_input, grad_applier, device=None, batch_size=None, use_rollout=False, one_buffer=False, sampleR=False): """Initialize A3CTrainingThread class.""" assert self.action_size != -1 self.is_sil_thread = True self.thread_idx = thread_index self.initial_learning_rate = initial_learning_rate self.learning_rate_input = learning_rate_input self.local_net = local_net self.batch_size = batch_size self.use_rollout = use_rollout self.one_buffer = one_buffer self.sampleR = sampleR logger.info("===SIL thread_index: {}===".format(self.thread_idx)) logger.info("device: {}".format(device)) logger.info("action_size: {}".format(self.action_size)) logger.info("entropy_beta: {}".format(self.entropy_beta)) logger.info("gamma: {}".format(self.gamma)) logger.info("reward_type: {}".format(self.reward_type)) logger.info("transformed_bellman: {}".format( colored(self.transformed_bellman, "green" if self.transformed_bellman else "red"))) logger.info("clip_norm: {}".format(self.clip_norm)) logger.info("use_grad_cam: {}".format( colored(self.use_grad_cam, "green" if self.use_grad_cam else "red"))) reward_clipped = True if self.reward_type == 'CLIP' else False local_vars = self.local_net.get_vars with tf.device(device): critic_lr = 0.1 entropy_beta = 0 w_loss = 1.0 logger.info("sil batch_size: {}".format(self.batch_size)) logger.info("sil w_loss: {}".format(w_loss)) logger.info("sil critic_lr: {}".format(critic_lr)) logger.info("sil entropy_beta: {}".format(entropy_beta)) self.local_net.prepare_sil_loss(entropy_beta=entropy_beta, w_loss=w_loss, critic_lr=critic_lr) var_refs = [v._ref() for v in local_vars()] self.sil_gradients = tf.gradients(self.local_net.total_loss_sil, var_refs) global_vars = global_net.get_vars with tf.device(device): if self.clip_norm is not None: self.sil_gradients, grad_norm = tf.clip_by_global_norm( self.sil_gradients, self.clip_norm) sil_gradients_global = list(zip(self.sil_gradients, global_vars())) sil_gradients_local = list(zip(self.sil_gradients, local_vars())) self.sil_apply_gradients = grad_applier.apply_gradients( sil_gradients_global) self.sil_apply_gradients_local = grad_applier.apply_gradients( sil_gradients_local) self.sync = self.local_net.sync_from(global_net) self.episode = SILReplayMemory(self.action_size, max_len=None, gamma=self.gamma, clip=reward_clipped, height=self.local_net.in_shape[0], width=self.local_net.in_shape[1], phi_length=self.local_net.in_shape[2], reward_constant=self.reward_constant) # temp_buffer for mixing and re-sample (brown arrow in Figure 1) # initial only when needed (A3CTBSIL & LiDER-OneBuffer does not need temp_buffer) self.temp_buffer = None if (self.use_rollout) and (not self.one_buffer): self.temp_buffer = SILReplayMemory( self.action_size, max_len=self.batch_size * 2, gamma=self.gamma, clip=reward_clipped, height=self.local_net.in_shape[0], width=self.local_net.in_shape[1], phi_length=self.local_net.in_shape[2], priority=True, reward_constant=self.reward_constant) def record_sil(self, sil_ctr=0, total_used=0, num_a3c_used=0, a3c_used_return=0, rollout_used=0, rollout_used_return=0, old_used=0, global_t=0, mode='SIL'): """Record SIL.""" summary = tf.Summary() summary.value.add(tag='{}/sil_ctr'.format(mode), simple_value=float(sil_ctr)) summary.value.add(tag='{}/total_num_sample_used'.format(mode), simple_value=float(total_used)) summary.value.add(tag='{}/num_a3c_used'.format(mode), simple_value=float(num_a3c_used)) summary.value.add(tag='{}/a3c_used_return'.format(mode), simple_value=float(a3c_used_return)) summary.value.add(tag='{}/num_rollout_used'.format(mode), simple_value=float(rollout_used_return)) summary.value.add(tag='{}/rollout_used_return'.format(mode), simple_value=float(rollout_used)) summary.value.add(tag='{}/num_old_used'.format(mode), simple_value=float(old_used)) self.writer.add_summary(summary, global_t) self.writer.flush() def sil_train(self, sess, global_t, sil_memory, m, rollout_buffer=None): """Self-imitation learning process.""" # copy weights from shared to local sess.run(self.sync) cur_learning_rate = self._anneal_learning_rate( global_t, self.initial_learning_rate) local_sil_ctr = 0 local_sil_a3c_sampled, local_sil_a3c_used, local_sil_a3c_used_return = 0, 0, 0 local_sil_rollout_sampled, local_sil_rollout_used, local_sil_rollout_used_return = 0, 0, 0 local_sil_old_used = 0 total_used = 0 num_a3c_used = 0 num_rollout_used = 0 num_rollout_sampled = 0 num_old_used = 0 for _ in range(m): d_batch_size, r_batch_size = 0, 0 # A3CTBSIL if not self.use_rollout: d_batch_size = self.batch_size # or LiDER-OneBuffer elif self.use_rollout and self.one_buffer: d_batch_size = self.batch_size # or LiDER else: assert rollout_buffer is not None assert self.temp_buffer is not None self.temp_buffer.reset() if not self.sampleR: # otherwise, LiDER-SampleR d_batch_size = self.batch_size r_batch_size = self.batch_size batch_state, batch_action, batch_returns, batch_fullstate, \ batch_rollout, batch_refresh, weights = ([] for i in range(7)) # sample from buffer D if d_batch_size > 0 and len(sil_memory) > d_batch_size: d_sample = sil_memory.sample(d_batch_size, beta=0.4) d_index_list, d_batch, d_weights = d_sample d_batch_state, d_action, d_batch_returns, \ d_batch_fullstate, d_batch_rollout, d_batch_refresh = d_batch # update priority of sampled experiences self.update_priorities_once(sess, sil_memory, d_index_list, d_batch_state, d_action, d_batch_returns) if self.temp_buffer is not None: # when LiDER d_batch_action = convert_onehot_to_a(d_action) self.temp_buffer.extend_one_priority( d_batch_state, d_batch_fullstate, d_batch_action, d_batch_returns, d_batch_rollout, d_batch_refresh) else: # when A3CTBSIL or LiDER-OneBuffer batch_state.extend(d_batch_state) batch_action.extend(d_action) batch_returns.extend(d_batch_returns) batch_fullstate.extend(d_batch_fullstate) batch_rollout.extend(d_batch_rollout) batch_refresh.extend(d_batch_refresh) weights.extend(d_weights) # sample from buffer R if r_batch_size > 0 and len(rollout_buffer) > r_batch_size: r_sample = rollout_buffer.sample(r_batch_size, beta=0.4) r_index_list, r_batch, r_weights = r_sample r_batch_state, r_action, r_batch_returns, \ r_batch_fullstate, r_batch_rollout, r_batch_refresh = r_batch # update priority of sampled experiences self.update_priorities_once(sess, rollout_buffer, r_index_list, r_batch_state, r_action, r_batch_returns) if self.temp_buffer is not None: # when LiDER r_batch_action = convert_onehot_to_a(r_action) self.temp_buffer.extend_one_priority( r_batch_state, r_batch_fullstate, r_batch_action, r_batch_returns, r_batch_rollout, r_batch_refresh) else: # when A3CTBSIL or LiDER-OneBuffer batch_state.extend(r_batch_state) batch_action.extend(r_action) batch_returns.extend(r_batch_returns) batch_fullstate.extend(r_batch_fullstate) batch_rollout.extend(r_batch_rollout) batch_refresh.extend(r_batch_refresh) weights.extend(r_weights) # LiDER only: pick 32 out of mixed # (at the beginning the 32 could all from buffer D since rollout has no data yet) # make sure the temp_buffer has been filled with at least size of one batch before sampling if self.temp_buffer is not None and len( self.temp_buffer) >= self.batch_size: sample = self.temp_buffer.sample(self.batch_size, beta=0.4) index_list, batch, weights = sample # overwrite the initial empty list batch_state, batch_action, batch_returns, \ batch_fullstate, batch_rollout, batch_refresh = batch if self.use_rollout: num_rollout_sampled += np.sum(batch_rollout) # sil policy update (if one full batch is sampled) if len(batch_state) == self.batch_size: feed_dict = { self.local_net.s: batch_state, self.local_net.a_sil: batch_action, self.local_net.returns: batch_returns, self.local_net.weights: weights, self.learning_rate_input: cur_learning_rate, } fetch = [ self.local_net.clipped_advs, self.local_net.advs, self.sil_apply_gradients, self.sil_apply_gradients_local, ] adv_clip, adv, _, _ = sess.run(fetch, feed_dict=feed_dict) pos_idx = [i for (i, num) in enumerate(adv) if num > 0] neg_idx = [i for (i, num) in enumerate(adv) if num <= 0] # log number of samples used for SIL updates total_used += len(pos_idx) num_rollout_used += np.sum(np.take(batch_rollout, pos_idx)) num_a3c_used += (len(pos_idx) - np.sum(np.take(batch_rollout, pos_idx))) num_old_used += np.sum(np.take(batch_refresh, pos_idx)) # return for used rollout samples rollout_idx = [ i for (i, num) in enumerate(batch_rollout) if num > 0 ] pos_rollout_idx = np.intersect1d(rollout_idx, pos_idx) if len(pos_rollout_idx) > 0: local_sil_rollout_used_return += np.sum( np.take(adv, pos_rollout_idx)) # return for used a3c samples a3c_idx = [ i for (i, num) in enumerate(batch_rollout) if num <= 0 ] pos_a3c_idx = np.intersect1d(a3c_idx, pos_idx) if len(pos_a3c_idx) > 0: local_sil_a3c_used_return += np.sum( np.take(batch_returns, pos_a3c_idx)) local_sil_ctr += 1 local_sil_a3c_sampled += (self.batch_size * m - num_rollout_sampled) local_sil_rollout_sampled += num_rollout_sampled local_sil_a3c_used += num_a3c_used local_sil_rollout_used += num_rollout_used local_sil_old_used += num_old_used return local_sil_ctr, local_sil_a3c_sampled, local_sil_a3c_used, \ local_sil_a3c_used_return, \ local_sil_rollout_sampled, local_sil_rollout_used, \ local_sil_rollout_used_return, \ local_sil_old_used def update_priorities_once(self, sess, memory, index_list, batch_state, batch_action, batch_returns): """Self-imitation update priorities once.""" # copy weights from shared to local sess.run(self.sync) feed_dict = { self.local_net.s: batch_state, self.local_net.a_sil: batch_action, self.local_net.returns: batch_returns, } fetch = self.local_net.clipped_advs adv_clip = sess.run(fetch, feed_dict=feed_dict) memory.set_weights(index_list, adv_clip)
def __init__(self, thread_index, global_net, local_net, initial_learning_rate, learning_rate_input, grad_applier, device=None, no_op_max=30): """Initialize A3CTrainingThread class.""" assert self.action_size != -1 self.is_sil_thread = False self.is_refresh_thread = False self.thread_idx = thread_index self.learning_rate_input = learning_rate_input self.local_net = local_net self.no_op_max = no_op_max self.override_num_noops = 0 if self.no_op_max == 0 else None logger.info("===A3C thread_index: {}===".format(self.thread_idx)) logger.info("device: {}".format(device)) logger.info("use_sil: {}".format( colored(self.use_sil, "green" if self.use_sil else "red"))) logger.info("local_t_max: {}".format(self.local_t_max)) logger.info("action_size: {}".format(self.action_size)) logger.info("entropy_beta: {}".format(self.entropy_beta)) logger.info("gamma: {}".format(self.gamma)) logger.info("reward_type: {}".format(self.reward_type)) logger.info("transformed_bellman: {}".format( colored(self.transformed_bellman, "green" if self.transformed_bellman else "red"))) logger.info("clip_norm: {}".format(self.clip_norm)) logger.info("use_grad_cam: {}".format( colored(self.use_grad_cam, "green" if self.use_grad_cam else "red"))) reward_clipped = True if self.reward_type == 'CLIP' else False local_vars = self.local_net.get_vars with tf.device(device): self.local_net.prepare_loss(entropy_beta=self.entropy_beta, critic_lr=0.5) var_refs = [v._ref() for v in local_vars()] self.gradients = tf.gradients(self.local_net.total_loss, var_refs) global_vars = global_net.get_vars with tf.device(device): if self.clip_norm is not None: self.gradients, grad_norm = tf.clip_by_global_norm( self.gradients, self.clip_norm) self.gradients = list(zip(self.gradients, global_vars())) self.apply_gradients = grad_applier.apply_gradients(self.gradients) self.sync = self.local_net.sync_from(global_net) self.game_state = GameState(env_id=self.env_id, display=False, no_op_max=self.no_op_max, human_demo=False, episode_life=True, override_num_noops=self.override_num_noops) self.local_t = 0 self.initial_learning_rate = initial_learning_rate self.episode_reward = 0 self.episode_steps = 0 # variable controlling log output self.prev_local_t = 0 with tf.device(device): if self.use_grad_cam: self.action_meaning = self.game_state.env.unwrapped \ .get_action_meanings() self.local_net.build_grad_cam_grads() if self.use_sil: self.episode = SILReplayMemory( self.action_size, max_len=None, gamma=self.gamma, clip=reward_clipped, height=self.local_net.in_shape[0], width=self.local_net.in_shape[1], phi_length=self.local_net.in_shape[2], reward_constant=self.reward_constant)
class A3CTrainingThread(CommonWorker): """Asynchronous Actor-Critic Training Thread Class.""" log_interval = 100 perf_log_interval = 1000 local_t_max = 20 entropy_beta = 0.01 gamma = 0.99 shaping_actions = -1 # -1 all actions, 0 exclude noop transformed_bellman = False clip_norm = 0.5 use_grad_cam = False use_sil = False log_idx = 0 reward_constant = 0 def __init__(self, thread_index, global_net, local_net, initial_learning_rate, learning_rate_input, grad_applier, device=None, no_op_max=30): """Initialize A3CTrainingThread class.""" assert self.action_size != -1 self.is_sil_thread = False self.is_refresh_thread = False self.thread_idx = thread_index self.learning_rate_input = learning_rate_input self.local_net = local_net self.no_op_max = no_op_max self.override_num_noops = 0 if self.no_op_max == 0 else None logger.info("===A3C thread_index: {}===".format(self.thread_idx)) logger.info("device: {}".format(device)) logger.info("use_sil: {}".format( colored(self.use_sil, "green" if self.use_sil else "red"))) logger.info("local_t_max: {}".format(self.local_t_max)) logger.info("action_size: {}".format(self.action_size)) logger.info("entropy_beta: {}".format(self.entropy_beta)) logger.info("gamma: {}".format(self.gamma)) logger.info("reward_type: {}".format(self.reward_type)) logger.info("transformed_bellman: {}".format( colored(self.transformed_bellman, "green" if self.transformed_bellman else "red"))) logger.info("clip_norm: {}".format(self.clip_norm)) logger.info("use_grad_cam: {}".format( colored(self.use_grad_cam, "green" if self.use_grad_cam else "red"))) reward_clipped = True if self.reward_type == 'CLIP' else False local_vars = self.local_net.get_vars with tf.device(device): self.local_net.prepare_loss(entropy_beta=self.entropy_beta, critic_lr=0.5) var_refs = [v._ref() for v in local_vars()] self.gradients = tf.gradients(self.local_net.total_loss, var_refs) global_vars = global_net.get_vars with tf.device(device): if self.clip_norm is not None: self.gradients, grad_norm = tf.clip_by_global_norm( self.gradients, self.clip_norm) self.gradients = list(zip(self.gradients, global_vars())) self.apply_gradients = grad_applier.apply_gradients(self.gradients) self.sync = self.local_net.sync_from(global_net) self.game_state = GameState(env_id=self.env_id, display=False, no_op_max=self.no_op_max, human_demo=False, episode_life=True, override_num_noops=self.override_num_noops) self.local_t = 0 self.initial_learning_rate = initial_learning_rate self.episode_reward = 0 self.episode_steps = 0 # variable controlling log output self.prev_local_t = 0 with tf.device(device): if self.use_grad_cam: self.action_meaning = self.game_state.env.unwrapped \ .get_action_meanings() self.local_net.build_grad_cam_grads() if self.use_sil: self.episode = SILReplayMemory( self.action_size, max_len=None, gamma=self.gamma, clip=reward_clipped, height=self.local_net.in_shape[0], width=self.local_net.in_shape[1], phi_length=self.local_net.in_shape[2], reward_constant=self.reward_constant) def train(self, sess, global_t, train_rewards): """Train A3C.""" states = [] fullstates = [] actions = [] rewards = [] values = [] rho = [] terminal_pseudo = False # loss of life terminal_end = False # real terminal # copy weights from shared to local sess.run(self.sync) start_local_t = self.local_t # t_max times loop for i in range(self.local_t_max): state = cv2.resize(self.game_state.s_t, self.local_net.in_shape[:-1], interpolation=cv2.INTER_AREA) fullstate = self.game_state.clone_full_state() pi_, value_, logits_ = self.local_net.run_policy_and_value( sess, state) action = self.pick_action(logits_) states.append(state) fullstates.append(fullstate) actions.append(action) values.append(value_) if self.thread_idx == self.log_idx \ and self.local_t % self.log_interval == 0: log_msg1 = "lg={}".format( np.array_str(logits_, precision=4, suppress_small=True)) log_msg2 = "pi={}".format( np.array_str(pi_, precision=4, suppress_small=True)) log_msg3 = "V={:.4f}".format(value_) logger.debug(log_msg1) logger.debug(log_msg2) logger.debug(log_msg3) # process game self.game_state.step(action) # receive game result reward = self.game_state.reward terminal = self.game_state.terminal self.episode_reward += reward if self.use_sil: # save states in episode memory self.episode.add_item(self.game_state.s_t, fullstate, action, reward, terminal) if self.reward_type == 'CLIP': reward = np.sign(reward) rewards.append(reward) self.local_t += 1 self.episode_steps += 1 global_t += 1 # s_t1 -> s_t self.game_state.update() if terminal: terminal_pseudo = True env = self.game_state.env name = 'EpisodicLifeEnv' if get_wrapper_by_name(env, name).was_real_done: # reduce log freq if self.thread_idx == self.log_idx: log_msg = "train: worker={} global_t={} local_t={}".format( self.thread_idx, global_t, self.local_t) score_str = colored( "score={}".format(self.episode_reward), "magenta") steps_str = colored( "steps={}".format(self.episode_steps), "blue") log_msg += " {} {}".format(score_str, steps_str) logger.debug(log_msg) train_rewards['train'][global_t] = (self.episode_reward, self.episode_steps) self.record_summary(score=self.episode_reward, steps=self.episode_steps, episodes=None, global_t=global_t, mode='Train') self.episode_reward = 0 self.episode_steps = 0 terminal_end = True self.game_state.reset(hard_reset=False) break cumsum_reward = 0.0 if not terminal: state = cv2.resize(self.game_state.s_t, self.local_net.in_shape[:-1], interpolation=cv2.INTER_AREA) cumsum_reward = self.local_net.run_value(sess, state) actions.reverse() states.reverse() rewards.reverse() values.reverse() batch_state = [] batch_action = [] batch_adv = [] batch_cumsum_reward = [] # compute and accumulate gradients for (ai, ri, si, vi) in zip(actions, rewards, states, values): if self.transformed_bellman: ri = np.sign(ri) * self.reward_constant + ri cumsum_reward = transform_h(ri + self.gamma * transform_h_inv(cumsum_reward)) else: cumsum_reward = ri + self.gamma * cumsum_reward advantage = cumsum_reward - vi # convert action to one-hot vector a = np.zeros([self.action_size]) a[ai] = 1 batch_state.append(si) batch_action.append(a) batch_adv.append(advantage) batch_cumsum_reward.append(cumsum_reward) cur_learning_rate = self._anneal_learning_rate( global_t, self.initial_learning_rate) feed_dict = { self.local_net.s: batch_state, self.local_net.a: batch_action, self.local_net.advantage: batch_adv, self.local_net.cumulative_reward: batch_cumsum_reward, self.learning_rate_input: cur_learning_rate, } sess.run(self.apply_gradients, feed_dict=feed_dict) t = self.local_t - self.prev_local_t if (self.thread_idx == self.log_idx and t >= self.perf_log_interval): self.prev_local_t += self.perf_log_interval elapsed_time = time.time() - self.start_time steps_per_sec = global_t / elapsed_time logger.info("worker-{}, log_worker-{}".format( self.thread_idx, self.log_idx)) logger.info("Performance : {} STEPS in {:.0f} sec. {:.0f}" " STEPS/sec. {:.2f}M STEPS/hour.".format( global_t, elapsed_time, steps_per_sec, steps_per_sec * 3600 / 1000000.)) # return advanced local step size diff_local_t = self.local_t - start_local_t return diff_local_t, terminal_end, terminal_pseudo
def __init__(self, thread_index, action_size, env_id, global_a3c, local_a3c, update_in_rollout, nstep_bc, global_pretrained_model, local_pretrained_model, transformed_bellman=False, no_op_max=0, device='/cpu:0', entropy_beta=0.01, clip_norm=None, grad_applier=None, initial_learn_rate=0.007, learning_rate_input=None): """Initialize RolloutThread class.""" self.is_refresh_thread = True self.action_size = action_size self.thread_idx = thread_index self.transformed_bellman = transformed_bellman self.entropy_beta = entropy_beta self.clip_norm = clip_norm self.initial_learning_rate = initial_learn_rate self.learning_rate_input = learning_rate_input self.no_op_max = no_op_max self.override_num_noops = 0 if self.no_op_max == 0 else None logger.info("===REFRESH thread_index: {}===".format(self.thread_idx)) logger.info("device: {}".format(device)) logger.info("action_size: {}".format(self.action_size)) logger.info("reward_type: {}".format(self.reward_type)) logger.info("transformed_bellman: {}".format( colored(self.transformed_bellman, "green" if self.transformed_bellman else "red"))) logger.info("update in rollout: {}".format( colored(update_in_rollout, "green" if update_in_rollout else "red"))) logger.info("N-step BC: {}".format(nstep_bc)) self.reward_clipped = True if self.reward_type == 'CLIP' else False # setup local a3c self.local_a3c = local_a3c self.sync_a3c = self.local_a3c.sync_from(global_a3c) with tf.device(device): local_vars = self.local_a3c.get_vars self.local_a3c.prepare_loss( entropy_beta=self.entropy_beta, critic_lr=0.5) var_refs = [v._ref() for v in local_vars()] self.rollout_gradients = tf.gradients(self.local_a3c.total_loss, var_refs) global_vars = global_a3c.get_vars if self.clip_norm is not None: self.rollout_gradients, grad_norm = tf.clip_by_global_norm( self.rollout_gradients, self.clip_norm) self.rollout_gradients = list(zip(self.rollout_gradients, global_vars())) self.rollout_apply_gradients = grad_applier.apply_gradients(self.rollout_gradients) # setup local pretrained model self.local_pretrained = None if nstep_bc > 0: assert local_pretrained_model is not None assert global_pretrained_model is not None self.local_pretrained = local_pretrained_model self.sync_pretrained = self.local_pretrained.sync_from(global_pretrained_model) # setup env self.rolloutgame = GameState(env_id=env_id, display=False, no_op_max=0, human_demo=False, episode_life=True, override_num_noops=0) self.local_t = 0 self.episode_reward = 0 self.episode_steps = 0 self.action_meaning = self.rolloutgame.env.unwrapped.get_action_meanings() assert self.local_a3c is not None if nstep_bc > 0: assert self.local_pretrained is not None self.episode = SILReplayMemory( self.action_size, max_len=None, gamma=self.gamma, clip=self.reward_clipped, height=self.local_a3c.in_shape[0], width=self.local_a3c.in_shape[1], phi_length=self.local_a3c.in_shape[2], reward_constant=self.reward_constant)
class RefreshThread(CommonWorker): """Rollout Thread Class.""" advice_confidence = 0.8 gamma = 0.99 def __init__(self, thread_index, action_size, env_id, global_a3c, local_a3c, update_in_rollout, nstep_bc, global_pretrained_model, local_pretrained_model, transformed_bellman=False, no_op_max=0, device='/cpu:0', entropy_beta=0.01, clip_norm=None, grad_applier=None, initial_learn_rate=0.007, learning_rate_input=None): """Initialize RolloutThread class.""" self.is_refresh_thread = True self.action_size = action_size self.thread_idx = thread_index self.transformed_bellman = transformed_bellman self.entropy_beta = entropy_beta self.clip_norm = clip_norm self.initial_learning_rate = initial_learn_rate self.learning_rate_input = learning_rate_input self.no_op_max = no_op_max self.override_num_noops = 0 if self.no_op_max == 0 else None logger.info("===REFRESH thread_index: {}===".format(self.thread_idx)) logger.info("device: {}".format(device)) logger.info("action_size: {}".format(self.action_size)) logger.info("reward_type: {}".format(self.reward_type)) logger.info("transformed_bellman: {}".format( colored(self.transformed_bellman, "green" if self.transformed_bellman else "red"))) logger.info("update in rollout: {}".format( colored(update_in_rollout, "green" if update_in_rollout else "red"))) logger.info("N-step BC: {}".format(nstep_bc)) self.reward_clipped = True if self.reward_type == 'CLIP' else False # setup local a3c self.local_a3c = local_a3c self.sync_a3c = self.local_a3c.sync_from(global_a3c) with tf.device(device): local_vars = self.local_a3c.get_vars self.local_a3c.prepare_loss( entropy_beta=self.entropy_beta, critic_lr=0.5) var_refs = [v._ref() for v in local_vars()] self.rollout_gradients = tf.gradients(self.local_a3c.total_loss, var_refs) global_vars = global_a3c.get_vars if self.clip_norm is not None: self.rollout_gradients, grad_norm = tf.clip_by_global_norm( self.rollout_gradients, self.clip_norm) self.rollout_gradients = list(zip(self.rollout_gradients, global_vars())) self.rollout_apply_gradients = grad_applier.apply_gradients(self.rollout_gradients) # setup local pretrained model self.local_pretrained = None if nstep_bc > 0: assert local_pretrained_model is not None assert global_pretrained_model is not None self.local_pretrained = local_pretrained_model self.sync_pretrained = self.local_pretrained.sync_from(global_pretrained_model) # setup env self.rolloutgame = GameState(env_id=env_id, display=False, no_op_max=0, human_demo=False, episode_life=True, override_num_noops=0) self.local_t = 0 self.episode_reward = 0 self.episode_steps = 0 self.action_meaning = self.rolloutgame.env.unwrapped.get_action_meanings() assert self.local_a3c is not None if nstep_bc > 0: assert self.local_pretrained is not None self.episode = SILReplayMemory( self.action_size, max_len=None, gamma=self.gamma, clip=self.reward_clipped, height=self.local_a3c.in_shape[0], width=self.local_a3c.in_shape[1], phi_length=self.local_a3c.in_shape[2], reward_constant=self.reward_constant) def record_rollout(self, score=0, steps=0, old_return=0, new_return=0, global_t=0, rollout_ctr=0, rollout_added_ctr=0, mode='Rollout', confidence=None, episodes=None): """Record rollout summary.""" summary = tf.Summary() summary.value.add(tag='{}/score'.format(mode), simple_value=float(score)) summary.value.add(tag='{}/old_return_from_s'.format(mode), simple_value=float(old_return)) summary.value.add(tag='{}/new_return_from_s'.format(mode), simple_value=float(new_return)) summary.value.add(tag='{}/steps'.format(mode), simple_value=float(steps)) summary.value.add(tag='{}/all_rollout_ctr'.format(mode), simple_value=float(rollout_ctr)) summary.value.add(tag='{}/rollout_added_ctr'.format(mode), simple_value=float(rollout_added_ctr)) if confidence is not None: summary.value.add(tag='{}/advice-confidence'.format(mode), simple_value=float(confidence)) if episodes is not None: summary.value.add(tag='{}/episodes'.format(mode), simple_value=float(episodes)) self.writer.add_summary(summary, global_t) self.writer.flush() def compute_return_for_state(self, rewards, terminal): """Compute expected return.""" length = np.shape(rewards)[0] returns = np.empty_like(rewards, dtype=np.float32) if self.reward_clipped: rewards = np.clip(rewards, -1., 1.) else: rewards = np.sign(rewards) * self.reward_constant + rewards for i in reversed(range(length)): if terminal[i]: returns[i] = rewards[i] if self.reward_clipped else transform_h(rewards[i]) else: if self.reward_clipped: returns[i] = rewards[i] + self.gamma * returns[i+1] else: # apply transformed expected return exp_r_t = self.gamma * transform_h_inv(returns[i+1]) returns[i] = transform_h(rewards[i] + exp_r_t) return returns[0] def update_a3c(self, sess, actions, states, rewards, values, global_t): cumsum_reward = 0.0 actions.reverse() states.reverse() rewards.reverse() values.reverse() batch_state = [] batch_action = [] batch_adv = [] batch_cumsum_reward = [] # compute and accumulate gradients for(ai, ri, si, vi) in zip(actions, rewards, states, values): if self.transformed_bellman: ri = np.sign(ri) * self.reward_constant + ri cumsum_reward = transform_h( ri + self.gamma * transform_h_inv(cumsum_reward)) else: cumsum_reward = ri + self.gamma * cumsum_reward advantage = cumsum_reward - vi # convert action to one-hot vector a = np.zeros([self.action_size]) a[ai] = 1 batch_state.append(si) batch_action.append(a) batch_adv.append(advantage) batch_cumsum_reward.append(cumsum_reward) cur_learning_rate = self._anneal_learning_rate(global_t, self.initial_learning_rate ) feed_dict = { self.local_a3c.s: batch_state, self.local_a3c.a: batch_action, self.local_a3c.advantage: batch_adv, self.local_a3c.cumulative_reward: batch_cumsum_reward, self.learning_rate_input: cur_learning_rate, } sess.run(self.rollout_apply_gradients, feed_dict=feed_dict) return batch_adv def rollout(self, a3c_sess, folder, pretrain_sess, global_t, past_state, add_all_rollout, ep_max_steps, nstep_bc, update_in_rollout): """Perform one rollout until terminal.""" a3c_sess.run(self.sync_a3c) if nstep_bc > 0: pretrain_sess.run(self.sync_pretrained) _, fs, old_a, old_return, _, _ = past_state states = [] actions = [] rewards = [] values = [] terminals = [] confidences = [] rollout_ctr, rollout_added_ctr = 0, 0 rollout_new_return, rollout_old_return = 0, 0 terminal_pseudo = False # loss of life terminal_end = False # real terminal add = False self.rolloutgame.reset(hard_reset=True) self.rolloutgame.restore_full_state(fs) # check if restore successful fs_check = self.rolloutgame.clone_full_state() assert fs_check.all() == fs.all() del fs_check start_local_t = self.local_t self.rolloutgame.step(0) # prevent rollout too long, set max_ep_steps to be lower than ALE default # see https://github.com/openai/gym/blob/54f22cf4db2e43063093a1b15d968a57a32b6e90/gym/envs/__init__.py#L635 # but in all games tested, no rollout exceeds ep_max_steps while ep_max_steps > 0: state = cv2.resize(self.rolloutgame.s_t, self.local_a3c.in_shape[:-1], interpolation=cv2.INTER_AREA) fullstate = self.rolloutgame.clone_full_state() if nstep_bc > 0: # LiDER-TA or BC model_pi = self.local_pretrained.run_policy(pretrain_sess, state) action, confidence = self.choose_action_with_high_confidence( model_pi, exclude_noop=False) confidences.append(confidence) # not using "confidences" for anything nstep_bc -= 1 else: # LiDER, refresh with current policy pi_, _, logits_ = self.local_a3c.run_policy_and_value(a3c_sess, state) action = self.pick_action(logits_) confidences.append(pi_[action]) value_ = self.local_a3c.run_value(a3c_sess, state) values.append(value_) states.append(state) actions.append(action) self.rolloutgame.step(action) ep_max_steps -= 1 reward = self.rolloutgame.reward terminal = self.rolloutgame.terminal terminals.append(terminal) self.episode_reward += reward self.episode.add_item(self.rolloutgame.s_t, fullstate, action, reward, terminal, from_rollout=True) if self.reward_type == 'CLIP': reward = np.sign(reward) rewards.append(reward) self.local_t += 1 self.episode_steps += 1 global_t += 1 self.rolloutgame.update() if terminal: terminal_pseudo = True env = self.rolloutgame.env name = 'EpisodicLifeEnv' rollout_ctr += 1 terminal_end = get_wrapper_by_name(env, name).was_real_done new_return = self.compute_return_for_state(rewards, terminals) if not add_all_rollout: if new_return > old_return: add = True else: add = True if add: rollout_added_ctr += 1 rollout_new_return += new_return rollout_old_return += old_return # update policy immediate using a good rollout if update_in_rollout: batch_adv = self.update_a3c(a3c_sess, actions, states, rewards, values, global_t) self.episode_reward = 0 self.episode_steps = 0 self.rolloutgame.reset(hard_reset=True) break diff_local_t = self.local_t - start_local_t return diff_local_t, terminal_end, terminal_pseudo, rollout_ctr, \ rollout_added_ctr, add, rollout_new_return, rollout_old_return