Beispiel #1
0
    def run(self, net):
        """
        run is the entry function to begin collecting rollouts from the
        environment using the specified net. gate_q indicates when to begin
        collecting a rollout and is controlled from the main process.
        The stop_q is used to indicate to the main process that a new rollout
        has been collected.

        net - torch Module object. This is the model to interact with the
            environment.
        """
        self.net = net
        float_params = try_key(self.hyps, "float_params", dict())
        self.env = SeqEnv(self.hyps['env_type'],
                          self.hyps['seed'],
                          worker_id=None,
                          float_params=float_params)
        state = next_state(self.env,
                           self.obs_deque,
                           obs=None,
                           reset=True,
                           preprocess=self.hyps['preprocess'])
        self.state_bookmark = state
        self.ep_rew = 0
        self.net.train(mode=False)  # fixes batchnorm issues
        for p in self.net.parameters():  # Turn off gradient collection
            p.requires_grad = False
        with torch.no_grad():
            while self.end_q.empty():
                idx = self.gate_q.get()  # Opened from main process
                self.rollout(self.net, idx, self.hyps)
                self.stop_q.put(
                    idx
                )  # Signals to main process that data has been collected
Beispiel #2
0
 def __init__(self, hyps):
     """
     hyps - dict object with all necessary hyperparameters
         keys (Assume string type keys):
             "n_tsteps" - number of steps to be taken in the
                         environment
             "n_frame_stack" - number of frames to stack for
                         creation of the mdp state
             "preprocessor" - function to preprocess raw observations
             "env_type" - type of gym environment to be interacted 
                         with. Follows OpenAI's gym api.
             "seed" - the random seed for the env
     """
     self.hyps = hyps
     self.env = SequentialEnvironment(**self.hyps)
     self.obs_deque = deque(maxlen=self.hyps['n_frame_stack'])
     self.n_episodes = try_key(self.hyps, "n_test_eps", 15)
Beispiel #3
0
def train(_, hyps, verbose=True):
    """
    hyps - dictionary of required hyperparameters
        type: dict
    """
    # Hyperparam corrections
    if isinstance(try_key(hyps, "grid_size", None), int):
        hyps['grid_size'] = [hyps['grid_size'], hyps['grid_size']]

    # Preprocessor Type
    env_type = hyps['env_type'].lower()
    hyps['preprocessor'] = getattr(preprocessing, hyps['prep_fxn'])

    hyps['main_path'] = try_key(hyps, "main_path", "./")
    hyps['exp_num'] = get_exp_num(hyps['main_path'], hyps['exp_name'])
    hyps['save_folder'] = get_save_folder(hyps)
    save_folder = hyps['save_folder']
    if not os.path.exists(hyps['save_folder']):
        os.mkdir(hyps['save_folder'])
    hyps['seed'] = try_key(hyps, 'seed', int(time.time()))
    torch.manual_seed(hyps['seed'])
    np.random.seed(hyps['seed'])

    net_save_file = os.path.join(save_folder, "net.p")
    best_net_file = os.path.join(save_folder, "best_net.p")
    optim_save_file = os.path.join(save_folder, "optim.p")
    log_file = os.path.join(save_folder, "log.txt")

    if hyps['resume']: log = open(log_file, 'a')
    else: log = open(log_file, 'w')
    keys = sorted(list(hyps.keys()))
    for k in keys:
        log.write(k + ":" + str(hyps[k]) + "\n")

    # Miscellaneous Variable Prep
    logger = Logger()
    shared_len = hyps['n_tsteps'] * hyps['n_rollouts']
    stats_runner = StatsRunner(hyps)
    env = stats_runner.env
    hyps['is_discrete'] = env.is_discrete
    obs = env.reset()
    hyps['state_shape'] = [hyps['n_frame_stack']] + [*obs.shape[1:]]
    if hyps['env_type'] == "Pong-v0":
        action_size = 3
        hyps['action_shift'] = 1
    else:
        action_size = env.n
        hyps['action_shift'] = 0
    print("Raw Obs Shape:", env.raw_shape)
    print("Obs Shape:,", obs.shape)
    print("State Shape:,", hyps['state_shape'])
    print("Num Samples Per Update:", shared_len)

    # Make Network
    hyps["action_size"] = action_size
    net = globals()[hyps['model']](hyps['state_shape'],
                                   action_size,
                                   bnorm=hyps['use_bnorm'],
                                   **hyps)
    if try_key(hyps, 'resume', False):
        net.load_state_dict(torch.load(net_save_file))
    base_net = copy.deepcopy(net)
    net = cuda_if(net)
    net.share_memory()
    base_net = cuda_if(base_net)

    # Prepare Shared Variables
    states_shape = (shared_len, *hyps['state_shape'])
    if env.is_discrete:
        actions = torch.zeros(shared_len).long()
    else:
        actions = torch.zeros((shared_len, env.n)).float()
    shared_data = {
        'states': cuda_if(torch.zeros(states_shape).share_memory_()),
        'deltas': cuda_if(torch.zeros(shared_len).share_memory_()),
        'rewards': cuda_if(torch.zeros(shared_len).share_memory_()),
        'actions': actions.share_memory_(),
        'dones': cuda_if(torch.zeros(shared_len).share_memory_())
    }
    if net.is_recurrent:
        zeros = torch.zeros(shared_len, net.h_size)
        shared_data['h_states'] = cuda_if(zeros.share_memory_())
    n_rollouts = hyps['n_rollouts']
    gate_q = mp.Queue(n_rollouts)
    stop_q = mp.Queue(n_rollouts)
    reward_q = mp.Queue(1)
    reward_q.put(-1)

    # Make Runners
    runners = []
    for i in range(hyps['n_envs']):
        runner = Runner(shared_data, hyps, gate_q, stop_q, reward_q)
        runners.append(runner)

    # Start Data Collection
    print("Making New Processes")
    procs = []
    for i in range(len(runners)):
        proc = mp.Process(target=runners[i].run, args=(net, ))
        procs.append(proc)
        proc.start()
        print(i, "/", len(runners), end='\r')
    for i in range(n_rollouts):
        gate_q.put(i)

    # Make Updater
    updater = Updater(base_net, hyps)
    if hyps['resume']:
        updater.optim.load_state_dict(torch.load(optim_save_file))
    updater.optim.zero_grad()
    updater.net.train(mode=True)
    updater.net.req_grads(True)

    # Prepare Decay Precursors
    entr_coef_diff = hyps['entr_coef'] - hyps['entr_coef_low']
    lr_diff = hyps['lr'] - hyps['lr_low']
    gamma_diff = hyps['gamma_high'] - hyps['gamma']

    # Training Loop
    past_rews = deque([0] * hyps['n_past_rews'])
    last_avg_rew = 0
    best_eval_rew = -np.inf
    epoch = 0
    T = 0
    while T < hyps['max_tsteps']:
        basetime = time.time()
        epoch += 1
        stats_string = ""

        # Collect data
        for i in range(n_rollouts):
            stop_q.get()
        T += shared_len
        s = "Epoch {} - T: {} -- {}".format(epoch, T, hyps['save_folder'])
        print(s)
        stats_string += s + "\n"

        # Reward Stats
        avg_reward = reward_q.get()
        reward_q.put(avg_reward)
        last_avg_rew = avg_reward

        # Calculate the Loss and Update nets
        updater.update_model(shared_data)
        # update all collector nets
        net.load_state_dict(updater.net.state_dict())

        eval_rew = stats_runner.rollout(net)
        if eval_rew > best_eval_rew:
            best_eval_rew = eval_rew
            updater.save_model(best_net_file, None)
        stats_string += "Eval rew: {}\n".format(eval_rew)

        # Resume Data Collection
        for i in range(n_rollouts):
            gate_q.put(i)

        # Decay HyperParameters
        if hyps['decay_lr']:
            decay_factor = max((1 - T / (hyps['max_tsteps'])), 0)
            new_lr = decay_factor * lr_diff + hyps['lr_low']
            updater.new_lr(new_lr)
            s = "New lr: " + str(new_lr)
            print(s)
            stats_string += s + "\n"
        if hyps['decay_entr']:
            decay_factor = max((1 - T / (hyps['max_tsteps'])), 0)
            updater.entr_coef = entr_coef_diff * decay_factor
            updater.entr_coef += hyps['entr_coef_low']
            s = "New Entr: " + str(updater.entr_coef)
            print(s)
            stats_string += s + "\n"

        # Periodically save model
        if epoch % 10 == 0:
            updater.save_model(net_save_file, optim_save_file)

        # Print Epoch Data
        past_rews.popleft()
        past_rews.append(avg_reward)
        max_rew, min_rew = deque_maxmin(past_rews)
        rew_avg, rew_std = np.mean(past_rews), np.std(past_rews)
        updater.print_statistics()
        avg_action = shared_data['actions'].float().mean().item()
        s = "Grad Norm: {:.5f} – Avg Action: {:.5f} - Best EvalRew: {:.5f}"
        s = s.format(float(updater.norm), avg_action, best_eval_rew)
        stats_string += s + "\n"
        stats_string += "Avg Rew: " + str(avg_reward) + "\n"
        s = "Past " + str(hyps['n_past_rews']) + " Rews – High: {:.5f}"
        s += " - Low: {:.5f} - Avg: {:.5f} - StD: {:.5f}"
        stats_string += s.format(max_rew, min_rew, rew_avg, rew_std) + "\n"
        updater.log_statistics(log, T, avg_reward, avg_action, best_eval_rew)
        updater.info["EvalRew"] = eval_rew
        updater.info['AvgRew'] = avg_reward
        logger.append(updater.info, x_val=T)

        # Check for memory leaks
        gc.collect()
        max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        s = "Time: " + str(time.time() - basetime)
        stats_string += s + "\n"
        print(stats_string)
        log.write(stats_string + "\n")
        if 'hyp_search_count' in hyps and hyps['hyp_search_count'] > 0\
                                         and hyps['search_id'] != None:
            print("Search:", hyps['search_id'], "/", hyps['hyp_search_count'])
        print("Memory Used: {:.2f} memory\n".format(max_mem_used / 1024))

    logger.make_plots(save_folder + hyps['exp_name'])
    log.write("\nBestRew:" + str(best_eval_rew))
    log.close()

    # Close processes
    for p in procs:
        p.terminate()

    return best_eval_rew
    def train(self, hyps): 
        """
        hyps - dictionary of required hyperparameters
            type: dict
        """

        # Initial settings
        if "randomizeObjs" in hyps:
            assert False, "you mean randomizeObs, not randomizeObjs"
        if "audibleTargs" in hyps and hyps['audibleTargs'] > 0:
            hyps['aud_targs'] = True
            if verbose: print("Using audible targs!")
        countOut = try_key(hyps, 'countOut', 0)
        if countOut and not hyps['endAtOrigin']:
            assert False, "endAtOrigin must be true for countOut setting"

        # Print Hyperparameters To Screen
        items = list(hyps.items())
        for k, v in sorted(items):
            print(k+":", v)

        # Make Save Files
        if "save_folder" in hyps:
            save_folder = hyps['save_folder']
        else:
            save_folder = "./saved_data/"

        if not os.path.exists(save_folder):
            os.mkdir(save_folder)
        base_name = save_folder + hyps['exp_name']
        net_save_file = base_name+"_net.p"
        fwd_save_file = base_name+"_fwd.p"
        best_net_file = base_name+"_best.p"
        optim_save_file = base_name+"_optim.p"
        fwd_optim_file = base_name+"_fwdoptim.p"
        hyps['fwd_emb_file'] = base_name+"_fwdemb.p"
        if hyps['inv_model'] is not None:
            inv_save_file = base_name+"_invnet.p"
            reconinv_optim_file = base_name+"_reconinvoptim.p"
        else:
            inv_save_file = None
            reconinv_optim_file = None
        if hyps['recon_model'] is not None:
            recon_save_file = base_name+"_reconnet.p"
            reconinv_optim_file = base_name+"_reconinvoptim.p"
        else:
            recon_save_file = None
        log_file = base_name+"_log.txt"
        if hyps['resume']: log = open(log_file, 'a')
        else: log = open(log_file, 'w')
        for k, v in sorted(items):
            log.write(k+":"+str(v)+"\n")

        # Miscellaneous Variable Prep
        logger = Logger()
        shared_len = hyps['n_tsteps']*hyps['n_rollouts']
        float_params = dict()
        if "float_params" not in hyps:
            try:
                keys = hyps['game_keys']
                hyps['float_params'] = {k:try_key(hyps,k,0) for k in keys}
                if "minObjLoc" not in hyps:
                    hyps['float_params']["minObjLoc"] = 0.27
                    hyps['float_params']["maxObjLoc"] = 0.73
                float_params = hyps['float_params']
            except: pass
        env = SeqEnv(hyps['env_type'], hyps['seed'],
                                            worker_id=None,
                                            float_params=float_params)
        hyps['discrete_env'] = hasattr(env.action_space, "n")
        obs = env.reset()
        prepped = hyps['preprocess'](obs)
        hyps['state_shape'] = [hyps['n_frame_stack']*prepped.shape[0],
                              *prepped.shape[1:]]
        if not hyps['discrete_env']:
            action_size = int(np.prod(env.action_space.shape))
        elif hyps['env_type'] == "Pong-v0":
            action_size = 3
        else:
            action_size = env.action_space.n
        hyps['action_shift'] = (4-action_size)*(hyps['env_type']=="Pong-v0") 
        print("Obs Shape:,",obs.shape)
        print("Prep Shape:,",prepped.shape)
        print("State Shape:,",hyps['state_shape'])
        print("Num Samples Per Update:", shared_len)
        if not (hyps['n_cache_refresh'] <= shared_len or hyps['cache_size'] == 0):
            hyps['n_cache_refresh'] = shared_len
        print("Samples Wasted in Update:", shared_len % hyps['batch_size'])
        try: env.close()
        except: pass
        del env

        # Prepare Shared Variables
        shared_data = {
            'states': torch.zeros(shared_len,
                                *hyps['state_shape']).share_memory_(),
            'next_states': torch.zeros(shared_len,
                                *hyps['state_shape']).share_memory_(),
            'dones':torch.zeros(shared_len).share_memory_(),
            'rews':torch.zeros(shared_len).share_memory_(),
            'hs':torch.zeros(shared_len,hyps['h_size']).share_memory_(),
            'next_hs':torch.zeros(shared_len,hyps['h_size']).share_memory_()}
        if hyps['discrete_env']:
            shared_data['actions'] = torch.zeros(shared_len).long().share_memory_()
        else:
            shape = (shared_len, action_size)
            shared_data['actions']=torch.zeros(shape).float().share_memory_()
        shared_data = {k: cuda_if(v) for k,v in shared_data.items()}
        n_rollouts = hyps['n_rollouts']
        gate_q = mp.Queue(n_rollouts)
        stop_q = mp.Queue(n_rollouts)
        end_q = mp.Queue(1)
        reward_q = mp.Queue(1)
        reward_q.put(-1)

        # Make Runners
        runners = []
        for i in range(hyps['n_envs']):
            runner = Runner(shared_data, hyps, gate_q, stop_q,
                                                       end_q,
                                                       reward_q)
            runners.append(runner)

        # Make the Networks
        h_size = hyps['h_size']
        net = hyps['model'](hyps['state_shape'], action_size, h_size,
                                            bnorm=hyps['use_bnorm'],
                                            lnorm=hyps['use_lnorm'],
                                            discrete_env=hyps['discrete_env'])
        # Fwd Dynamics
        hyps['is_recurrent'] = hasattr(net, "fresh_h")
        intl_size = h_size+action_size + hyps['is_recurrent']*h_size
        if hyps['fwd_lnorm']:
            block = [nn.LayerNorm(intl_size)]
        block = [nn.Linear(intl_size, h_size), 
            nn.ReLU(), nn.Linear(h_size, h_size), 
            nn.ReLU(), nn.Linear(h_size, h_size)]
        fwd_net = nn.Sequential(*block)
        # Allows us to argue an h vector along with embedding to
        # forward func
        if hyps['is_recurrent']:
            fwd_net = CatModule(fwd_net) 
        if hyps['ensemble']:
            fwd_net = Ensemble(fwd_net)
        fwd_net = cuda_if(fwd_net)

        if hyps['inv_model'] is not None:
            inv_net = hyps['inv_model'](h_size, action_size)
            inv_net = cuda_if(inv_net)
        else:
            inv_net = None
        if hyps['recon_model'] is not None:
            recon_net = hyps['recon_model'](emb_size=h_size,
                                       img_shape=hyps['state_shape'],
                                       fwd_bnorm=hyps['fwd_bnorm'],
                                       deconv_ksizes=hyps['recon_ksizes'])
            recon_net = cuda_if(recon_net)
        else:
            recon_net = None
        if hyps['resume']:
            net.load_state_dict(torch.load(net_save_file))
            fwd_net.load_state_dict(torch.load(fwd_save_file))
            if inv_net is not None:
                inv_net.load_state_dict(torch.load(inv_save_file))
            if recon_net is not None:
                recon_net.load_state_dict(torch.load(recon_save_file))
        base_net = copy.deepcopy(net)
        net = cuda_if(net)
        net.share_memory()
        base_net = cuda_if(base_net)
        hyps['is_recurrent'] = hasattr(net, "fresh_h")

        # Start Data Collection
        print("Making New Processes")
        procs = []
        for i in range(len(runners)):
            proc = mp.Process(target=runners[i].run, args=(net,))
            procs.append(proc)
            proc.start()
            print(i, "/", len(runners), end='\r')
        for i in range(n_rollouts):
            gate_q.put(i)

        # Make Updater
        updater = Updater(base_net, fwd_net, hyps, inv_net, recon_net)
        if hyps['resume']:
            updater.optim.load_state_dict(torch.load(optim_save_file))
            updater.fwd_optim.load_state_dict(torch.load(fwd_optim_file))
            if inv_net is not None:
                updater.reconinv_optim.load_state_dict(torch.load(reconinv_optim_file))
        updater.optim.zero_grad()
        updater.net.train(mode=True)
        updater.net.req_grads(True)

        # Prepare Decay Precursors
        entr_coef_diff = hyps['entr_coef'] - hyps['entr_coef_low']
        epsilon_diff = hyps['epsilon'] - hyps['epsilon_low']
        lr_diff = hyps['lr'] - hyps['lr_low']
        gamma_diff = hyps['gamma_high'] - hyps['gamma']

        # Training Loop
        past_rews = deque([0]*hyps['n_past_rews'])
        last_avg_rew = 0
        best_rew_diff = 0
        best_avg_rew = -10000
        best_eval_rew = -10000
        ep_eval_rew = 0
        eval_rew = 0

        epoch = 0
        done_count = 0
        T = 0
        try:
            while T < hyps['max_tsteps']:
                basetime = time.time()
                epoch += 1

                # Collect data
                for i in range(n_rollouts):
                    stop_q.get()
                T += shared_len

                # Reward Stats
                avg_reward = reward_q.get()
                reward_q.put(avg_reward)
                last_avg_rew = avg_reward
                done_count += shared_data['dones'].sum().item()
                new_best = False
                if avg_reward > best_avg_rew and done_count > n_rollouts:
                    new_best = True
                    best_avg_rew = avg_reward
                    updater.save_model(best_net_file, fwd_save_file,
                                                         None, None)
                eval_rew = shared_data['rews'].mean()
                if eval_rew > best_eval_rew:
                    best_eval_rew = eval_rew
                    save_names = [net_save_file, fwd_save_file,
                                                 optim_save_file,
                                                 fwd_optim_file,
                                                 inv_save_file,
                                                 recon_save_file,
                                                 reconinv_optim_file]

                    for i in range(len(save_names)):
                        if save_names[i] is not None:
                            splt = save_names[i].split(".")
                            splt[0] = splt[0]+"_best"
                            save_names[i] = ".".join(splt)
                    updater.save_model(*save_names)
                s = "EvalRew: {:.5f} | BestEvalRew: {:.5f}"
                print(s.format(eval_rew, best_eval_rew))

                # Calculate the Loss and Update nets
                updater.update_model(shared_data)
                net.load_state_dict(updater.net.state_dict()) # update all collector nets
                
                # Resume Data Collection
                for i in range(n_rollouts):
                    gate_q.put(i)

                # Decay HyperParameters
                if hyps['decay_eps']:
                    updater.epsilon = (1-T/(hyps['max_tsteps']))*epsilon_diff + hyps['epsilon_low']
                    print("New Eps:", updater.epsilon)
                if hyps['decay_lr']:
                    new_lr = (1-T/(hyps['max_tsteps']))*lr_diff + hyps['lr_low']
                    updater.new_lr(new_lr)
                    print("New lr:", new_lr)
                if hyps['decay_entr']:
                    updater.entr_coef = entr_coef_diff*(1-T/(hyps['max_tsteps']))+hyps['entr_coef_low']
                    print("New Entr:", updater.entr_coef)
                if hyps['incr_gamma']:
                    updater.gamma = gamma_diff*(T/(hyps['max_tsteps']))+hyps['gamma']
                    print("New Gamma:", updater.gamma)

                # Periodically save model
                if epoch % 10 == 0 or epoch == 1:
                    updater.save_model(net_save_file, fwd_save_file,
                                                      optim_save_file,
                                                      fwd_optim_file,
                                                      inv_save_file,
                                                      recon_save_file,
                                                      reconinv_optim_file)

                # Print Epoch Data
                past_rews.popleft()
                past_rews.append(avg_reward)
                max_rew, min_rew = deque_maxmin(past_rews)
                print("Epoch", epoch, "– T =", T, "-- Folder:", base_name)
                if not hyps['discrete_env']:
                    s = ("{:.5f} | "*net.logsigs.shape[1])
                    s = s.format(*[x.item() for x in torch.exp(net.logsigs[0])])
                    print("Sigmas:", s)
                updater.print_statistics()
                avg_action = shared_data['actions'].float().mean().item()
                print("Grad Norm:",float(updater.norm),"– Avg Action:",avg_action,"– Best AvgRew:",best_avg_rew)
                print("Avg Rew:", avg_reward, "– High:", max_rew, "– Low:", min_rew, end='\n')
                updater.log_statistics(log, T, avg_reward, avg_action, best_avg_rew)
                updater.info['AvgRew'] = avg_reward
                updater.info['EvalRew'] = eval_rew
                logger.append(updater.info, x_val=T)

                # Check for memory leaks
                gc.collect()
                max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
                print("Time:", time.time()-basetime)
                if 'hyp_search_count' in hyps and hyps['hyp_search_count'] > 0 and hyps['search_id'] != None:
                    print("Search:", hyps['search_id'], "/", hyps['hyp_search_count'])
                print("Memory Used: {:.2f} memory\n".format(max_mem_used / 1024))
                if updater.info["VLoss"] == float('inf') or updater.norm == float('inf'):
                    break
        except KeyboardInterrupt:
            pass

        end_q.put(1)
        time.sleep(1)
        logger.make_plots(base_name)
        log.write("\nBestRew:"+str(best_avg_rew))
        log.close()
        # Close processes
        for p in procs:
            p.terminate()
        return best_avg_rew
Beispiel #5
0
print("Making model")
model = getattr(models,hyps['model_class'])(**hyps)
model.to(DEVICE)
try:
    model.load_state_dict(checkpt["state_dict"])
except:
    keys = list(checkpt['state_dict'].keys())
    for k in keys:
        if "pavlov" == k.split(".")[0]:
            del checkpt['state_dict'][k]
    model.load_state_dict(checkpt['state_dict'])
    print("state dict success")
model.eval()

fwd_dynamics = try_key(hyps,'use_fwd_dynamics',False) and\
               try_key(hyps,"countOut",False)
if fwd_dynamics:
    fwd_model = getattr(models,hyps['fwd_class'])(**hyps)
    fwd_model.cuda()
    fwd_model.load_state_dict(checkpt['fwd_state_dict'])
    fwd_model.eval()
else:
    fwd_model = DummyFwdModel()

done = True
sum_rew = 0
n_loops = 0
frames = []
obscatpreds = []
with torch.no_grad():
import locgame.save_io as locio
import ml_utils.save_io as mlio
import ml_utils.analysis as mlanl
from ml_utils.utils import try_key
import pandas as pd
import os
import sys

if __name__ == "__main__":
    argued_folders = sys.argv[1:]
    model_folders = []
    for folder in argued_folders:
        if not mlio.is_model_folder(folder):
            model_folders += mlio.get_model_folders(folder, True)
        else:
            model_folders += [folder]
    print("Model Folders:", model_folders)
    for model_folder in model_folders:
        checkpts = mlio.get_checkpoints(model_folder)
        if len(checkpts) == 0: continue
        table = mlanl.get_table(mlio.load_checkpoint(checkpts[0]))
        for checkpt in checkpts:
            chkpt = mlio.load_checkpoint(checkpt)
            for k in table.keys():
                if k in set(chkpt.keys()):
                    table[k].append(chkpt[k])
        df = pd.DataFrame(table)
        df['seed'] = try_key(chkpt['hyps'], 'seed', -1)
        save_path = os.path.join(model_folder, "model_data.csv")
        df.to_csv(save_path, sep="!", index=False, header=True)
Beispiel #7
0
def train(hyps, verbose=True):
    """
    hyps: dict
        contains all relavent hyperparameters
    """
    # Set manual seed
    hyps['exp_num'] = get_exp_num(hyps['main_path'], hyps['exp_name'])
    hyps['save_folder'] = get_save_folder(hyps)
    if not os.path.exists(hyps['save_folder']):
        os.mkdir(hyps['save_folder'])
    hyps['seed'] = try_key(hyps, 'seed', int(time.time()))
    torch.manual_seed(hyps['seed'])
    np.random.seed(hyps['seed'])
    model_class = hyps['model_class']
    hyps['model_type'] = models.TRANSFORMER_TYPE[model_class]
    if not hyps['init_decs'] and not hyps['gen_decs'] and\
                                not hyps['ordered_preds']:
        s = "WARNING!! You probably want to set ordered preds to True "
        s += "with your current configuration!!"
        print(s)

    if verbose:
        print("Retreiving Dataset")
    if "shuffle_split" not in hyps and hyps['shuffle']:
        hyps['shuffle_split'] = True
    train_data, val_data = datas.get_data(**hyps)
    hyps['enc_slen'] = train_data.X.shape[-1]
    hyps['dec_slen'] = train_data.Y.shape[-1] - 1
    #if hyps[
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=hyps['batch_size'],
                                               shuffle=hyps['shuffle'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=hyps['batch_size'])
    hyps['n_vocab'] = len(train_data.word2idx.keys())

    if verbose:
        print("Making model")
    model = getattr(models, model_class)(**hyps)
    model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hyps['lr'],
                                 weight_decay=hyps['l2'])
    init_checkpt = try_key(hyps, "init_checkpt", None)
    if init_checkpt is not None and init_checkpt != "":
        if verbose:
            print("Loading state dicts from", init_checkpt)
        checkpt = io.load_checkpoint(init_checkpt)
        model.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_dict"])
    lossfxn = nn.CrossEntropyLoss()
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=0.5,
                                  patience=6,
                                  verbose=True)
    if model.transformer_type != models.DICTIONARY:
        hyps['emb_alpha'] = 0
    if verbose:
        print("Beginning training for {}".format(hyps['save_folder']))
        print("train shape:", train_data.X.shape)
        print("val shape:", val_data.X.shape)
        print("n_vocab:", hyps['n_vocab'])

    record_session(hyps, model)
    if hyps['dataset'] == "WordProblem":
        save_data_structs(train_data.samp_structs)

    if hyps['exp_name'] == "test":
        hyps['n_epochs'] = 2
    epoch = -1
    alpha = hyps['loss_alpha']
    emb_alpha = hyps['emb_alpha']
    print()
    idx2word = train_data.idx2word
    mask_idx = train_data.word2idx["<MASK>"]
    while epoch < hyps['n_epochs']:
        epoch += 1
        print("Epoch:{} | Model:{}".format(epoch, hyps['save_folder']))
        starttime = time.time()
        avg_loss = 0
        avg_acc = 0
        avg_indy_acc = 0
        avg_emb_loss = 0
        mask_avg_loss = 0
        mask_avg_acc = 0
        model.train()
        print("Training...")
        optimizer.zero_grad()
        for b, (x, y) in enumerate(train_loader):
            torch.cuda.empty_cache()
            if model.transformer_type == models.AUTOENCODER:
                targs = x.data[:, 1:]
                y = x.data
            elif model.transformer_type == models.DICTIONARY:
                targs = x.data[:, 1:]
                emb_targs = model.embeddings(y.to(DEVICE))
                word_to_define = idx2word[y.squeeze()[0].item()]
                y = x.data
            else:
                targs = y.data[:, 1:]
            og_shape = targs.shape
            if hyps['masking_task']:
                x, y, mask = mask_words(x, y, mask_p=hyps['mask_p'])
            y = y[:, :-1]
            preds = model(x.to(DEVICE), y.to(DEVICE))

            tot_loss = 0
            if model.transformer_type == models.DICTIONARY:
                emb_preds = preds[1]
                preds = preds[0]
                emb_loss = F.mse_loss(emb_preds, emb_targs.data)
                tot_loss += (emb_alpha) * emb_loss
                avg_emb_loss += emb_loss.item()

            if epoch % 3 == 0 and b == 0:
                if model.transformer_type == models.DICTIONARY:
                    print("Word:", word_to_define)
                whr = torch.where(y[0] == mask_idx)[0]
                endx = y.shape[-1] if len(whr) == 0 else whr[0].item()
                print("y:", [idx2word[a.item()] for a in y[0, :endx]])
                print("t:", [idx2word[a.item()] for a in targs[0, :endx - 1]])
                ms = torch.argmax(preds, dim=-1)
                print("p:", [idx2word[a.item()] for a in ms[0, :endx - 1]])
                del ms
            targs = targs.reshape(-1)

            if hyps['masking_task']:
                print("masking!")
                # Mask loss and acc
                preds = preds.reshape(-1, preds.shape[-1])
                mask = mask.reshape(-1).bool()
                idxs = torch.arange(len(mask))[mask]
                mask_preds = preds[idxs]
                mask_targs = targs[idxs]
                mask_loss = lossfxn(mask_preds, mask_targs)
                mask_preds = torch.argmax(mask_preds, dim=-1)
                mask_acc = (mask_preds == mask_targs).sum().float()
                mask_acc = mask_acc / idxs.numel()

                mask_avg_acc += mask_acc.item()
                mask_avg_loss += mask_loss.item()
            else:
                mask_loss = torch.zeros(1).to(DEVICE)
                mask_acc = torch.zeros(1).to(DEVICE)

                mask_avg_acc += mask_acc.item()
                mask_avg_loss += mask_loss.item()

            # Tot loss and acc
            preds = preds.reshape(-1, preds.shape[-1])
            targs = targs.reshape(-1).to(DEVICE)
            if not hyps['masking_task']:
                bitmask = (targs != mask_idx)
                loss = (1 - emb_alpha) * lossfxn(preds[bitmask],
                                                 targs[bitmask])
            else:
                loss = lossfxn(preds, targs)

            if hyps['masking_task']:
                temp = ((alpha) * loss + (1 - alpha) * mask_loss)
                tot_loss += temp / hyps['n_loss_loops']
            else:
                tot_loss += loss / hyps['n_loss_loops']
            tot_loss.backward()

            if b % hyps['n_loss_loops'] == 0 or b == len(train_loader) - 1:
                optimizer.step()
                optimizer.zero_grad()

            with torch.no_grad():
                preds = torch.argmax(preds, dim=-1)
                sl = og_shape[-1]
                if not hyps['masking_task']:
                    eq = (preds == targs).float()
                    indy_acc = eq[bitmask].mean()
                    eq[~bitmask] = 1
                    eq = eq.reshape(og_shape)
                    acc = (eq.sum(-1) == sl).float().mean()
                else:
                    eq = (preds == targs).float().reshape(og_shape)
                    acc = (eq.sum(-1) == sl).float().mean()
                    indy_acc = eq.mean()
                preds = preds.cpu()

            avg_acc += acc.item()
            avg_indy_acc += indy_acc.item()
            avg_loss += loss.item()

            if hyps["masking_task"]:
                s = "Mask Loss:{:.5f} | Acc:{:.5f} | {:.0f}%"
                s = s.format(mask_loss.item(), mask_acc.item(),
                             b / len(train_loader) * 100)
            elif model.transformer_type == models.DICTIONARY:
                s = "Loss:{:.5f} | Acc:{:.5f} | Emb:{:.5f} | {:.0f}%"
                s = s.format(loss.item(), acc.item(), emb_loss.item(),
                             b / len(train_loader) * 100)
            else:
                s = "Loss:{:.5f} | Acc:{:.5f} | {:.0f}%"
                s = s.format(loss.item(), acc.item(),
                             b / len(train_loader) * 100)
            print(s, end=len(s) * " " + "\r")
            if hyps['exp_name'] == "test" and b > 5: break
        print()
        mask_train_loss = mask_avg_loss / len(train_loader)
        mask_train_acc = mask_avg_acc / len(train_loader)
        train_avg_loss = avg_loss / len(train_loader)
        train_avg_acc = avg_acc / len(train_loader)
        train_avg_indy = avg_indy_acc / len(train_loader)
        train_emb_loss = avg_emb_loss / len(train_loader)

        stats_string = "Train - Loss:{:.5f} | Acc:{:.5f} | Indy:{:.5f}\n"
        stats_string = stats_string.format(train_avg_loss, train_avg_acc,
                                           train_avg_indy)
        if hyps['masking_task']:
            stats_string += "Tr. Mask Loss:{:.5f} | Tr. Mask Acc:{:.5f}\n"
            stats_string = stats_string.format(mask_train_loss, mask_train_acc)
        elif model.transformer_type == models.DICTIONARY:
            stats_string += "Train Emb Loss:{:.5f}\n"
            stats_string = stats_string.format(train_emb_loss)
        model.eval()
        avg_loss = 0
        avg_acc = 0
        avg_indy_acc = 0
        avg_emb_loss = 0
        mask_avg_loss = 0
        mask_avg_acc = 0
        print("Validating...")
        words = None
        with torch.no_grad():
            rand_word_batch = int(np.random.randint(0, len(val_loader)))
            for b, (x, y) in enumerate(val_loader):
                torch.cuda.empty_cache()
                if model.transformer_type == models.AUTOENCODER:
                    targs = x.data[:, 1:]
                    y = x.data
                elif model.transformer_type == models.DICTIONARY:
                    targs = x.data[:, 1:]
                    emb_targs = model.embeddings(y.to(DEVICE))
                    if b == rand_word_batch or hyps['exp_name'] == "test":
                        words = [idx2word[y.squeeze()[i].item()] for\
                                         i in range(len(y.squeeze()))]
                    y = x.data
                else:
                    targs = y.data[:, 1:]
                og_shape = targs.shape

                if hyps['init_decs']:
                    y = train_data.inits.clone().repeat(len(x), 1)
                if hyps['masking_task']:
                    x, y, mask = mask_words(x, y, mask_p=hyps['mask_p'])
                y = y[:, :-1]
                preds = model(x.to(DEVICE), y.to(DEVICE))

                tot_loss = 0
                if model.transformer_type == models.DICTIONARY:
                    emb_preds = preds[1]
                    preds = preds[0]
                    emb_loss = F.mse_loss(emb_preds, emb_targs)
                    avg_emb_loss += emb_loss.item()

                if hyps['masking_task']:
                    # Mask loss and acc
                    targs = targs.reshape(-1)
                    preds = preds.reshape(-1, preds.shape[-1])
                    mask = mask.reshape(-1).bool()
                    idxs = torch.arange(len(mask))[mask]
                    mask_preds = preds[idxs]
                    mask_targs = targs[idxs]
                    mask_loss = lossfxn(mask_preds, mask_targs)
                    mask_avg_loss += mask_loss.item()
                    mask_preds = torch.argmax(mask_preds, dim=-1)
                    mask_acc = (mask_preds == mask_targs).sum().float()
                    mask_acc = mask_acc / mask_preds.numel()
                    mask_avg_acc += mask_acc.item()
                else:
                    mask_acc = torch.zeros(1).to(DEVICE)
                    mask_loss = torch.zeros(1).to(DEVICE)

                # Tot loss and acc
                preds = preds.reshape(-1, preds.shape[-1])
                targs = targs.reshape(-1).to(DEVICE)
                if not hyps['masking_task']:
                    bitmask = (targs != mask_idx)
                    loss = lossfxn(preds[bitmask], targs[bitmask])
                else:
                    loss = lossfxn(preds, targs)
                preds = torch.argmax(preds, dim=-1)
                sl = og_shape[-1]
                if not hyps['masking_task']:
                    eq = (preds == targs).float()
                    indy_acc = eq[bitmask].mean()
                    eq[~bitmask] = 1
                    eq = eq.reshape(og_shape)
                    acc = (eq.sum(-1) == sl).float().mean()
                else:
                    eq = (preds == targs).float().reshape(og_shape)
                    acc = (eq.sum(-1) == sl).float().mean()
                    indy_acc = eq.mean()
                preds = preds.cpu()

                if b == rand_word_batch or hyps['exp_name'] == "test":
                    rand = int(np.random.randint(0, len(x)))
                    question = x[rand]
                    whr = torch.where(question == mask_idx)[0]
                    endx = len(question) if len(whr) == 0 else whr[0].item()
                    question = question[:endx]
                    pred_samp = preds.reshape(og_shape)[rand]
                    targ_samp = targs.reshape(og_shape)[rand]
                    whr = torch.where(targ_samp == mask_idx)[0]
                    endx = len(targ_samp) if len(whr) == 0 else whr[0].item()
                    targ_samp = targ_samp[:endx]
                    pred_samp = pred_samp[:endx]
                    idx2word = train_data.idx2word
                    question = [idx2word[p.item()] for p in question]
                    pred_samp = [idx2word[p.item()] for p in pred_samp]
                    targ_samp = [idx2word[p.item()] for p in targ_samp]
                    question = " ".join(question)
                    pred_samp = " ".join(pred_samp)
                    targ_samp = " ".join(targ_samp)
                    if words is not None:
                        word_samp = str(words[rand])

                avg_acc += acc.item()
                avg_indy_acc += indy_acc.item()
                avg_loss += loss.item()
                if hyps["masking_task"]:
                    s = "Mask Loss:{:.5f} | Acc:{:.5f} | {:.0f}%"
                    s = s.format(mask_loss.item(), mask_acc.item(),
                                 b / len(val_loader) * 100)
                elif model.transformer_type == models.DICTIONARY:
                    s = "Loss:{:.5f} | Acc:{:.5f} | Emb:{:.5f} | {:.0f}%"
                    s = s.format(loss.item(), acc.item(), emb_loss.item(),
                                 b / len(val_loader) * 100)
                else:
                    s = "Loss:{:.5f} | Acc:{:.5f} | {:.0f}%"
                    s = s.format(loss.item(), acc.item(),
                                 b / len(val_loader) * 100)
                print(s, end=len(s) * " " + "\r")
                if hyps['exp_name'] == "test" and b > 5: break
        print()
        del targs
        del x
        del y
        del eq
        if model.transformer_type == models.DICTIONARY:
            del emb_targs
        torch.cuda.empty_cache()
        mask_val_loss = mask_avg_loss / len(val_loader)
        mask_val_acc = mask_avg_acc / len(val_loader)
        val_avg_loss = avg_loss / len(val_loader)
        val_avg_acc = avg_acc / len(val_loader)
        val_emb_loss = avg_emb_loss / len(val_loader)
        scheduler.step(val_avg_acc)
        val_avg_indy = avg_indy_acc / len(val_loader)
        stats_string += "Val - Loss:{:.5f} | Acc:{:.5f} | Indy:{:.5f}\n"
        stats_string = stats_string.format(val_avg_loss, val_avg_acc,
                                           val_avg_indy)
        if hyps['masking_task']:
            stats_string += "Val Mask Loss:{:.5f} | Val Mask Acc:{:.5f}\n"
            stats_string = stats_string.format(mask_avg_loss, mask_avg_acc)
        elif model.transformer_type == models.DICTIONARY:
            stats_string += "Val Emb Loss:{:.5f}\n"
            stats_string = stats_string.format(val_emb_loss)
        if words is not None:
            stats_string += "Word: " + word_samp + "\n"
        stats_string += "Quest: " + question + "\n"
        stats_string += "Targ: " + targ_samp + "\n"
        stats_string += "Pred: " + pred_samp + "\n"
        optimizer.zero_grad()

        save_dict = {
            "epoch": epoch,
            "hyps": hyps,
            "train_loss": train_avg_loss,
            "train_acc": train_avg_acc,
            "train_indy": train_avg_indy,
            "mask_train_loss": mask_train_loss,
            "mask_train_acc": mask_train_acc,
            "val_loss": val_avg_loss,
            "val_acc": val_avg_acc,
            "val_indy": val_avg_indy,
            "mask_val_loss": mask_val_loss,
            "mask_val_acc": mask_val_acc,
            "state_dict": model.state_dict(),
            "optim_dict": optimizer.state_dict(),
            "word2idx": train_data.word2idx,
            "idx2word": train_data.idx2word,
            "sampled_types": train_data.sampled_types
        }
        save_name = "checkpt"
        save_name = os.path.join(hyps['save_folder'], save_name)
        io.save_checkpt(save_dict,
                        save_name,
                        epoch,
                        ext=".pt",
                        del_prev_sd=hyps['del_prev_sd'])
        stats_string += "Exec time: {}\n".format(time.time() - starttime)
        print(stats_string)
        s = "Epoch:{} | Model:{}\n".format(epoch, hyps['save_folder'])
        stats_string = s + stats_string
        log_file = os.path.join(hyps['save_folder'], "training_log.txt")
        with open(log_file, 'a') as f:
            f.write(str(stats_string) + '\n')
    del save_dict['state_dict']
    del save_dict['optim_dict']
    del save_dict['hyps']
    save_dict['save_folder'] = hyps['save_folder']
    return save_dict
Beispiel #8
0
def train(gpu, hyps, verbose=True):
    """
    gpu: int
        the gpu for this training process
    hyps: dict
        contains all relavent hyperparameters
    """
    rank = 0
    if hyps['multi_gpu']:
        rank = hyps['n_gpus']*hyps['node_rank'] + gpu
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            world_size=hyps['world_size'],
            rank=rank)
    verbose = verbose and rank==0
    hyps['rank'] = rank

    torch.cuda.set_device(gpu)
    test_batch_size = try_key(hyps,"test_batch_size",False)
    if test_batch_size and verbose:
        print("Testing batch size!! No saving will occur!")
    hyps['main_path'] = try_key(hyps,'main_path',"./")
    if "ignore_keys" not in hyps:
        hyps['ignore_keys'] = ["n_epochs", "batch_size",
                                "max_context","rank", "n_loss_loops"]
    checkpt,hyps = get_resume_checkpt(hyps)
    if checkpt is None and rank==0:
        hyps['exp_num']=get_exp_num(hyps['main_path'], hyps['exp_name'])
        hyps['save_folder'] = get_save_folder(hyps)
    if rank>0: hyps['save_folder'] = "placeholder"
    if not os.path.exists(hyps['save_folder']) and\
                    not test_batch_size and rank==0:
        os.mkdir(hyps['save_folder'])
    # Set manual seed
    hyps['seed'] = try_key(hyps, 'seed', int(time.time()))+rank
    torch.manual_seed(hyps['seed'])
    np.random.seed(hyps['seed'])
    hyps['MASK'] = MASK
    hyps['START'] = START
    hyps['STOP'] = STOP

    model_class = hyps['model_class']
    hyps['n_loss_loops'] = try_key(hyps,'n_loss_loops',1)
    if not hyps['init_decs'] and not hyps['ordered_preds'] and verbose:
        s = "WARNING!! You probably want to set ordered preds to True "
        s += "with your current configuration!!"
        print(s)

    if verbose:
        print("Retreiving Dataset")
    if "shuffle_split" not in hyps and hyps['shuffle']:
        hyps['shuffle_split'] = True
    train_data,val_data = datas.get_data(hyps)

    hyps['enc_slen'] = train_data.X.shape[-1]
    hyps['dec_slen'] = train_data.Y.shape[-1]
    hyps["mask_idx"] = train_data.X_tokenizer.token_to_id(MASK)
    hyps["dec_mask_idx"] = train_data.Y_tokenizer.token_to_id(MASK)
    hyps['n_vocab'] = train_data.X_tokenizer.get_vocab_size()
    hyps['n_vocab_out'] = train_data.Y_tokenizer.get_vocab_size()

    train_loader = datas.VariableLengthSeqLoader(train_data,
                                    samples_per_epoch=1000,
                                    shuffle=hyps['shuffle'])
    val_loader = datas.VariableLengthSeqLoader(val_data,
                                    samples_per_epoch=50,
                                    shuffle=True)

    if verbose:
        print("Making model")
    model = getattr(models,model_class)(**hyps)
    model.cuda(gpu)
    lossfxn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=hyps['lr'],
                                           weight_decay=hyps['l2'])
    if hyps['multi_gpu']:
        model, optimizer = amp.initialize(model, optimizer,
                                          opt_level='O0')
        model = DDP(model)
    # Load State Dicts if Resuming Training
    if checkpt is not None:
        if verbose:
            print("Loading state dicts from", hyps['save_folder'])
        model.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_dict"])
        epoch = checkpt['epoch']
        if hyps['multi_gpu'] and "amp_dict" in checkpt:
            amp.load_state_dict(checkpt['amp_dict'])
    else:
        epoch = -1
    scheduler = custmods.VaswaniScheduler(optimizer, hyps['emb_size'])
    if verbose:
        print("Beginning training for {}".format(hyps['save_folder']))
        print("train shape:", (len(train_data),*train_data.X.shape[1:]))
        print("val shape:", (len(val_data),*val_data.X.shape[1:]))
    if not test_batch_size:
        record_session(hyps,model)

    if hyps['exp_name'] == "test":
        hyps['n_epochs'] = 2
    mask_idx = train_data.Y_mask_idx
    stop_idx = train_data.Y_stop_idx
    step_num = 0 if checkpt is None else try_key(checkpt,'step_num',0)
    checkpt_steps = 0
    if verbose: print()
    while epoch < hyps['n_epochs']:
        epoch += 1
        if verbose:
            print("Epoch:{} | Step: {} | Model:{}".format(epoch,
                                                 step_num,
                                                 hyps['save_folder']))
            print("Training...")
        starttime = time.time()
        avg_loss = 0
        avg_acc = 0
        avg_indy_acc = 0
        checkpt_loss = 0
        checkpt_acc = 0
        model.train()
        optimizer.zero_grad()
        for b,(x,y) in enumerate(train_loader):
            if test_batch_size:
                x,y = train_loader.get_largest_batch(b)
            torch.cuda.empty_cache()
            targs = y.data[:,1:]
            og_shape = targs.shape
            y = y[:,:-1]
            logits = model(x.cuda(non_blocking=True),
                           y.cuda(non_blocking=True))
            preds = torch.argmax(logits,dim=-1)

            if epoch % 3 == 0 and b == 0 and verbose:
                inp = x.data[0].cpu().numpy()
                trg = targs.data[0].numpy()
                prd = preds.data[0].cpu().numpy()
                print("Inp:", train_data.X_idxs2tokens(inp))
                print("Targ:", train_data.Y_idxs2tokens(trg))
                print("Pred:", train_data.Y_idxs2tokens(prd))

            # Tot loss
            logits = logits.reshape(-1,logits.shape[-1])
            targs = targs.reshape(-1).cuda(non_blocking=True)
            bitmask = targs!=mask_idx
            loss = lossfxn(logits[bitmask],targs[bitmask])

            loss = loss/hyps['n_loss_loops']
            if hyps['multi_gpu']:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            if b % hyps['n_loss_loops'] == 0 or b == len(train_loader)-1:
                optimizer.step()
                optimizer.zero_grad()
                step_num += 1
                scheduler.update_lr(step_num)

            with torch.no_grad():
                # Acc
                preds = preds.reshape(-1)
                sl = og_shape[-1]
                eq = (preds==targs).float()
                indy_acc = eq[bitmask].mean()
                eq[~bitmask] = 1
                eq = eq.reshape(og_shape)
                acc = (eq.sum(-1)==sl).float().mean()
                avg_acc += acc.item()
                avg_indy_acc += indy_acc.item()
                avg_loss += loss.item()
                checkpt_acc += acc.item()
                checkpt_loss += loss.item()

            s = "Loss:{:.5f} | Acc:{:.5f} | {:.0f}%"
            s = s.format(loss.item(), acc.item(),
                                      b/len(train_loader)*100)
            if verbose: print(s, end=len(s)*" " + "\r")
            if hyps['exp_name'] == "test" and b>5: break

        optimizer.zero_grad()
        train_avg_loss = avg_loss/len(train_loader)
        train_avg_acc = avg_acc/len(train_loader)
        train_avg_indy = avg_indy_acc/len(train_loader)

        s = "Ending Step Count: {}\n".format(step_num)
        s = s+"Train - Loss:{:.5f} | Acc:{:.5f} | Indy:{:.5f}\n"
        stats_string = s.format(train_avg_loss, train_avg_acc,
                                                train_avg_indy)

        ###### VALIDATION
        model.eval()
        avg_bleu = 0
        avg_loss = 0
        avg_acc = 0
        avg_indy_acc = 0
        if verbose: print("\nValidating...")
        torch.cuda.empty_cache()
        if rank==0:
            with torch.no_grad():
                rand_word_batch = int(np.random.randint(0,
                                         len(val_loader)))
                for b,(x,y) in enumerate(val_loader):
                    if test_batch_size:
                        x,y = val_loader.get_largest_batch(b)
                    targs = y.data[:,1:]
                    og_shape = targs.shape
                    y = y[:,:-1]
                    preds = model(x.cuda(non_blocking=True),
                                  y.cuda(non_blocking=True))

                    # Tot loss and acc
                    preds = preds.reshape(-1,preds.shape[-1])
                    targs = targs.reshape(-1).cuda(non_blocking=True)
                    bitmask = targs!=mask_idx
                    loss = lossfxn(preds[bitmask],targs[bitmask])
                    if hyps['multi_gpu']:
                        loss = loss.mean()
                    preds = torch.argmax(preds,dim=-1)
                    sl = int(og_shape[-1])
                    eq = (preds==targs).float()
                    indy_acc = eq[bitmask].mean()
                    eq[~bitmask] = 1
                    eq = eq.reshape(og_shape)
                    acc = (eq.sum(-1)==sl).float().mean()
                    bleu_trgs=targs.reshape(og_shape).data.cpu().numpy()
                    trg_ends = np.argmax((bleu_trgs==stop_idx),axis=1)
                    bleu_prds=preds.reshape(og_shape).data.cpu().numpy()
                    prd_ends = np.argmax((bleu_prds==stop_idx),axis=1)
                    btrgs = []
                    bprds = []
                    for i in range(len(bleu_trgs)):
                        temp = bleu_trgs[i,None,:trg_ends[i]].tolist()
                        btrgs.append(temp)
                        bprds.append(bleu_prds[i,:prd_ends[i]].tolist())
                    bleu = corpus_bleu(btrgs,bprds)
                    avg_bleu += bleu
                    avg_acc += acc.item()
                    avg_indy_acc += indy_acc.item()
                    avg_loss += loss.item()

                    if b == rand_word_batch or hyps['exp_name']=="test":
                        rand = int(np.random.randint(0,len(x)))
                        inp = x.data[rand].cpu().numpy()
                        inp_samp = val_data.X_idxs2tokens(inp)
                        trg = targs.reshape(og_shape)[rand].data.cpu()
                        targ_samp = val_data.Y_idxs2tokens(trg.numpy())
                        prd = preds.reshape(og_shape)[rand].data.cpu()
                        pred_samp = val_data.Y_idxs2tokens(prd.numpy())
                    s="Loss:{:.5f} | Acc:{:.5f} | Bleu:{:.5f} | {:.0f}%"
                    s = s.format(loss.item(), acc.item(), bleu,
                                              b/len(val_loader)*100)
                    if verbose: print(s, end=len(s)*" " + "\r")
                    if hyps['exp_name']=="test" and b > 5: break

            if verbose: print()
            val_avg_bleu = avg_bleu/len(val_loader)
            val_avg_loss = avg_loss/len(val_loader)
            val_avg_acc = avg_acc/len(val_loader)
            val_avg_indy = avg_indy_acc/len(val_loader)
            stats_string += "Val- Loss:{:.5f} | Acc:{:.5f} | "
            stats_string += "Indy:{:.5f}\nVal Bleu: {:.5f}\n"
            stats_string = stats_string.format(val_avg_loss,val_avg_acc,
                                                            val_avg_indy,
                                                            val_avg_bleu)
            stats_string += "Inp: "  + inp_samp  + "\n"
            stats_string += "Targ: " + targ_samp + "\n"
            stats_string += "Pred: " + pred_samp + "\n"
        optimizer.zero_grad()

        if not test_batch_size and rank==0:
            save_dict = {
                "epoch":epoch,
                "step_num":step_num,
                "hyps":hyps,
                "train_loss":train_avg_loss,
                "train_acc":train_avg_acc,
                "train_indy":train_avg_indy,
                "val_bleu":val_avg_bleu,
                "val_loss":val_avg_loss,
                "val_acc":val_avg_acc,
                "val_indy":val_avg_indy,
                "state_dict":model.state_dict(),
                "optim_dict":optimizer.state_dict(),
            }
            if hyps['multi_gpu']: save_dict['amp_dict']=amp.state_dict()
            save_name = "checkpt"
            save_name = os.path.join(hyps['save_folder'],save_name)
            io.save_checkpt(save_dict, save_name, epoch, ext=".pt",
                                del_prev_sd=hyps['del_prev_sd'])
        stats_string += "Exec time: {}\n".format(time.time()-starttime)
        if verbose: print(stats_string)
        s = "Epoch:{} | Model:{}\n".format(epoch, hyps['save_folder'])
        stats_string = s + stats_string
        log_file = os.path.join(hyps['save_folder'],
                                "training_log"+str(rank)+".txt")
        if not test_batch_size:
            with open(log_file,'a') as f:
                f.write(str(stats_string)+'\n')
    if rank==0:
        del save_dict['state_dict']
        del save_dict['optim_dict']
        del save_dict['hyps']
        save_dict['save_folder'] = hyps['save_folder']
        return save_dict
    return None