def make_hdf5_from_array(cls, array: Union[np.array, pd.Series], output_file: str, num_batches: int =100 , bptt_length = 75):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
        Save as mdf5, and return Dataset with this mdf5 as source.
        Properties of mdf5 file:
            dataset tokenized_sequences: concatenation of all tokenized sequences (stop tokens inserted). 1D array of size total_n_tokens
            dataset starting_indices: starting index in tokenized_sequences of each sequence. 1D array of size n_sequences
        '''

        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize
        startidxlist = []
        tokenlist = []
        current_start_idx = 0
        for seq in array:
            
            startidxlist.append(current_start_idx)
            words = tokenizer.tokenize(seq) + [tokenizer.stop_token]
            for word in words:
                tokenlist.append(tokenizer.convert_token_to_id(word))
            current_start_idx = len(tokenlist)

        data =  np.array(tokenlist)
        startidx = np.array(startidxlist)
        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)
            f.create_dataset('starting_indices', data = startidx)

        return cls(output_file, bptt_length)
    def make_hdf5_from_txt(cls, file: str, num_batches: int = 100, output_file: str = None, bptt_length = 75, buffer_size = 1000):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
           Save as mdf5, and return Dataset with this mdf5 as source.
        '''
        if not os.path.exists(file):
            raise FileNotFoundError(file)
        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize        
        startidxlist = []
        tokenlist = []
        current_start_idx = 0

        with open(file, 'r') as f:
            for line in f:

                startidxlist.append(current_start_idx)
                words = tokenizer.tokenize(line.rstrip()) + [tokenizer.stop_token]
                for word in words:
                    tokenlist.append(tokenizer.convert_token_to_id(word))
                current_start_idx = len(tokenlist)


        data =  np.array(tokenlist)
        startidx = np.array(startidxlist)
        if not output_file:
            output_file = file + '.hdf5'

        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)
            f.create_dataset('starting_indices', data = startidx)

        return cls(output_file, num_batches, bptt_length, buffer_size)
    def make_hdf5_from_txt(cls, file: str, num_batches: int, output_file: str = None, bptt_length = 75):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
           Save as mdf5, and return Dataset with this mdf5 as source.
        '''
        if not os.path.exists(file):
            raise FileNotFoundError(file)
        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize
        tokenlist = []
        with open(file, 'r') as f:
            #ids = torch.LongTensor(tokens)
            #token = 0
            for line in f:
                words = tokenizer.tokenize(line.rstrip()) + [tokenizer.stop_token]
                #tokens += len(words)
                for word in words:
                    tokenlist.append(tokenizer.convert_token_to_id(word))


        #split into batches
            tokensperbatch = len(tokenlist) // num_batches
            end = tokensperbatch*num_batches #trim
            tokenlist = tokenlist[0:end]
            data =  np.array(tokenlist)
            data = data.reshape(-1, num_batches)
        
        if not output_file:
            output_file = file + '.hdf5'

        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)

        return cls(output_file, bptt_length)
    def make_hdf5_from_array(cls, array: Union[np.array, pd.Series], num_batches: int, output_file: str, bptt_length = 75):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
           Save as mdf5, and return Dataset with this mdf5 as source.
        '''

        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize
        tokenlist = []
        for seq in array:
            words = tokenizer.tokenize(seq) + [tokenizer.stop_token]
            #tokens += len(words)
            for word in words:
                tokenlist.append(tokenizer.convert_token_to_id(word))

        #split into batches
        tokensperbatch = len(tokenlist) // num_batches
        end = tokensperbatch*num_batches #trim
        tokenlist = tokenlist[0:end]
        data =  np.array(tokenlist)
        data = data.reshape(-1, num_batches)

        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)

        return cls(output_file, bptt_length)
def main_training_loop(args: argparse.ArgumentParser):
    if args.enforce_walltime == True:
        loop_start_time = time.time()
        logger.info('Started timing loop')

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    #Setup Model
    tokenizer = TAPETokenizer(vocab='iupac')
    config = ProteinAWDLSTMConfig(**vars(args))
    config.vocab_size = tokenizer.vocab_size
    if args.reset_hidden:
        logger.info(f'Resetting hidden state after {tokenizer.stop_token}')
        config.reset_token_id = tokenizer.convert_token_to_id(
            tokenizer.stop_token)

    model = ProteinAWDLSTMForLM(config)
    #training logger
    time_stamp = time.strftime("%y-%m-%d-%H-%M-%S", time.gmtime())
    experiment_name = f"{args.experiment_name}_{model.base_model_prefix}_{time_stamp}"
    viz = visualization.get(
        args.output_dir, experiment_name, local_rank=-1
    )  #debug=args.debug) #this -1 means traning is not distributed, debug makes experiment dry run for wandb

    train_data = Hdf5Dataset(os.path.join(args.data, 'train.hdf5'),
                             batch_size=args.batch_size,
                             bptt_length=args.bptt,
                             buffer_size=args.buffer_size)
    val_data = Hdf5Dataset(os.path.join(args.data, 'valid.hdf5'),
                           batch_size=args.batch_size,
                           bptt_length=args.bptt,
                           buffer_size=args.buffer_size)

    logger.info(f'Data loaded. One train epoch = {len(train_data)} steps.')
    logger.info(f'Data loaded. One valid epoch = {len(val_data)} steps.')

    train_loader = DataLoader(train_data,
                              batch_size=1,
                              collate_fn=train_data.collate_fn)
    val_loader = DataLoader(val_data,
                            batch_size=1,
                            collate_fn=train_data.collate_fn)
    #setup validation here so i can get a subsample from where i stopped each time i need it
    val_iterator = enumerate(val_loader)
    val_steps = 0
    hidden = None

    #overwrite model when restarting/changing params
    if args.resume:
        logger.info(f'Loading pretrained model in {args.resume}')
        model = ProteinAWDLSTMForLM.from_pretrained(args.resume)

    if args.wandb_sweep:
        #This prevents errors. When model is partly set up from config, not commmandline,
        #config that is received from wandb might not match what is in args and ProteinConfig.
        #when then calling log_config, inconsistency would throw an error.
        #this overwrites args and ProteinConfig, so wandb has priority.
        #In case of doubt of match, check wandb run and save_pretrained config.json. Should always agree
        logger.info(f'Receiving config from wandb!')
        import wandb
        from training_utils import override_from_wandb
        override_from_wandb(wandb.config, args, config)
        model = ProteinAWDLSTMForLM(config)

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.wdecay)
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.wdecay)

    model.to(device)
    logger.info('Model set up!')
    num_parameters = sum(p.numel() for p in model.parameters()
                         if p.requires_grad)
    logger.info(f'Model has {num_parameters} trainable parameters')

    if torch.cuda.is_available():
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        logger.info(f'Running model on {device}, not using nvidia apex')

    #set up wandb logging, tape visualizer class takes care of everything. just login to wandb in the env as usual
    viz.log_config(args)
    viz.log_config(model.config.to_dict())
    viz.watch(model)
    logger.info(
        f'Logging experiment as {experiment_name} to wandb/tensorboard')

    #keep track of best loss
    num_epochs_no_improvement = 0
    stored_loss = 100000000
    learning_rate_steps = 0

    global_step = 0
    for epoch in range(1, args.epochs + 1):
        logger.info(f'Starting epoch {epoch}')
        viz.log_metrics({'Learning Rate': optimizer.param_groups[0]['lr']},
                        "train", global_step)

        epoch_start_time = time.time()
        start_time = time.time()  #for lr update interval
        hidden = None
        for i, batch in enumerate(train_loader):

            data, targets = batch
            loss, reg_loss, hidden = training_step(model, data, targets,
                                                   hidden, optimizer, args, i)
            viz.log_metrics(
                {
                    'loss': loss,
                    'regularized loss': reg_loss,
                    'perplexity': math.exp(loss),
                    'regularized perplexity': math.exp(reg_loss)
                }, "train", global_step)
            global_step += 1

            update_steps = args.update_lr_steps if len(
                train_loader) > args.update_lr_steps else len(
                    train_loader
                )  #ad hoc fix for smaller datasets, evaluate after full epochs
            # every update_lr_steps, evaluate performance and save model/progress in learning rate
            if global_step % update_steps == 0 and global_step > 0:
                total_loss = 0
                total_reg_loss = 0
                total_len = 0

                #NOTE Plasmodium sets are 1% the size of Eukarya sets. run 1/100 of total set at each time
                #n_val_steps = (len(val_loader)//100) if len(val_loader) > 100000 else len(val_loader) #works because plasmodium set is smaller, don't want another arg for this
                #old border was too high, cannot train homology reduced eukarya 10 percent with it
                n_val_steps = (
                    len(val_loader) // 100
                ) if len(val_loader) > 10000 else len(
                    val_loader
                )  #works because plasmodium set is smaller, don't want another arg for this
                logger.info(
                    f'Step {global_step}, validating for {n_val_steps} Validation steps'
                )

                for j in range(n_val_steps):
                    val_steps += 1
                    #if val_steps == len(val_loader): #reset the validation data when at its end
                    #    val_iterator = enumerate(val_loader)
                    #    hidden = None
                    try:
                        _, (data, targets) = next(val_iterator)
                    except:
                        val_iterator = enumerate(val_loader)
                        logger.info(
                            f'validation step{j}: resetting validation enumerator.'
                        )
                        hidden = None
                        _, (data, targets) = next(val_iterator)
                    loss, reg_loss, hidden = validation_step(
                        model, data, targets, hidden)
                    total_len += len(data)
                    total_reg_loss += reg_loss * len(data)
                    total_loss += loss * len(data)

                val_reg_loss = total_reg_loss / total_len

                val_loss = total_loss / total_len

                val_metrics = {
                    'loss': val_loss,
                    'perplexity': math.exp(val_loss),
                    'regularized loss': val_reg_loss,
                    'regularized perplexity': math.exp(val_reg_loss)
                }
                viz.log_metrics(val_metrics, "val", global_step)

                elapsed = time.time() - start_time
                logger.info(
                    f'Training step {global_step}, { elapsed / args.log_interval:.3f} s/batch. tr_loss: {loss:.2f}, tr_perplexity {math.exp(loss):.2f} va_loss: {val_loss:.2f}, va_perplexity {math.exp(val_loss):.2f}'
                )
                start_time = time.time()

                if val_loss < stored_loss:
                    num_epochs_no_improvement = 0
                    model.save_pretrained(args.output_dir)
                    save_training_status(args.output_dir, epoch, global_step,
                                         num_epochs_no_improvement,
                                         stored_loss, learning_rate_steps)
                    #also save with apex
                    if torch.cuda.is_available():
                        checkpoint = {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'amp': amp.state_dict()
                        }
                        torch.save(
                            checkpoint,
                            os.path.join(args.output_dir, 'amp_checkpoint.pt'))
                        logger.info(
                            f'New best model with loss {val_loss}, Saving model, training step {global_step}'
                        )
                    stored_loss = val_loss
                else:
                    num_epochs_no_improvement += 1
                    logger.info(
                        f'Step {global_step}: No improvement for {num_epochs_no_improvement} pseudo-epochs.'
                    )

                    if num_epochs_no_improvement == args.wait_epochs:
                        optimizer.param_groups[0][
                            'lr'] = optimizer.param_groups[0][
                                'lr'] * args.lr_step
                        learning_rate_steps += 1
                        num_epochs_no_improvement = 0
                        logger.info(
                            f'Step {global_step}: Decreasing learning rate. learning rate step {learning_rate_steps}.'
                        )
                        viz.log_metrics(
                            {'Learning Rate': optimizer.param_groups[0]['lr']},
                            "train", global_step)

                        #break early after 5 lr steps
                        if learning_rate_steps > 5:
                            logger.info(
                                'Learning rate step limit reached, ending training early'
                            )
                            return stored_loss

            if args.enforce_walltime == True and (
                    time.time() - loop_start_time) > 84600:  #23.5 hours
                logger.info('Wall time limit reached, ending training early')
                return stored_loss

        logger.info(f'Epoch {epoch} training complete')
        logger.info(
            f'Epoch {epoch}, took {time.time() - epoch_start_time:.2f}.\t Train loss: {loss:.2f} \t Train perplexity: {math.exp(loss):.2f}'
        )

    return stored_loss