Ejemplo n.º 1
0
    def make_hdf5_from_array(cls, array: Union[np.array, pd.Series], output_file: str, num_batches: int =100 , bptt_length = 75):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
        Save as mdf5, and return Dataset with this mdf5 as source.
        Properties of mdf5 file:
            dataset tokenized_sequences: concatenation of all tokenized sequences (stop tokens inserted). 1D array of size total_n_tokens
            dataset starting_indices: starting index in tokenized_sequences of each sequence. 1D array of size n_sequences
        '''

        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize
        startidxlist = []
        tokenlist = []
        current_start_idx = 0
        for seq in array:
            
            startidxlist.append(current_start_idx)
            words = tokenizer.tokenize(seq) + [tokenizer.stop_token]
            for word in words:
                tokenlist.append(tokenizer.convert_token_to_id(word))
            current_start_idx = len(tokenlist)

        data =  np.array(tokenlist)
        startidx = np.array(startidxlist)
        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)
            f.create_dataset('starting_indices', data = startidx)

        return cls(output_file, bptt_length)
Ejemplo n.º 2
0
    def make_hdf5_from_array(cls, array: Union[np.array, pd.Series], num_batches: int, output_file: str, bptt_length = 75):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
           Save as mdf5, and return Dataset with this mdf5 as source.
        '''

        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize
        tokenlist = []
        for seq in array:
            words = tokenizer.tokenize(seq) + [tokenizer.stop_token]
            #tokens += len(words)
            for word in words:
                tokenlist.append(tokenizer.convert_token_to_id(word))

        #split into batches
        tokensperbatch = len(tokenlist) // num_batches
        end = tokensperbatch*num_batches #trim
        tokenlist = tokenlist[0:end]
        data =  np.array(tokenlist)
        data = data.reshape(-1, num_batches)

        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)

        return cls(output_file, bptt_length)
Ejemplo n.º 3
0
    def make_hdf5_from_txt(cls, file: str, num_batches: int = 100, output_file: str = None, bptt_length = 75, buffer_size = 1000):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
           Save as mdf5, and return Dataset with this mdf5 as source.
        '''
        if not os.path.exists(file):
            raise FileNotFoundError(file)
        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize        
        startidxlist = []
        tokenlist = []
        current_start_idx = 0

        with open(file, 'r') as f:
            for line in f:

                startidxlist.append(current_start_idx)
                words = tokenizer.tokenize(line.rstrip()) + [tokenizer.stop_token]
                for word in words:
                    tokenlist.append(tokenizer.convert_token_to_id(word))
                current_start_idx = len(tokenlist)


        data =  np.array(tokenlist)
        startidx = np.array(startidxlist)
        if not output_file:
            output_file = file + '.hdf5'

        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)
            f.create_dataset('starting_indices', data = startidx)

        return cls(output_file, num_batches, bptt_length, buffer_size)
Ejemplo n.º 4
0
    def __init__(self, dataset_sequences):

        self.dataset_sequences = dataset_sequences
        self.model = ProteinBertModel.from_pretrained('bert-base')
        self.tokenizer = TAPETokenizer(
            vocab='iupac'
        )  # iupac is the vocab for TAPE models, use unirep for the UniRep model
Ejemplo n.º 5
0
    def make_hdf5_from_txt(cls, file: str, num_batches: int, output_file: str = None, bptt_length = 75):
        '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences.
           Save as mdf5, and return Dataset with this mdf5 as source.
        '''
        if not os.path.exists(file):
            raise FileNotFoundError(file)
        tokenizer = TAPETokenizer(vocab = 'iupac') 
        #load and tokenize
        tokenlist = []
        with open(file, 'r') as f:
            #ids = torch.LongTensor(tokens)
            #token = 0
            for line in f:
                words = tokenizer.tokenize(line.rstrip()) + [tokenizer.stop_token]
                #tokens += len(words)
                for word in words:
                    tokenlist.append(tokenizer.convert_token_to_id(word))


        #split into batches
            tokensperbatch = len(tokenlist) // num_batches
            end = tokensperbatch*num_batches #trim
            tokenlist = tokenlist[0:end]
            data =  np.array(tokenlist)
            data = data.reshape(-1, num_batches)
        
        if not output_file:
            output_file = file + '.hdf5'

        with h5py.File(output_file, "w") as f:
            f.create_dataset('tokenized_sequences', data=data)

        return cls(output_file, bptt_length)
Ejemplo n.º 6
0
def test_basic():
    import torch
    from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer  # type: ignore

    config = ProteinBertConfig(hidden_size=12,
                               intermediate_size=12 * 4,
                               num_hidden_layers=2)
    model = ProteinBertModel(config)
    tokenizer = TAPETokenizer(vocab='iupac')

    sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ'
    token_ids = torch.tensor([tokenizer.encode(sequence)])
    output = model(token_ids)
    sequence_output = output[0]  # noqa
    pooled_output = output[1]  # noqa
Ejemplo n.º 7
0
def test_forcedownload():
    model = ProteinBertModel.from_pretrained('bert-base')
    url = BERT_PRETRAINED_MODEL_ARCHIVE_MAP['bert-base']
    filename = url_to_filename(url, get_etag(url))
    wholepath = get_cache() / filename
    oldtime = time.ctime(os.path.getmtime(wholepath))
    model = ProteinBertModel.from_pretrained('bert-base', force_download=True)
    newtime = time.ctime(os.path.getmtime(wholepath))
    assert (newtime != oldtime)
    # Deploy model
    # iupac is the vocab for TAPE models, use unirep for the UniRep model
    tokenizer = TAPETokenizer(vocab='iupac')
    # Pfam Family: Hexapep, Clan: CL0536
    sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ'
    token_ids = torch.tensor([tokenizer.encode(sequence)])
    model(token_ids)
Ejemplo n.º 8
0
    def __init__(self,
                 data_file: Union[str, Path],
                 tokenizer: Union[str, TAPETokenizer] = 'iupac',
                 batch_size: int = 50,
                 bptt_length: int = 75
                 ):
        super().__init__()

        if isinstance(tokenizer, str):
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.bptt_length = bptt_length


        self.data_file = Path(data_file)
        if not os.path.exists(self.data_file):
            raise FileNotFoundError(self.data_file)

        fn = Path(data_file+'.data')
        if os.path.exists(fn):
            logger.info('Loading cached dataset...')
            data = torch.load(fn)
        else:
            logger.info('Producing dataset...')
            data = self._concatenate_full_dataset(data_file)
            torch.save(data, fn)
            logger.info(f'Cached dataset at {fn}')

        data = self._batchify(data, self.batch_size)
        self.data = data
        self.start_idx, self.end_idx = self._get_bptt_indices(len(self.data), self.bptt_length)
Ejemplo n.º 9
0
def run_eval_epoch(
        eval_loader: DataLoader,
        runner: ForwardRunner,
        is_master: bool = True) -> typing.List[typing.Dict[str, typing.Any]]:
    torch.set_grad_enabled(False)
    runner.eval()

    save_outputs = []
    from tape import TAPETokenizer
    tokenizer = TAPETokenizer(vocab="iupac")
    data_dict = {
        entry["primary"]: entry["id"]
        for entry in eval_loader.dataset.data
    }

    for batch in tqdm(eval_loader,
                      desc='Evaluation',
                      total=len(eval_loader),
                      disable=not is_master):
        loss, metrics, outputs = runner.forward(
            batch, return_outputs=True)  # type: ignore
        predictions = outputs[1].cpu().numpy()
        targets = batch['targets'].cpu().numpy()

        seqs = []
        ids = []
        for k in range(batch['input_ids'].shape[0]):
            seq = "".join([
                c
                for c in tokenizer.convert_ids_to_tokens(batch['input_ids'][k])
                if c != "<pad>"
            ][1:-1])
            seqs.append(seq)
            ids.append(data_dict[seq])
        for pred, target, seqs, ids in zip(predictions, targets, seqs, ids):
            save_outputs.append({
                'prediction': pred,
                'target': target,
                'seq': seqs,
                'ids': ids
            })

        # for pred, target in zip(predictions, targets):
        #     save_outputs.append({'prediction': pred, 'target': target})

    return save_outputs
Ejemplo n.º 10
0
def UniRep_Embed(input_seq):
    T0 = time.time()
    UNIREPEB_ = []
    PID = []
    print("UniRep Embedding...")

    model = UniRepModel.from_pretrained('babbler-1900')
    model = model.to(DEVICE)
    tokenizer = TAPETokenizer(vocab='unirep')

    for key, value in input_seq.items():
        PID.append(key)
        sequence = value
        if len(sequence) == 0:
            print('# WARNING: sequence',
                  PID,
                  'has length=0. Skipping.',
                  file=sys.stderr)
            continue
        with torch.no_grad():
            token_ids = torch.tensor([tokenizer.encode(sequence)])
            token_ids = token_ids.to(DEVICE)
            output = model(token_ids)
            unirep_output = output[0]
            unirep_output = torch.squeeze(unirep_output)
            unirep_output = unirep_output.mean(0)
            unirep_output = unirep_output.cpu().numpy()
            UNIREPEB_.append(unirep_output.tolist())
    unirep_feature = pd.DataFrame(UNIREPEB_)
    col = ["UniRep_F" + str(i + 1) for i in range(0, 1900)]
    unirep_feature.columns = col
    unirep_feature = pd.concat([unirep_feature], axis=1)
    unirep_feature.index = PID
    # print(unirep_feature.shape)
    unirep_feature.to_csv("./dataset/unirep_feature.csv")
    print("Getting Deep Representation Learning Features with UniRep is done.")
    print("it took %0.3f mins.\n" % ((time.time() - T0) / 60))

    return unirep_feature
Ejemplo n.º 11
0
    def __init__(self,
                 data_path: Union[str, Path],
                 tokenizer: Union[str, TAPETokenizer] = 'iupac',
                 in_memory: bool = False):
        super().__init__()
        
        if isinstance(tokenizer, str):
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer

        self.data_file = Path(data_path)
        if not self.data_file.exists():
            raise FileNotFoundError(self.data_file)
        seqs = []
        with open(self.data_file, 'r') as f:
            for line in f:
                seqs.append(line.rstrip())
        self.data = seqs
Ejemplo n.º 12
0
class encoding_tape(object):
    def __init__(self, dataset_sequences):

        self.dataset_sequences = dataset_sequences
        self.model = ProteinBertModel.from_pretrained('bert-base')
        self.tokenizer = TAPETokenizer(
            vocab='iupac'
        )  # iupac is the vocab for TAPE models, use unirep for the UniRep model

    def apply_encoding(self):

        matrix_encoding = []
        for i in range(len(self.dataset_sequences)):

            try:
                token_ids = torch.tensor([
                    self.tokenizer.encode(
                        self.dataset_sequences['sequence'][i])
                ])
                output = self.model(token_ids)
                sequence_output = output[0]

                matrix_data = []

                for element in sequence_output[0].cpu().detach().numpy():
                    matrix_data.append(element)

                encoding_avg = []

                for k in range(len(matrix_data[0])):
                    array_value = []
                    for j in range(len(matrix_data)):
                        array_value.append(matrix_data[j][k])

                    encoding_avg.append(np.mean(array_value))
                matrix_encoding.append(encoding_avg)
            except:
                pass

        header = [
            "P_" + str(i + 1) for i in range(len(matrix_encoding[0]) - 1)
        ]
        self.dataset_encoding = pd.DataFrame(matrix_encoding, columns=header)
Ejemplo n.º 13
0
    def __init__(self,
                 data_file: Union[str, Path],
                 tokenizer: Union[str, TAPETokenizer] = 'iupac', label_column = 'target label', sequence_column = 'Sequence'):
        super().__init__()
        
        if isinstance(tokenizer, str):
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer

        if not os.path.exists(data_file):
            raise FileNotFoundError(data_file)
        
        
        df = pd.read_csv(data_file, sep ='\t')
        self.data = df[sequence_column]
        self.labels = df[label_column]
        #no more pandas from here
        self.data = list(self.data)
        self.labels = list(self.labels)
Ejemplo n.º 14
0
    def __init__(self,
                 data_file: Union[str, Path],
                 tokenizer: Union[str, TAPETokenizer] = 'iupac', label_dict = None):
        super().__init__()
        
        if isinstance(tokenizer, str):
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer

        if not os.path.exists(data_file):
            raise FileNotFoundError(data_file)
        
        if label_dict is None:
            self.label_dict = {'S': 0,'P': 1} #TODO make arg or learn from input data once
        else:
            self.label_dict = label_dict
        
        df = pd.read_csv(data_file, sep ='\t')
        self.data = df[df.columns[0]]
        self.labels = df[df.columns[1]]
        #no more pandas from here
        self.data = list(self.data)
        self.labels = list(self.labels)
Ejemplo n.º 15
0
    os.makedirs(f_savepath, exist_ok=True)
    
    all_indices = {}
    
    if rpr=='protein':
        import torch
        from tape import ProteinBertModel, TAPETokenizer

        unique_protein = list(set(all_protein))
        print(f'n unique protein used to compute repr: {len(unique_protein)}')
        
        unique_prot_to_idx = get_data_to_idx_mapping(all_protein)
        
        # init protein pretrained model
        model = ProteinBertModel.from_pretrained('bert-base')
        tokenizer = TAPETokenizer(vocab='iupac') 
        
        results = Parallel(n_jobs=nworkers)(delayed(get_PROTrepr)(i,x,model,tokenizer,f_savepath) for i,x in enumerate(unique_protein))
        
        all_indices = {}
        for x in results:
            i = x[0]
            prot = x[1]
            _all_idx = unique_prot_to_idx[prot]
            for _idx in _all_idx:
                all_indices[_idx] = i
        
        hp.save_pkl(f'{savepath}all_indices_{rpr}.pkl', all_indices)
        
        z_norma = True
        if z_norma:
Ejemplo n.º 16
0
import torch
from tape import ProteinBertModel, TAPETokenizer
model = ProteinBertModel.from_pretrained('bert-base')
tokenizer = TAPETokenizer(
    vocab='iupac'
)  # iupac is the vocab for TAPE models, use unirep for the UniRep model

# Pfam Family: Hexapep, Clan: CL0536
sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ'
token_ids = torch.tensor([tokenizer.encode(sequence)])
output = model(token_ids)
sequence_output = output[0]
pooled_output = output[1]
Ejemplo n.º 17
0
def DRLF_Embed(fastaFile, outFile, device=-2):

    path = fastaFile
    count = 0
    SSAEMB_ = []
    UNIREPEB_ = []
    ##read Fasta File
    inData = fasta.fasta2csv(path)
    Seqs = inData["Seq"]

    PID_ = []
    ##SSA Embedding

    print("SSA Embedding...")
    lm_embed, lstm_stack, proj = load_model(
        "./src/PretrainedModel/SSA_embed.model", use_cuda=True)

    with open(path, 'rb') as f:
        for name, sequence in fasta.parse_stream(f):

            pid = str(name.decode('utf-8'))
            if len(sequence) == 0:
                print('# WARNING: sequence',
                      pid,
                      'has length=0. Skipping.',
                      file=sys.stderr)
                continue

            PID_.append(pid)

            z = embed_sequence(sequence,
                               lm_embed,
                               lstm_stack,
                               proj,
                               final_only=True,
                               pool='avg',
                               use_cuda=True)

            SSAEMB_.append(z)
            count += 1
            print(sequence,
                  '# {} sequences processed...'.format(count),
                  file=sys.stderr,
                  end='\r')
    print("SSA embedding finished@")

    ssa_feature = pd.DataFrame(SSAEMB_)
    col = ["SSA_F" + str(i + 1) for i in range(0, 121)]
    ssa_feature.columns = col

    print("UniRep Embedding...")
    print("Loading UniRep Model...", file=sys.stderr, end='\r')

    model = UniRepModel.from_pretrained('babbler-1900')
    model = model.to(DEVICE)
    tokenizer = TAPETokenizer(vocab='unirep')

    count = 0
    PID_ = inData["PID"]

    for sequence in Seqs:

        if len(sequence) == 0:
            print('# WARNING: sequence',
                  pid,
                  'has length=0. Skipping.',
                  file=sys.stderr)
            continue

        with torch.no_grad():
            token_ids = torch.tensor([tokenizer.encode(sequence)])
            token_ids = token_ids.to(DEVICE)
            output = model(token_ids)
            unirep_output = output[0]
            #print(unirep_output.shape)
            unirep_output = torch.squeeze(unirep_output)
            #print(unirep_output.shape)
            unirep_output = unirep_output.mean(0)
            unirep_output = unirep_output.cpu().numpy()

            # print(sequence,len(sequence),unirep_output.shape)
            UNIREPEB_.append(unirep_output.tolist())
            count += 1
            print(sequence,
                  '# {} sequences processed...'.format(count),
                  file=sys.stderr,
                  end='\r')

    unirep_feature = pd.DataFrame(UNIREPEB_)

    col = ["UniRep_avg_F" + str(i + 1) for i in range(0, 1900)]
    unirep_feature.columns = col
    print("UniRep Embedding Finished@!")
    Features = pd.concat([ssa_feature, unirep_feature], axis=1)
    Features.index = PID_
    Features.to_csv(outFile)
    print("Getting Deep Representation Learning Features is done.")

    return Features, inData
Ejemplo n.º 18
0
        if i%2==0:
            labels.append(0)
        else:
            seqs.append(line.rstrip())
        i=i+1
    f2.close()
    return labels,seqs
train_labels,train_seqs=readFasta("/content/gdrive/My Drive/ProInFuse/train_pos.txt","/content/gdrive/My Drive/ProInFuse/train_neg.txt") # change to your own path
test_labels,test_seqs=readFasta("/content/gdrive/My Drive/ProInFuse/ind_pos.txt","/content/gdrive/My Drive/ProInFuse/ind_neg.txt") # change to your own path

# use any of the read files

import torch
from tape import ProteinBertModel, TAPETokenizer
model = ProteinBertModel.from_pretrained('bert-base')
tokenizer = TAPETokenizer(vocab='iupac') 


num_of_features = 768
import numpy as np
X=np.zeros((len(train_seqs),num_of_features))
y=np.zeros(len(train_seqs))

ind_X=np.zeros((len(test_seqs),num_of_features))
ind_y=np.zeros(len(test_seqs))

# now lets populate X
i=0
for s in train_seqs:
    #f=extractFeatures(s)
    token_ids = torch.tensor([tokenizer.encode(s)])
Ejemplo n.º 19
0
def main_training_loop(args: argparse.ArgumentParser):
    if args.enforce_walltime == True:
        loop_start_time = time.time()
        logger.info('Started timing loop')

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    #Setup Model
    tokenizer = TAPETokenizer(vocab='iupac')
    config = ProteinAWDLSTMConfig(**vars(args))
    config.vocab_size = tokenizer.vocab_size
    if args.reset_hidden:
        logger.info(f'Resetting hidden state after {tokenizer.stop_token}')
        config.reset_token_id = tokenizer.convert_token_to_id(
            tokenizer.stop_token)

    model = ProteinAWDLSTMForLM(config)
    #training logger
    time_stamp = time.strftime("%y-%m-%d-%H-%M-%S", time.gmtime())
    experiment_name = f"{args.experiment_name}_{model.base_model_prefix}_{time_stamp}"
    viz = visualization.get(
        args.output_dir, experiment_name, local_rank=-1
    )  #debug=args.debug) #this -1 means traning is not distributed, debug makes experiment dry run for wandb

    train_data = Hdf5Dataset(os.path.join(args.data, 'train.hdf5'),
                             batch_size=args.batch_size,
                             bptt_length=args.bptt,
                             buffer_size=args.buffer_size)
    val_data = Hdf5Dataset(os.path.join(args.data, 'valid.hdf5'),
                           batch_size=args.batch_size,
                           bptt_length=args.bptt,
                           buffer_size=args.buffer_size)

    logger.info(f'Data loaded. One train epoch = {len(train_data)} steps.')
    logger.info(f'Data loaded. One valid epoch = {len(val_data)} steps.')

    train_loader = DataLoader(train_data,
                              batch_size=1,
                              collate_fn=train_data.collate_fn)
    val_loader = DataLoader(val_data,
                            batch_size=1,
                            collate_fn=train_data.collate_fn)
    #setup validation here so i can get a subsample from where i stopped each time i need it
    val_iterator = enumerate(val_loader)
    val_steps = 0
    hidden = None

    #overwrite model when restarting/changing params
    if args.resume:
        logger.info(f'Loading pretrained model in {args.resume}')
        model = ProteinAWDLSTMForLM.from_pretrained(args.resume)

    if args.wandb_sweep:
        #This prevents errors. When model is partly set up from config, not commmandline,
        #config that is received from wandb might not match what is in args and ProteinConfig.
        #when then calling log_config, inconsistency would throw an error.
        #this overwrites args and ProteinConfig, so wandb has priority.
        #In case of doubt of match, check wandb run and save_pretrained config.json. Should always agree
        logger.info(f'Receiving config from wandb!')
        import wandb
        from training_utils import override_from_wandb
        override_from_wandb(wandb.config, args, config)
        model = ProteinAWDLSTMForLM(config)

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.wdecay)
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.wdecay)

    model.to(device)
    logger.info('Model set up!')
    num_parameters = sum(p.numel() for p in model.parameters()
                         if p.requires_grad)
    logger.info(f'Model has {num_parameters} trainable parameters')

    if torch.cuda.is_available():
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        logger.info(f'Running model on {device}, not using nvidia apex')

    #set up wandb logging, tape visualizer class takes care of everything. just login to wandb in the env as usual
    viz.log_config(args)
    viz.log_config(model.config.to_dict())
    viz.watch(model)
    logger.info(
        f'Logging experiment as {experiment_name} to wandb/tensorboard')

    #keep track of best loss
    num_epochs_no_improvement = 0
    stored_loss = 100000000
    learning_rate_steps = 0

    global_step = 0
    for epoch in range(1, args.epochs + 1):
        logger.info(f'Starting epoch {epoch}')
        viz.log_metrics({'Learning Rate': optimizer.param_groups[0]['lr']},
                        "train", global_step)

        epoch_start_time = time.time()
        start_time = time.time()  #for lr update interval
        hidden = None
        for i, batch in enumerate(train_loader):

            data, targets = batch
            loss, reg_loss, hidden = training_step(model, data, targets,
                                                   hidden, optimizer, args, i)
            viz.log_metrics(
                {
                    'loss': loss,
                    'regularized loss': reg_loss,
                    'perplexity': math.exp(loss),
                    'regularized perplexity': math.exp(reg_loss)
                }, "train", global_step)
            global_step += 1

            update_steps = args.update_lr_steps if len(
                train_loader) > args.update_lr_steps else len(
                    train_loader
                )  #ad hoc fix for smaller datasets, evaluate after full epochs
            # every update_lr_steps, evaluate performance and save model/progress in learning rate
            if global_step % update_steps == 0 and global_step > 0:
                total_loss = 0
                total_reg_loss = 0
                total_len = 0

                #NOTE Plasmodium sets are 1% the size of Eukarya sets. run 1/100 of total set at each time
                #n_val_steps = (len(val_loader)//100) if len(val_loader) > 100000 else len(val_loader) #works because plasmodium set is smaller, don't want another arg for this
                #old border was too high, cannot train homology reduced eukarya 10 percent with it
                n_val_steps = (
                    len(val_loader) // 100
                ) if len(val_loader) > 10000 else len(
                    val_loader
                )  #works because plasmodium set is smaller, don't want another arg for this
                logger.info(
                    f'Step {global_step}, validating for {n_val_steps} Validation steps'
                )

                for j in range(n_val_steps):
                    val_steps += 1
                    #if val_steps == len(val_loader): #reset the validation data when at its end
                    #    val_iterator = enumerate(val_loader)
                    #    hidden = None
                    try:
                        _, (data, targets) = next(val_iterator)
                    except:
                        val_iterator = enumerate(val_loader)
                        logger.info(
                            f'validation step{j}: resetting validation enumerator.'
                        )
                        hidden = None
                        _, (data, targets) = next(val_iterator)
                    loss, reg_loss, hidden = validation_step(
                        model, data, targets, hidden)
                    total_len += len(data)
                    total_reg_loss += reg_loss * len(data)
                    total_loss += loss * len(data)

                val_reg_loss = total_reg_loss / total_len

                val_loss = total_loss / total_len

                val_metrics = {
                    'loss': val_loss,
                    'perplexity': math.exp(val_loss),
                    'regularized loss': val_reg_loss,
                    'regularized perplexity': math.exp(val_reg_loss)
                }
                viz.log_metrics(val_metrics, "val", global_step)

                elapsed = time.time() - start_time
                logger.info(
                    f'Training step {global_step}, { elapsed / args.log_interval:.3f} s/batch. tr_loss: {loss:.2f}, tr_perplexity {math.exp(loss):.2f} va_loss: {val_loss:.2f}, va_perplexity {math.exp(val_loss):.2f}'
                )
                start_time = time.time()

                if val_loss < stored_loss:
                    num_epochs_no_improvement = 0
                    model.save_pretrained(args.output_dir)
                    save_training_status(args.output_dir, epoch, global_step,
                                         num_epochs_no_improvement,
                                         stored_loss, learning_rate_steps)
                    #also save with apex
                    if torch.cuda.is_available():
                        checkpoint = {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'amp': amp.state_dict()
                        }
                        torch.save(
                            checkpoint,
                            os.path.join(args.output_dir, 'amp_checkpoint.pt'))
                        logger.info(
                            f'New best model with loss {val_loss}, Saving model, training step {global_step}'
                        )
                    stored_loss = val_loss
                else:
                    num_epochs_no_improvement += 1
                    logger.info(
                        f'Step {global_step}: No improvement for {num_epochs_no_improvement} pseudo-epochs.'
                    )

                    if num_epochs_no_improvement == args.wait_epochs:
                        optimizer.param_groups[0][
                            'lr'] = optimizer.param_groups[0][
                                'lr'] * args.lr_step
                        learning_rate_steps += 1
                        num_epochs_no_improvement = 0
                        logger.info(
                            f'Step {global_step}: Decreasing learning rate. learning rate step {learning_rate_steps}.'
                        )
                        viz.log_metrics(
                            {'Learning Rate': optimizer.param_groups[0]['lr']},
                            "train", global_step)

                        #break early after 5 lr steps
                        if learning_rate_steps > 5:
                            logger.info(
                                'Learning rate step limit reached, ending training early'
                            )
                            return stored_loss

            if args.enforce_walltime == True and (
                    time.time() - loop_start_time) > 84600:  #23.5 hours
                logger.info('Wall time limit reached, ending training early')
                return stored_loss

        logger.info(f'Epoch {epoch} training complete')
        logger.info(
            f'Epoch {epoch}, took {time.time() - epoch_start_time:.2f}.\t Train loss: {loss:.2f} \t Train perplexity: {math.exp(loss):.2f}'
        )

    return stored_loss
Ejemplo n.º 20
0
         tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd",
                                                   do_lower_case=False)
     elif model_version == 'prot_bert':
         model = BertModel.from_pretrained("Rostlab/prot_bert",
                                           output_attentions=True)
         tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert",
                                                   do_lower_case=False)
     elif model_version == 'prot_albert':
         model = AlbertModel.from_pretrained("Rostlab/prot_albert",
                                             output_attentions=True)
         tokenizer = AlbertTokenizer.from_pretrained("Rostlab/prot_albert",
                                                     do_lower_case=False)
     else:
         model = ProteinBertModel.from_pretrained(model_version,
                                                  output_attentions=True)
         tokenizer = TAPETokenizer()
     num_layers = model.config.num_hidden_layers
     num_heads = model.config.num_attention_heads
 elif args.model == 'xlnet':
     model_version = args.model_version
     if model_version == 'prot_xlnet':
         model = XLNetModel.from_pretrained("Rostlab/prot_xlnet",
                                            output_attentions=True)
         tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet",
                                                    do_lower_case=False)
     else:
         raise ValueError('Invalid model version')
     num_layers = model.config.n_layer
     num_heads = model.config.n_head
 else:
     raise ValueError(f"Invalid model: {args.model}")
Ejemplo n.º 21
0
def escape_tape(virus, vocabulary, pretrained='transformer'):
    if pretrained == 'transformer':
        from tape import ProteinBertModel
        model_class = ProteinBertModel
        model_name = 'bert-base'
        fname_prefix = 'tape_transformer'
        vocab = 'iupac'
    elif pretrained == 'unirep':
        from tape import UniRepModel
        model_class = UniRepModel
        model_name = 'babbler-1900'
        fname_prefix = 'unirep'
        vocab = 'unirep'

    from tape import TAPETokenizer
    model = model_class.from_pretrained(model_name)
    tokenizer = TAPETokenizer(vocab=vocab)

    seq, seqs_escape, train_fname, mut_fname, anchor_id = load(virus)
    if virus == 'h1':
        embed_fname = ('target/flu/embedding/{}_h1.npz'.format(fname_prefix))
    elif virus == 'h3':
        embed_fname = ('target/flu/embedding/{}_h3.npz'.format(fname_prefix))
    elif virus == 'bg505':
        embed_fname = ('target/hiv/embedding/{}_hiv.npz'.format(fname_prefix))
    elif virus == 'sarscov2':
        embed_fname = (
            'target/cov/embedding/{}_sarscov2.npz'.format(fname_prefix))
    elif virus == 'cov2rbd':
        embed_fname = (
            'target/cov/embedding/{}_sarscov2.npz'.format(fname_prefix))
    else:
        raise ValueError('invalid option {}'.format(virus))

    anchor = None
    for idx, record in enumerate(SeqIO.parse(train_fname, 'fasta')):
        if record.id == anchor_id:
            anchor = str(record.seq)
    assert (anchor is not None)

    base_embedding = tape_embed(anchor.replace('-', ''), model, tokenizer)
    with np.load(embed_fname, allow_pickle=True) as data:
        embeddings = {name: data[name][()]['avg'] for name in data.files}

    mutations = [str(record.seq) for record in SeqIO.parse(mut_fname, 'fasta')]
    mut2change = {}
    for mutation in mutations:
        didx = [c1 != c2 for c1, c2 in zip(anchor, mutation)].index(True)
        embedding = embeddings['mut_{}_{}'.format(didx, mutation[didx])]
        mutation_clean = mutation.replace('-', '')
        mut2change[mutation_clean] = abs(base_embedding - embedding).sum()

    anchor = anchor.replace('-', '')
    escape_idx, changes = [], []
    for i in range(len(anchor)):
        if virus == 'bg505' and (i < 29 or i > 698):
            continue
        if virus == 'cov2rbd' and (i < 318 or i > 540):
            continue
        for word in vocabulary:
            if anchor[i] == word:
                continue
            mut_seq = anchor[:i] + word + anchor[i + 1:]
            if mut_seq in seqs_escape and \
               (sum([ m['significant'] for m in seqs_escape[mut_seq] ]) > 0):
                escape_idx.append(len(changes))
            changes.append(mut2change[mut_seq])
    changes = np.array(changes)

    plot_result(-changes,
                escape_idx,
                virus,
                fname_prefix,
                legend_name='TAPE ({})'.format(fname_prefix))