def run(self, net): """ run is the entry function to begin collecting rollouts from the environment using the specified net. gate_q indicates when to begin collecting a rollout and is controlled from the main process. The stop_q is used to indicate to the main process that a new rollout has been collected. net - torch Module object. This is the model to interact with the environment. """ self.net = net float_params = try_key(self.hyps, "float_params", dict()) self.env = SeqEnv(self.hyps['env_type'], self.hyps['seed'], worker_id=None, float_params=float_params) state = next_state(self.env, self.obs_deque, obs=None, reset=True, preprocess=self.hyps['preprocess']) self.state_bookmark = state self.ep_rew = 0 self.net.train(mode=False) # fixes batchnorm issues for p in self.net.parameters(): # Turn off gradient collection p.requires_grad = False with torch.no_grad(): while self.end_q.empty(): idx = self.gate_q.get() # Opened from main process self.rollout(self.net, idx, self.hyps) self.stop_q.put( idx ) # Signals to main process that data has been collected
def __init__(self, hyps): """ hyps - dict object with all necessary hyperparameters keys (Assume string type keys): "n_tsteps" - number of steps to be taken in the environment "n_frame_stack" - number of frames to stack for creation of the mdp state "preprocessor" - function to preprocess raw observations "env_type" - type of gym environment to be interacted with. Follows OpenAI's gym api. "seed" - the random seed for the env """ self.hyps = hyps self.env = SequentialEnvironment(**self.hyps) self.obs_deque = deque(maxlen=self.hyps['n_frame_stack']) self.n_episodes = try_key(self.hyps, "n_test_eps", 15)
def train(_, hyps, verbose=True): """ hyps - dictionary of required hyperparameters type: dict """ # Hyperparam corrections if isinstance(try_key(hyps, "grid_size", None), int): hyps['grid_size'] = [hyps['grid_size'], hyps['grid_size']] # Preprocessor Type env_type = hyps['env_type'].lower() hyps['preprocessor'] = getattr(preprocessing, hyps['prep_fxn']) hyps['main_path'] = try_key(hyps, "main_path", "./") hyps['exp_num'] = get_exp_num(hyps['main_path'], hyps['exp_name']) hyps['save_folder'] = get_save_folder(hyps) save_folder = hyps['save_folder'] if not os.path.exists(hyps['save_folder']): os.mkdir(hyps['save_folder']) hyps['seed'] = try_key(hyps, 'seed', int(time.time())) torch.manual_seed(hyps['seed']) np.random.seed(hyps['seed']) net_save_file = os.path.join(save_folder, "net.p") best_net_file = os.path.join(save_folder, "best_net.p") optim_save_file = os.path.join(save_folder, "optim.p") log_file = os.path.join(save_folder, "log.txt") if hyps['resume']: log = open(log_file, 'a') else: log = open(log_file, 'w') keys = sorted(list(hyps.keys())) for k in keys: log.write(k + ":" + str(hyps[k]) + "\n") # Miscellaneous Variable Prep logger = Logger() shared_len = hyps['n_tsteps'] * hyps['n_rollouts'] stats_runner = StatsRunner(hyps) env = stats_runner.env hyps['is_discrete'] = env.is_discrete obs = env.reset() hyps['state_shape'] = [hyps['n_frame_stack']] + [*obs.shape[1:]] if hyps['env_type'] == "Pong-v0": action_size = 3 hyps['action_shift'] = 1 else: action_size = env.n hyps['action_shift'] = 0 print("Raw Obs Shape:", env.raw_shape) print("Obs Shape:,", obs.shape) print("State Shape:,", hyps['state_shape']) print("Num Samples Per Update:", shared_len) # Make Network hyps["action_size"] = action_size net = globals()[hyps['model']](hyps['state_shape'], action_size, bnorm=hyps['use_bnorm'], **hyps) if try_key(hyps, 'resume', False): net.load_state_dict(torch.load(net_save_file)) base_net = copy.deepcopy(net) net = cuda_if(net) net.share_memory() base_net = cuda_if(base_net) # Prepare Shared Variables states_shape = (shared_len, *hyps['state_shape']) if env.is_discrete: actions = torch.zeros(shared_len).long() else: actions = torch.zeros((shared_len, env.n)).float() shared_data = { 'states': cuda_if(torch.zeros(states_shape).share_memory_()), 'deltas': cuda_if(torch.zeros(shared_len).share_memory_()), 'rewards': cuda_if(torch.zeros(shared_len).share_memory_()), 'actions': actions.share_memory_(), 'dones': cuda_if(torch.zeros(shared_len).share_memory_()) } if net.is_recurrent: zeros = torch.zeros(shared_len, net.h_size) shared_data['h_states'] = cuda_if(zeros.share_memory_()) n_rollouts = hyps['n_rollouts'] gate_q = mp.Queue(n_rollouts) stop_q = mp.Queue(n_rollouts) reward_q = mp.Queue(1) reward_q.put(-1) # Make Runners runners = [] for i in range(hyps['n_envs']): runner = Runner(shared_data, hyps, gate_q, stop_q, reward_q) runners.append(runner) # Start Data Collection print("Making New Processes") procs = [] for i in range(len(runners)): proc = mp.Process(target=runners[i].run, args=(net, )) procs.append(proc) proc.start() print(i, "/", len(runners), end='\r') for i in range(n_rollouts): gate_q.put(i) # Make Updater updater = Updater(base_net, hyps) if hyps['resume']: updater.optim.load_state_dict(torch.load(optim_save_file)) updater.optim.zero_grad() updater.net.train(mode=True) updater.net.req_grads(True) # Prepare Decay Precursors entr_coef_diff = hyps['entr_coef'] - hyps['entr_coef_low'] lr_diff = hyps['lr'] - hyps['lr_low'] gamma_diff = hyps['gamma_high'] - hyps['gamma'] # Training Loop past_rews = deque([0] * hyps['n_past_rews']) last_avg_rew = 0 best_eval_rew = -np.inf epoch = 0 T = 0 while T < hyps['max_tsteps']: basetime = time.time() epoch += 1 stats_string = "" # Collect data for i in range(n_rollouts): stop_q.get() T += shared_len s = "Epoch {} - T: {} -- {}".format(epoch, T, hyps['save_folder']) print(s) stats_string += s + "\n" # Reward Stats avg_reward = reward_q.get() reward_q.put(avg_reward) last_avg_rew = avg_reward # Calculate the Loss and Update nets updater.update_model(shared_data) # update all collector nets net.load_state_dict(updater.net.state_dict()) eval_rew = stats_runner.rollout(net) if eval_rew > best_eval_rew: best_eval_rew = eval_rew updater.save_model(best_net_file, None) stats_string += "Eval rew: {}\n".format(eval_rew) # Resume Data Collection for i in range(n_rollouts): gate_q.put(i) # Decay HyperParameters if hyps['decay_lr']: decay_factor = max((1 - T / (hyps['max_tsteps'])), 0) new_lr = decay_factor * lr_diff + hyps['lr_low'] updater.new_lr(new_lr) s = "New lr: " + str(new_lr) print(s) stats_string += s + "\n" if hyps['decay_entr']: decay_factor = max((1 - T / (hyps['max_tsteps'])), 0) updater.entr_coef = entr_coef_diff * decay_factor updater.entr_coef += hyps['entr_coef_low'] s = "New Entr: " + str(updater.entr_coef) print(s) stats_string += s + "\n" # Periodically save model if epoch % 10 == 0: updater.save_model(net_save_file, optim_save_file) # Print Epoch Data past_rews.popleft() past_rews.append(avg_reward) max_rew, min_rew = deque_maxmin(past_rews) rew_avg, rew_std = np.mean(past_rews), np.std(past_rews) updater.print_statistics() avg_action = shared_data['actions'].float().mean().item() s = "Grad Norm: {:.5f} – Avg Action: {:.5f} - Best EvalRew: {:.5f}" s = s.format(float(updater.norm), avg_action, best_eval_rew) stats_string += s + "\n" stats_string += "Avg Rew: " + str(avg_reward) + "\n" s = "Past " + str(hyps['n_past_rews']) + " Rews – High: {:.5f}" s += " - Low: {:.5f} - Avg: {:.5f} - StD: {:.5f}" stats_string += s.format(max_rew, min_rew, rew_avg, rew_std) + "\n" updater.log_statistics(log, T, avg_reward, avg_action, best_eval_rew) updater.info["EvalRew"] = eval_rew updater.info['AvgRew'] = avg_reward logger.append(updater.info, x_val=T) # Check for memory leaks gc.collect() max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss s = "Time: " + str(time.time() - basetime) stats_string += s + "\n" print(stats_string) log.write(stats_string + "\n") if 'hyp_search_count' in hyps and hyps['hyp_search_count'] > 0\ and hyps['search_id'] != None: print("Search:", hyps['search_id'], "/", hyps['hyp_search_count']) print("Memory Used: {:.2f} memory\n".format(max_mem_used / 1024)) logger.make_plots(save_folder + hyps['exp_name']) log.write("\nBestRew:" + str(best_eval_rew)) log.close() # Close processes for p in procs: p.terminate() return best_eval_rew
def train(self, hyps): """ hyps - dictionary of required hyperparameters type: dict """ # Initial settings if "randomizeObjs" in hyps: assert False, "you mean randomizeObs, not randomizeObjs" if "audibleTargs" in hyps and hyps['audibleTargs'] > 0: hyps['aud_targs'] = True if verbose: print("Using audible targs!") countOut = try_key(hyps, 'countOut', 0) if countOut and not hyps['endAtOrigin']: assert False, "endAtOrigin must be true for countOut setting" # Print Hyperparameters To Screen items = list(hyps.items()) for k, v in sorted(items): print(k+":", v) # Make Save Files if "save_folder" in hyps: save_folder = hyps['save_folder'] else: save_folder = "./saved_data/" if not os.path.exists(save_folder): os.mkdir(save_folder) base_name = save_folder + hyps['exp_name'] net_save_file = base_name+"_net.p" fwd_save_file = base_name+"_fwd.p" best_net_file = base_name+"_best.p" optim_save_file = base_name+"_optim.p" fwd_optim_file = base_name+"_fwdoptim.p" hyps['fwd_emb_file'] = base_name+"_fwdemb.p" if hyps['inv_model'] is not None: inv_save_file = base_name+"_invnet.p" reconinv_optim_file = base_name+"_reconinvoptim.p" else: inv_save_file = None reconinv_optim_file = None if hyps['recon_model'] is not None: recon_save_file = base_name+"_reconnet.p" reconinv_optim_file = base_name+"_reconinvoptim.p" else: recon_save_file = None log_file = base_name+"_log.txt" if hyps['resume']: log = open(log_file, 'a') else: log = open(log_file, 'w') for k, v in sorted(items): log.write(k+":"+str(v)+"\n") # Miscellaneous Variable Prep logger = Logger() shared_len = hyps['n_tsteps']*hyps['n_rollouts'] float_params = dict() if "float_params" not in hyps: try: keys = hyps['game_keys'] hyps['float_params'] = {k:try_key(hyps,k,0) for k in keys} if "minObjLoc" not in hyps: hyps['float_params']["minObjLoc"] = 0.27 hyps['float_params']["maxObjLoc"] = 0.73 float_params = hyps['float_params'] except: pass env = SeqEnv(hyps['env_type'], hyps['seed'], worker_id=None, float_params=float_params) hyps['discrete_env'] = hasattr(env.action_space, "n") obs = env.reset() prepped = hyps['preprocess'](obs) hyps['state_shape'] = [hyps['n_frame_stack']*prepped.shape[0], *prepped.shape[1:]] if not hyps['discrete_env']: action_size = int(np.prod(env.action_space.shape)) elif hyps['env_type'] == "Pong-v0": action_size = 3 else: action_size = env.action_space.n hyps['action_shift'] = (4-action_size)*(hyps['env_type']=="Pong-v0") print("Obs Shape:,",obs.shape) print("Prep Shape:,",prepped.shape) print("State Shape:,",hyps['state_shape']) print("Num Samples Per Update:", shared_len) if not (hyps['n_cache_refresh'] <= shared_len or hyps['cache_size'] == 0): hyps['n_cache_refresh'] = shared_len print("Samples Wasted in Update:", shared_len % hyps['batch_size']) try: env.close() except: pass del env # Prepare Shared Variables shared_data = { 'states': torch.zeros(shared_len, *hyps['state_shape']).share_memory_(), 'next_states': torch.zeros(shared_len, *hyps['state_shape']).share_memory_(), 'dones':torch.zeros(shared_len).share_memory_(), 'rews':torch.zeros(shared_len).share_memory_(), 'hs':torch.zeros(shared_len,hyps['h_size']).share_memory_(), 'next_hs':torch.zeros(shared_len,hyps['h_size']).share_memory_()} if hyps['discrete_env']: shared_data['actions'] = torch.zeros(shared_len).long().share_memory_() else: shape = (shared_len, action_size) shared_data['actions']=torch.zeros(shape).float().share_memory_() shared_data = {k: cuda_if(v) for k,v in shared_data.items()} n_rollouts = hyps['n_rollouts'] gate_q = mp.Queue(n_rollouts) stop_q = mp.Queue(n_rollouts) end_q = mp.Queue(1) reward_q = mp.Queue(1) reward_q.put(-1) # Make Runners runners = [] for i in range(hyps['n_envs']): runner = Runner(shared_data, hyps, gate_q, stop_q, end_q, reward_q) runners.append(runner) # Make the Networks h_size = hyps['h_size'] net = hyps['model'](hyps['state_shape'], action_size, h_size, bnorm=hyps['use_bnorm'], lnorm=hyps['use_lnorm'], discrete_env=hyps['discrete_env']) # Fwd Dynamics hyps['is_recurrent'] = hasattr(net, "fresh_h") intl_size = h_size+action_size + hyps['is_recurrent']*h_size if hyps['fwd_lnorm']: block = [nn.LayerNorm(intl_size)] block = [nn.Linear(intl_size, h_size), nn.ReLU(), nn.Linear(h_size, h_size), nn.ReLU(), nn.Linear(h_size, h_size)] fwd_net = nn.Sequential(*block) # Allows us to argue an h vector along with embedding to # forward func if hyps['is_recurrent']: fwd_net = CatModule(fwd_net) if hyps['ensemble']: fwd_net = Ensemble(fwd_net) fwd_net = cuda_if(fwd_net) if hyps['inv_model'] is not None: inv_net = hyps['inv_model'](h_size, action_size) inv_net = cuda_if(inv_net) else: inv_net = None if hyps['recon_model'] is not None: recon_net = hyps['recon_model'](emb_size=h_size, img_shape=hyps['state_shape'], fwd_bnorm=hyps['fwd_bnorm'], deconv_ksizes=hyps['recon_ksizes']) recon_net = cuda_if(recon_net) else: recon_net = None if hyps['resume']: net.load_state_dict(torch.load(net_save_file)) fwd_net.load_state_dict(torch.load(fwd_save_file)) if inv_net is not None: inv_net.load_state_dict(torch.load(inv_save_file)) if recon_net is not None: recon_net.load_state_dict(torch.load(recon_save_file)) base_net = copy.deepcopy(net) net = cuda_if(net) net.share_memory() base_net = cuda_if(base_net) hyps['is_recurrent'] = hasattr(net, "fresh_h") # Start Data Collection print("Making New Processes") procs = [] for i in range(len(runners)): proc = mp.Process(target=runners[i].run, args=(net,)) procs.append(proc) proc.start() print(i, "/", len(runners), end='\r') for i in range(n_rollouts): gate_q.put(i) # Make Updater updater = Updater(base_net, fwd_net, hyps, inv_net, recon_net) if hyps['resume']: updater.optim.load_state_dict(torch.load(optim_save_file)) updater.fwd_optim.load_state_dict(torch.load(fwd_optim_file)) if inv_net is not None: updater.reconinv_optim.load_state_dict(torch.load(reconinv_optim_file)) updater.optim.zero_grad() updater.net.train(mode=True) updater.net.req_grads(True) # Prepare Decay Precursors entr_coef_diff = hyps['entr_coef'] - hyps['entr_coef_low'] epsilon_diff = hyps['epsilon'] - hyps['epsilon_low'] lr_diff = hyps['lr'] - hyps['lr_low'] gamma_diff = hyps['gamma_high'] - hyps['gamma'] # Training Loop past_rews = deque([0]*hyps['n_past_rews']) last_avg_rew = 0 best_rew_diff = 0 best_avg_rew = -10000 best_eval_rew = -10000 ep_eval_rew = 0 eval_rew = 0 epoch = 0 done_count = 0 T = 0 try: while T < hyps['max_tsteps']: basetime = time.time() epoch += 1 # Collect data for i in range(n_rollouts): stop_q.get() T += shared_len # Reward Stats avg_reward = reward_q.get() reward_q.put(avg_reward) last_avg_rew = avg_reward done_count += shared_data['dones'].sum().item() new_best = False if avg_reward > best_avg_rew and done_count > n_rollouts: new_best = True best_avg_rew = avg_reward updater.save_model(best_net_file, fwd_save_file, None, None) eval_rew = shared_data['rews'].mean() if eval_rew > best_eval_rew: best_eval_rew = eval_rew save_names = [net_save_file, fwd_save_file, optim_save_file, fwd_optim_file, inv_save_file, recon_save_file, reconinv_optim_file] for i in range(len(save_names)): if save_names[i] is not None: splt = save_names[i].split(".") splt[0] = splt[0]+"_best" save_names[i] = ".".join(splt) updater.save_model(*save_names) s = "EvalRew: {:.5f} | BestEvalRew: {:.5f}" print(s.format(eval_rew, best_eval_rew)) # Calculate the Loss and Update nets updater.update_model(shared_data) net.load_state_dict(updater.net.state_dict()) # update all collector nets # Resume Data Collection for i in range(n_rollouts): gate_q.put(i) # Decay HyperParameters if hyps['decay_eps']: updater.epsilon = (1-T/(hyps['max_tsteps']))*epsilon_diff + hyps['epsilon_low'] print("New Eps:", updater.epsilon) if hyps['decay_lr']: new_lr = (1-T/(hyps['max_tsteps']))*lr_diff + hyps['lr_low'] updater.new_lr(new_lr) print("New lr:", new_lr) if hyps['decay_entr']: updater.entr_coef = entr_coef_diff*(1-T/(hyps['max_tsteps']))+hyps['entr_coef_low'] print("New Entr:", updater.entr_coef) if hyps['incr_gamma']: updater.gamma = gamma_diff*(T/(hyps['max_tsteps']))+hyps['gamma'] print("New Gamma:", updater.gamma) # Periodically save model if epoch % 10 == 0 or epoch == 1: updater.save_model(net_save_file, fwd_save_file, optim_save_file, fwd_optim_file, inv_save_file, recon_save_file, reconinv_optim_file) # Print Epoch Data past_rews.popleft() past_rews.append(avg_reward) max_rew, min_rew = deque_maxmin(past_rews) print("Epoch", epoch, "– T =", T, "-- Folder:", base_name) if not hyps['discrete_env']: s = ("{:.5f} | "*net.logsigs.shape[1]) s = s.format(*[x.item() for x in torch.exp(net.logsigs[0])]) print("Sigmas:", s) updater.print_statistics() avg_action = shared_data['actions'].float().mean().item() print("Grad Norm:",float(updater.norm),"– Avg Action:",avg_action,"– Best AvgRew:",best_avg_rew) print("Avg Rew:", avg_reward, "– High:", max_rew, "– Low:", min_rew, end='\n') updater.log_statistics(log, T, avg_reward, avg_action, best_avg_rew) updater.info['AvgRew'] = avg_reward updater.info['EvalRew'] = eval_rew logger.append(updater.info, x_val=T) # Check for memory leaks gc.collect() max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss print("Time:", time.time()-basetime) if 'hyp_search_count' in hyps and hyps['hyp_search_count'] > 0 and hyps['search_id'] != None: print("Search:", hyps['search_id'], "/", hyps['hyp_search_count']) print("Memory Used: {:.2f} memory\n".format(max_mem_used / 1024)) if updater.info["VLoss"] == float('inf') or updater.norm == float('inf'): break except KeyboardInterrupt: pass end_q.put(1) time.sleep(1) logger.make_plots(base_name) log.write("\nBestRew:"+str(best_avg_rew)) log.close() # Close processes for p in procs: p.terminate() return best_avg_rew
print("Making model") model = getattr(models,hyps['model_class'])(**hyps) model.to(DEVICE) try: model.load_state_dict(checkpt["state_dict"]) except: keys = list(checkpt['state_dict'].keys()) for k in keys: if "pavlov" == k.split(".")[0]: del checkpt['state_dict'][k] model.load_state_dict(checkpt['state_dict']) print("state dict success") model.eval() fwd_dynamics = try_key(hyps,'use_fwd_dynamics',False) and\ try_key(hyps,"countOut",False) if fwd_dynamics: fwd_model = getattr(models,hyps['fwd_class'])(**hyps) fwd_model.cuda() fwd_model.load_state_dict(checkpt['fwd_state_dict']) fwd_model.eval() else: fwd_model = DummyFwdModel() done = True sum_rew = 0 n_loops = 0 frames = [] obscatpreds = [] with torch.no_grad():
import locgame.save_io as locio import ml_utils.save_io as mlio import ml_utils.analysis as mlanl from ml_utils.utils import try_key import pandas as pd import os import sys if __name__ == "__main__": argued_folders = sys.argv[1:] model_folders = [] for folder in argued_folders: if not mlio.is_model_folder(folder): model_folders += mlio.get_model_folders(folder, True) else: model_folders += [folder] print("Model Folders:", model_folders) for model_folder in model_folders: checkpts = mlio.get_checkpoints(model_folder) if len(checkpts) == 0: continue table = mlanl.get_table(mlio.load_checkpoint(checkpts[0])) for checkpt in checkpts: chkpt = mlio.load_checkpoint(checkpt) for k in table.keys(): if k in set(chkpt.keys()): table[k].append(chkpt[k]) df = pd.DataFrame(table) df['seed'] = try_key(chkpt['hyps'], 'seed', -1) save_path = os.path.join(model_folder, "model_data.csv") df.to_csv(save_path, sep="!", index=False, header=True)
def train(hyps, verbose=True): """ hyps: dict contains all relavent hyperparameters """ # Set manual seed hyps['exp_num'] = get_exp_num(hyps['main_path'], hyps['exp_name']) hyps['save_folder'] = get_save_folder(hyps) if not os.path.exists(hyps['save_folder']): os.mkdir(hyps['save_folder']) hyps['seed'] = try_key(hyps, 'seed', int(time.time())) torch.manual_seed(hyps['seed']) np.random.seed(hyps['seed']) model_class = hyps['model_class'] hyps['model_type'] = models.TRANSFORMER_TYPE[model_class] if not hyps['init_decs'] and not hyps['gen_decs'] and\ not hyps['ordered_preds']: s = "WARNING!! You probably want to set ordered preds to True " s += "with your current configuration!!" print(s) if verbose: print("Retreiving Dataset") if "shuffle_split" not in hyps and hyps['shuffle']: hyps['shuffle_split'] = True train_data, val_data = datas.get_data(**hyps) hyps['enc_slen'] = train_data.X.shape[-1] hyps['dec_slen'] = train_data.Y.shape[-1] - 1 #if hyps[ train_loader = torch.utils.data.DataLoader(train_data, batch_size=hyps['batch_size'], shuffle=hyps['shuffle']) val_loader = torch.utils.data.DataLoader(val_data, batch_size=hyps['batch_size']) hyps['n_vocab'] = len(train_data.word2idx.keys()) if verbose: print("Making model") model = getattr(models, model_class)(**hyps) model.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=hyps['lr'], weight_decay=hyps['l2']) init_checkpt = try_key(hyps, "init_checkpt", None) if init_checkpt is not None and init_checkpt != "": if verbose: print("Loading state dicts from", init_checkpt) checkpt = io.load_checkpoint(init_checkpt) model.load_state_dict(checkpt["state_dict"]) optimizer.load_state_dict(checkpt["optim_dict"]) lossfxn = nn.CrossEntropyLoss() scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=6, verbose=True) if model.transformer_type != models.DICTIONARY: hyps['emb_alpha'] = 0 if verbose: print("Beginning training for {}".format(hyps['save_folder'])) print("train shape:", train_data.X.shape) print("val shape:", val_data.X.shape) print("n_vocab:", hyps['n_vocab']) record_session(hyps, model) if hyps['dataset'] == "WordProblem": save_data_structs(train_data.samp_structs) if hyps['exp_name'] == "test": hyps['n_epochs'] = 2 epoch = -1 alpha = hyps['loss_alpha'] emb_alpha = hyps['emb_alpha'] print() idx2word = train_data.idx2word mask_idx = train_data.word2idx["<MASK>"] while epoch < hyps['n_epochs']: epoch += 1 print("Epoch:{} | Model:{}".format(epoch, hyps['save_folder'])) starttime = time.time() avg_loss = 0 avg_acc = 0 avg_indy_acc = 0 avg_emb_loss = 0 mask_avg_loss = 0 mask_avg_acc = 0 model.train() print("Training...") optimizer.zero_grad() for b, (x, y) in enumerate(train_loader): torch.cuda.empty_cache() if model.transformer_type == models.AUTOENCODER: targs = x.data[:, 1:] y = x.data elif model.transformer_type == models.DICTIONARY: targs = x.data[:, 1:] emb_targs = model.embeddings(y.to(DEVICE)) word_to_define = idx2word[y.squeeze()[0].item()] y = x.data else: targs = y.data[:, 1:] og_shape = targs.shape if hyps['masking_task']: x, y, mask = mask_words(x, y, mask_p=hyps['mask_p']) y = y[:, :-1] preds = model(x.to(DEVICE), y.to(DEVICE)) tot_loss = 0 if model.transformer_type == models.DICTIONARY: emb_preds = preds[1] preds = preds[0] emb_loss = F.mse_loss(emb_preds, emb_targs.data) tot_loss += (emb_alpha) * emb_loss avg_emb_loss += emb_loss.item() if epoch % 3 == 0 and b == 0: if model.transformer_type == models.DICTIONARY: print("Word:", word_to_define) whr = torch.where(y[0] == mask_idx)[0] endx = y.shape[-1] if len(whr) == 0 else whr[0].item() print("y:", [idx2word[a.item()] for a in y[0, :endx]]) print("t:", [idx2word[a.item()] for a in targs[0, :endx - 1]]) ms = torch.argmax(preds, dim=-1) print("p:", [idx2word[a.item()] for a in ms[0, :endx - 1]]) del ms targs = targs.reshape(-1) if hyps['masking_task']: print("masking!") # Mask loss and acc preds = preds.reshape(-1, preds.shape[-1]) mask = mask.reshape(-1).bool() idxs = torch.arange(len(mask))[mask] mask_preds = preds[idxs] mask_targs = targs[idxs] mask_loss = lossfxn(mask_preds, mask_targs) mask_preds = torch.argmax(mask_preds, dim=-1) mask_acc = (mask_preds == mask_targs).sum().float() mask_acc = mask_acc / idxs.numel() mask_avg_acc += mask_acc.item() mask_avg_loss += mask_loss.item() else: mask_loss = torch.zeros(1).to(DEVICE) mask_acc = torch.zeros(1).to(DEVICE) mask_avg_acc += mask_acc.item() mask_avg_loss += mask_loss.item() # Tot loss and acc preds = preds.reshape(-1, preds.shape[-1]) targs = targs.reshape(-1).to(DEVICE) if not hyps['masking_task']: bitmask = (targs != mask_idx) loss = (1 - emb_alpha) * lossfxn(preds[bitmask], targs[bitmask]) else: loss = lossfxn(preds, targs) if hyps['masking_task']: temp = ((alpha) * loss + (1 - alpha) * mask_loss) tot_loss += temp / hyps['n_loss_loops'] else: tot_loss += loss / hyps['n_loss_loops'] tot_loss.backward() if b % hyps['n_loss_loops'] == 0 or b == len(train_loader) - 1: optimizer.step() optimizer.zero_grad() with torch.no_grad(): preds = torch.argmax(preds, dim=-1) sl = og_shape[-1] if not hyps['masking_task']: eq = (preds == targs).float() indy_acc = eq[bitmask].mean() eq[~bitmask] = 1 eq = eq.reshape(og_shape) acc = (eq.sum(-1) == sl).float().mean() else: eq = (preds == targs).float().reshape(og_shape) acc = (eq.sum(-1) == sl).float().mean() indy_acc = eq.mean() preds = preds.cpu() avg_acc += acc.item() avg_indy_acc += indy_acc.item() avg_loss += loss.item() if hyps["masking_task"]: s = "Mask Loss:{:.5f} | Acc:{:.5f} | {:.0f}%" s = s.format(mask_loss.item(), mask_acc.item(), b / len(train_loader) * 100) elif model.transformer_type == models.DICTIONARY: s = "Loss:{:.5f} | Acc:{:.5f} | Emb:{:.5f} | {:.0f}%" s = s.format(loss.item(), acc.item(), emb_loss.item(), b / len(train_loader) * 100) else: s = "Loss:{:.5f} | Acc:{:.5f} | {:.0f}%" s = s.format(loss.item(), acc.item(), b / len(train_loader) * 100) print(s, end=len(s) * " " + "\r") if hyps['exp_name'] == "test" and b > 5: break print() mask_train_loss = mask_avg_loss / len(train_loader) mask_train_acc = mask_avg_acc / len(train_loader) train_avg_loss = avg_loss / len(train_loader) train_avg_acc = avg_acc / len(train_loader) train_avg_indy = avg_indy_acc / len(train_loader) train_emb_loss = avg_emb_loss / len(train_loader) stats_string = "Train - Loss:{:.5f} | Acc:{:.5f} | Indy:{:.5f}\n" stats_string = stats_string.format(train_avg_loss, train_avg_acc, train_avg_indy) if hyps['masking_task']: stats_string += "Tr. Mask Loss:{:.5f} | Tr. Mask Acc:{:.5f}\n" stats_string = stats_string.format(mask_train_loss, mask_train_acc) elif model.transformer_type == models.DICTIONARY: stats_string += "Train Emb Loss:{:.5f}\n" stats_string = stats_string.format(train_emb_loss) model.eval() avg_loss = 0 avg_acc = 0 avg_indy_acc = 0 avg_emb_loss = 0 mask_avg_loss = 0 mask_avg_acc = 0 print("Validating...") words = None with torch.no_grad(): rand_word_batch = int(np.random.randint(0, len(val_loader))) for b, (x, y) in enumerate(val_loader): torch.cuda.empty_cache() if model.transformer_type == models.AUTOENCODER: targs = x.data[:, 1:] y = x.data elif model.transformer_type == models.DICTIONARY: targs = x.data[:, 1:] emb_targs = model.embeddings(y.to(DEVICE)) if b == rand_word_batch or hyps['exp_name'] == "test": words = [idx2word[y.squeeze()[i].item()] for\ i in range(len(y.squeeze()))] y = x.data else: targs = y.data[:, 1:] og_shape = targs.shape if hyps['init_decs']: y = train_data.inits.clone().repeat(len(x), 1) if hyps['masking_task']: x, y, mask = mask_words(x, y, mask_p=hyps['mask_p']) y = y[:, :-1] preds = model(x.to(DEVICE), y.to(DEVICE)) tot_loss = 0 if model.transformer_type == models.DICTIONARY: emb_preds = preds[1] preds = preds[0] emb_loss = F.mse_loss(emb_preds, emb_targs) avg_emb_loss += emb_loss.item() if hyps['masking_task']: # Mask loss and acc targs = targs.reshape(-1) preds = preds.reshape(-1, preds.shape[-1]) mask = mask.reshape(-1).bool() idxs = torch.arange(len(mask))[mask] mask_preds = preds[idxs] mask_targs = targs[idxs] mask_loss = lossfxn(mask_preds, mask_targs) mask_avg_loss += mask_loss.item() mask_preds = torch.argmax(mask_preds, dim=-1) mask_acc = (mask_preds == mask_targs).sum().float() mask_acc = mask_acc / mask_preds.numel() mask_avg_acc += mask_acc.item() else: mask_acc = torch.zeros(1).to(DEVICE) mask_loss = torch.zeros(1).to(DEVICE) # Tot loss and acc preds = preds.reshape(-1, preds.shape[-1]) targs = targs.reshape(-1).to(DEVICE) if not hyps['masking_task']: bitmask = (targs != mask_idx) loss = lossfxn(preds[bitmask], targs[bitmask]) else: loss = lossfxn(preds, targs) preds = torch.argmax(preds, dim=-1) sl = og_shape[-1] if not hyps['masking_task']: eq = (preds == targs).float() indy_acc = eq[bitmask].mean() eq[~bitmask] = 1 eq = eq.reshape(og_shape) acc = (eq.sum(-1) == sl).float().mean() else: eq = (preds == targs).float().reshape(og_shape) acc = (eq.sum(-1) == sl).float().mean() indy_acc = eq.mean() preds = preds.cpu() if b == rand_word_batch or hyps['exp_name'] == "test": rand = int(np.random.randint(0, len(x))) question = x[rand] whr = torch.where(question == mask_idx)[0] endx = len(question) if len(whr) == 0 else whr[0].item() question = question[:endx] pred_samp = preds.reshape(og_shape)[rand] targ_samp = targs.reshape(og_shape)[rand] whr = torch.where(targ_samp == mask_idx)[0] endx = len(targ_samp) if len(whr) == 0 else whr[0].item() targ_samp = targ_samp[:endx] pred_samp = pred_samp[:endx] idx2word = train_data.idx2word question = [idx2word[p.item()] for p in question] pred_samp = [idx2word[p.item()] for p in pred_samp] targ_samp = [idx2word[p.item()] for p in targ_samp] question = " ".join(question) pred_samp = " ".join(pred_samp) targ_samp = " ".join(targ_samp) if words is not None: word_samp = str(words[rand]) avg_acc += acc.item() avg_indy_acc += indy_acc.item() avg_loss += loss.item() if hyps["masking_task"]: s = "Mask Loss:{:.5f} | Acc:{:.5f} | {:.0f}%" s = s.format(mask_loss.item(), mask_acc.item(), b / len(val_loader) * 100) elif model.transformer_type == models.DICTIONARY: s = "Loss:{:.5f} | Acc:{:.5f} | Emb:{:.5f} | {:.0f}%" s = s.format(loss.item(), acc.item(), emb_loss.item(), b / len(val_loader) * 100) else: s = "Loss:{:.5f} | Acc:{:.5f} | {:.0f}%" s = s.format(loss.item(), acc.item(), b / len(val_loader) * 100) print(s, end=len(s) * " " + "\r") if hyps['exp_name'] == "test" and b > 5: break print() del targs del x del y del eq if model.transformer_type == models.DICTIONARY: del emb_targs torch.cuda.empty_cache() mask_val_loss = mask_avg_loss / len(val_loader) mask_val_acc = mask_avg_acc / len(val_loader) val_avg_loss = avg_loss / len(val_loader) val_avg_acc = avg_acc / len(val_loader) val_emb_loss = avg_emb_loss / len(val_loader) scheduler.step(val_avg_acc) val_avg_indy = avg_indy_acc / len(val_loader) stats_string += "Val - Loss:{:.5f} | Acc:{:.5f} | Indy:{:.5f}\n" stats_string = stats_string.format(val_avg_loss, val_avg_acc, val_avg_indy) if hyps['masking_task']: stats_string += "Val Mask Loss:{:.5f} | Val Mask Acc:{:.5f}\n" stats_string = stats_string.format(mask_avg_loss, mask_avg_acc) elif model.transformer_type == models.DICTIONARY: stats_string += "Val Emb Loss:{:.5f}\n" stats_string = stats_string.format(val_emb_loss) if words is not None: stats_string += "Word: " + word_samp + "\n" stats_string += "Quest: " + question + "\n" stats_string += "Targ: " + targ_samp + "\n" stats_string += "Pred: " + pred_samp + "\n" optimizer.zero_grad() save_dict = { "epoch": epoch, "hyps": hyps, "train_loss": train_avg_loss, "train_acc": train_avg_acc, "train_indy": train_avg_indy, "mask_train_loss": mask_train_loss, "mask_train_acc": mask_train_acc, "val_loss": val_avg_loss, "val_acc": val_avg_acc, "val_indy": val_avg_indy, "mask_val_loss": mask_val_loss, "mask_val_acc": mask_val_acc, "state_dict": model.state_dict(), "optim_dict": optimizer.state_dict(), "word2idx": train_data.word2idx, "idx2word": train_data.idx2word, "sampled_types": train_data.sampled_types } save_name = "checkpt" save_name = os.path.join(hyps['save_folder'], save_name) io.save_checkpt(save_dict, save_name, epoch, ext=".pt", del_prev_sd=hyps['del_prev_sd']) stats_string += "Exec time: {}\n".format(time.time() - starttime) print(stats_string) s = "Epoch:{} | Model:{}\n".format(epoch, hyps['save_folder']) stats_string = s + stats_string log_file = os.path.join(hyps['save_folder'], "training_log.txt") with open(log_file, 'a') as f: f.write(str(stats_string) + '\n') del save_dict['state_dict'] del save_dict['optim_dict'] del save_dict['hyps'] save_dict['save_folder'] = hyps['save_folder'] return save_dict
def train(gpu, hyps, verbose=True): """ gpu: int the gpu for this training process hyps: dict contains all relavent hyperparameters """ rank = 0 if hyps['multi_gpu']: rank = hyps['n_gpus']*hyps['node_rank'] + gpu dist.init_process_group( backend="nccl", init_method="env://", world_size=hyps['world_size'], rank=rank) verbose = verbose and rank==0 hyps['rank'] = rank torch.cuda.set_device(gpu) test_batch_size = try_key(hyps,"test_batch_size",False) if test_batch_size and verbose: print("Testing batch size!! No saving will occur!") hyps['main_path'] = try_key(hyps,'main_path',"./") if "ignore_keys" not in hyps: hyps['ignore_keys'] = ["n_epochs", "batch_size", "max_context","rank", "n_loss_loops"] checkpt,hyps = get_resume_checkpt(hyps) if checkpt is None and rank==0: hyps['exp_num']=get_exp_num(hyps['main_path'], hyps['exp_name']) hyps['save_folder'] = get_save_folder(hyps) if rank>0: hyps['save_folder'] = "placeholder" if not os.path.exists(hyps['save_folder']) and\ not test_batch_size and rank==0: os.mkdir(hyps['save_folder']) # Set manual seed hyps['seed'] = try_key(hyps, 'seed', int(time.time()))+rank torch.manual_seed(hyps['seed']) np.random.seed(hyps['seed']) hyps['MASK'] = MASK hyps['START'] = START hyps['STOP'] = STOP model_class = hyps['model_class'] hyps['n_loss_loops'] = try_key(hyps,'n_loss_loops',1) if not hyps['init_decs'] and not hyps['ordered_preds'] and verbose: s = "WARNING!! You probably want to set ordered preds to True " s += "with your current configuration!!" print(s) if verbose: print("Retreiving Dataset") if "shuffle_split" not in hyps and hyps['shuffle']: hyps['shuffle_split'] = True train_data,val_data = datas.get_data(hyps) hyps['enc_slen'] = train_data.X.shape[-1] hyps['dec_slen'] = train_data.Y.shape[-1] hyps["mask_idx"] = train_data.X_tokenizer.token_to_id(MASK) hyps["dec_mask_idx"] = train_data.Y_tokenizer.token_to_id(MASK) hyps['n_vocab'] = train_data.X_tokenizer.get_vocab_size() hyps['n_vocab_out'] = train_data.Y_tokenizer.get_vocab_size() train_loader = datas.VariableLengthSeqLoader(train_data, samples_per_epoch=1000, shuffle=hyps['shuffle']) val_loader = datas.VariableLengthSeqLoader(val_data, samples_per_epoch=50, shuffle=True) if verbose: print("Making model") model = getattr(models,model_class)(**hyps) model.cuda(gpu) lossfxn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=hyps['lr'], weight_decay=hyps['l2']) if hyps['multi_gpu']: model, optimizer = amp.initialize(model, optimizer, opt_level='O0') model = DDP(model) # Load State Dicts if Resuming Training if checkpt is not None: if verbose: print("Loading state dicts from", hyps['save_folder']) model.load_state_dict(checkpt["state_dict"]) optimizer.load_state_dict(checkpt["optim_dict"]) epoch = checkpt['epoch'] if hyps['multi_gpu'] and "amp_dict" in checkpt: amp.load_state_dict(checkpt['amp_dict']) else: epoch = -1 scheduler = custmods.VaswaniScheduler(optimizer, hyps['emb_size']) if verbose: print("Beginning training for {}".format(hyps['save_folder'])) print("train shape:", (len(train_data),*train_data.X.shape[1:])) print("val shape:", (len(val_data),*val_data.X.shape[1:])) if not test_batch_size: record_session(hyps,model) if hyps['exp_name'] == "test": hyps['n_epochs'] = 2 mask_idx = train_data.Y_mask_idx stop_idx = train_data.Y_stop_idx step_num = 0 if checkpt is None else try_key(checkpt,'step_num',0) checkpt_steps = 0 if verbose: print() while epoch < hyps['n_epochs']: epoch += 1 if verbose: print("Epoch:{} | Step: {} | Model:{}".format(epoch, step_num, hyps['save_folder'])) print("Training...") starttime = time.time() avg_loss = 0 avg_acc = 0 avg_indy_acc = 0 checkpt_loss = 0 checkpt_acc = 0 model.train() optimizer.zero_grad() for b,(x,y) in enumerate(train_loader): if test_batch_size: x,y = train_loader.get_largest_batch(b) torch.cuda.empty_cache() targs = y.data[:,1:] og_shape = targs.shape y = y[:,:-1] logits = model(x.cuda(non_blocking=True), y.cuda(non_blocking=True)) preds = torch.argmax(logits,dim=-1) if epoch % 3 == 0 and b == 0 and verbose: inp = x.data[0].cpu().numpy() trg = targs.data[0].numpy() prd = preds.data[0].cpu().numpy() print("Inp:", train_data.X_idxs2tokens(inp)) print("Targ:", train_data.Y_idxs2tokens(trg)) print("Pred:", train_data.Y_idxs2tokens(prd)) # Tot loss logits = logits.reshape(-1,logits.shape[-1]) targs = targs.reshape(-1).cuda(non_blocking=True) bitmask = targs!=mask_idx loss = lossfxn(logits[bitmask],targs[bitmask]) loss = loss/hyps['n_loss_loops'] if hyps['multi_gpu']: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if b % hyps['n_loss_loops'] == 0 or b == len(train_loader)-1: optimizer.step() optimizer.zero_grad() step_num += 1 scheduler.update_lr(step_num) with torch.no_grad(): # Acc preds = preds.reshape(-1) sl = og_shape[-1] eq = (preds==targs).float() indy_acc = eq[bitmask].mean() eq[~bitmask] = 1 eq = eq.reshape(og_shape) acc = (eq.sum(-1)==sl).float().mean() avg_acc += acc.item() avg_indy_acc += indy_acc.item() avg_loss += loss.item() checkpt_acc += acc.item() checkpt_loss += loss.item() s = "Loss:{:.5f} | Acc:{:.5f} | {:.0f}%" s = s.format(loss.item(), acc.item(), b/len(train_loader)*100) if verbose: print(s, end=len(s)*" " + "\r") if hyps['exp_name'] == "test" and b>5: break optimizer.zero_grad() train_avg_loss = avg_loss/len(train_loader) train_avg_acc = avg_acc/len(train_loader) train_avg_indy = avg_indy_acc/len(train_loader) s = "Ending Step Count: {}\n".format(step_num) s = s+"Train - Loss:{:.5f} | Acc:{:.5f} | Indy:{:.5f}\n" stats_string = s.format(train_avg_loss, train_avg_acc, train_avg_indy) ###### VALIDATION model.eval() avg_bleu = 0 avg_loss = 0 avg_acc = 0 avg_indy_acc = 0 if verbose: print("\nValidating...") torch.cuda.empty_cache() if rank==0: with torch.no_grad(): rand_word_batch = int(np.random.randint(0, len(val_loader))) for b,(x,y) in enumerate(val_loader): if test_batch_size: x,y = val_loader.get_largest_batch(b) targs = y.data[:,1:] og_shape = targs.shape y = y[:,:-1] preds = model(x.cuda(non_blocking=True), y.cuda(non_blocking=True)) # Tot loss and acc preds = preds.reshape(-1,preds.shape[-1]) targs = targs.reshape(-1).cuda(non_blocking=True) bitmask = targs!=mask_idx loss = lossfxn(preds[bitmask],targs[bitmask]) if hyps['multi_gpu']: loss = loss.mean() preds = torch.argmax(preds,dim=-1) sl = int(og_shape[-1]) eq = (preds==targs).float() indy_acc = eq[bitmask].mean() eq[~bitmask] = 1 eq = eq.reshape(og_shape) acc = (eq.sum(-1)==sl).float().mean() bleu_trgs=targs.reshape(og_shape).data.cpu().numpy() trg_ends = np.argmax((bleu_trgs==stop_idx),axis=1) bleu_prds=preds.reshape(og_shape).data.cpu().numpy() prd_ends = np.argmax((bleu_prds==stop_idx),axis=1) btrgs = [] bprds = [] for i in range(len(bleu_trgs)): temp = bleu_trgs[i,None,:trg_ends[i]].tolist() btrgs.append(temp) bprds.append(bleu_prds[i,:prd_ends[i]].tolist()) bleu = corpus_bleu(btrgs,bprds) avg_bleu += bleu avg_acc += acc.item() avg_indy_acc += indy_acc.item() avg_loss += loss.item() if b == rand_word_batch or hyps['exp_name']=="test": rand = int(np.random.randint(0,len(x))) inp = x.data[rand].cpu().numpy() inp_samp = val_data.X_idxs2tokens(inp) trg = targs.reshape(og_shape)[rand].data.cpu() targ_samp = val_data.Y_idxs2tokens(trg.numpy()) prd = preds.reshape(og_shape)[rand].data.cpu() pred_samp = val_data.Y_idxs2tokens(prd.numpy()) s="Loss:{:.5f} | Acc:{:.5f} | Bleu:{:.5f} | {:.0f}%" s = s.format(loss.item(), acc.item(), bleu, b/len(val_loader)*100) if verbose: print(s, end=len(s)*" " + "\r") if hyps['exp_name']=="test" and b > 5: break if verbose: print() val_avg_bleu = avg_bleu/len(val_loader) val_avg_loss = avg_loss/len(val_loader) val_avg_acc = avg_acc/len(val_loader) val_avg_indy = avg_indy_acc/len(val_loader) stats_string += "Val- Loss:{:.5f} | Acc:{:.5f} | " stats_string += "Indy:{:.5f}\nVal Bleu: {:.5f}\n" stats_string = stats_string.format(val_avg_loss,val_avg_acc, val_avg_indy, val_avg_bleu) stats_string += "Inp: " + inp_samp + "\n" stats_string += "Targ: " + targ_samp + "\n" stats_string += "Pred: " + pred_samp + "\n" optimizer.zero_grad() if not test_batch_size and rank==0: save_dict = { "epoch":epoch, "step_num":step_num, "hyps":hyps, "train_loss":train_avg_loss, "train_acc":train_avg_acc, "train_indy":train_avg_indy, "val_bleu":val_avg_bleu, "val_loss":val_avg_loss, "val_acc":val_avg_acc, "val_indy":val_avg_indy, "state_dict":model.state_dict(), "optim_dict":optimizer.state_dict(), } if hyps['multi_gpu']: save_dict['amp_dict']=amp.state_dict() save_name = "checkpt" save_name = os.path.join(hyps['save_folder'],save_name) io.save_checkpt(save_dict, save_name, epoch, ext=".pt", del_prev_sd=hyps['del_prev_sd']) stats_string += "Exec time: {}\n".format(time.time()-starttime) if verbose: print(stats_string) s = "Epoch:{} | Model:{}\n".format(epoch, hyps['save_folder']) stats_string = s + stats_string log_file = os.path.join(hyps['save_folder'], "training_log"+str(rank)+".txt") if not test_batch_size: with open(log_file,'a') as f: f.write(str(stats_string)+'\n') if rank==0: del save_dict['state_dict'] del save_dict['optim_dict'] del save_dict['hyps'] save_dict['save_folder'] = hyps['save_folder'] return save_dict return None