def make_hdf5_from_array(cls, array: Union[np.array, pd.Series], num_batches: int, output_file: str, bptt_length = 75): '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences. Save as mdf5, and return Dataset with this mdf5 as source. ''' tokenizer = TAPETokenizer(vocab = 'iupac') #load and tokenize tokenlist = [] for seq in array: words = tokenizer.tokenize(seq) + [tokenizer.stop_token] #tokens += len(words) for word in words: tokenlist.append(tokenizer.convert_token_to_id(word)) #split into batches tokensperbatch = len(tokenlist) // num_batches end = tokensperbatch*num_batches #trim tokenlist = tokenlist[0:end] data = np.array(tokenlist) data = data.reshape(-1, num_batches) with h5py.File(output_file, "w") as f: f.create_dataset('tokenized_sequences', data=data) return cls(output_file, bptt_length)
def make_hdf5_from_array(cls, array: Union[np.array, pd.Series], output_file: str, num_batches: int =100 , bptt_length = 75): '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences. Save as mdf5, and return Dataset with this mdf5 as source. Properties of mdf5 file: dataset tokenized_sequences: concatenation of all tokenized sequences (stop tokens inserted). 1D array of size total_n_tokens dataset starting_indices: starting index in tokenized_sequences of each sequence. 1D array of size n_sequences ''' tokenizer = TAPETokenizer(vocab = 'iupac') #load and tokenize startidxlist = [] tokenlist = [] current_start_idx = 0 for seq in array: startidxlist.append(current_start_idx) words = tokenizer.tokenize(seq) + [tokenizer.stop_token] for word in words: tokenlist.append(tokenizer.convert_token_to_id(word)) current_start_idx = len(tokenlist) data = np.array(tokenlist) startidx = np.array(startidxlist) with h5py.File(output_file, "w") as f: f.create_dataset('tokenized_sequences', data=data) f.create_dataset('starting_indices', data = startidx) return cls(output_file, bptt_length)
def make_hdf5_from_txt(cls, file: str, num_batches: int = 100, output_file: str = None, bptt_length = 75, buffer_size = 1000): '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences. Save as mdf5, and return Dataset with this mdf5 as source. ''' if not os.path.exists(file): raise FileNotFoundError(file) tokenizer = TAPETokenizer(vocab = 'iupac') #load and tokenize startidxlist = [] tokenlist = [] current_start_idx = 0 with open(file, 'r') as f: for line in f: startidxlist.append(current_start_idx) words = tokenizer.tokenize(line.rstrip()) + [tokenizer.stop_token] for word in words: tokenlist.append(tokenizer.convert_token_to_id(word)) current_start_idx = len(tokenlist) data = np.array(tokenlist) startidx = np.array(startidxlist) if not output_file: output_file = file + '.hdf5' with h5py.File(output_file, "w") as f: f.create_dataset('tokenized_sequences', data=data) f.create_dataset('starting_indices', data = startidx) return cls(output_file, num_batches, bptt_length, buffer_size)
def __init__(self, data_file: Union[str, Path], tokenizer: Union[str, TAPETokenizer] = 'iupac', batch_size: int = 50, bptt_length: int = 75 ): super().__init__() if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer self.batch_size = batch_size self.bptt_length = bptt_length self.data_file = Path(data_file) if not os.path.exists(self.data_file): raise FileNotFoundError(self.data_file) fn = Path(data_file+'.data') if os.path.exists(fn): logger.info('Loading cached dataset...') data = torch.load(fn) else: logger.info('Producing dataset...') data = self._concatenate_full_dataset(data_file) torch.save(data, fn) logger.info(f'Cached dataset at {fn}') data = self._batchify(data, self.batch_size) self.data = data self.start_idx, self.end_idx = self._get_bptt_indices(len(self.data), self.bptt_length)
def __init__(self, dataset_sequences): self.dataset_sequences = dataset_sequences self.model = ProteinBertModel.from_pretrained('bert-base') self.tokenizer = TAPETokenizer( vocab='iupac' ) # iupac is the vocab for TAPE models, use unirep for the UniRep model
def make_hdf5_from_txt(cls, file: str, num_batches: int, output_file: str = None, bptt_length = 75): '''Tokenize sequences from a line-by-line txt file, concatenate and cut into num_batch sequences. Save as mdf5, and return Dataset with this mdf5 as source. ''' if not os.path.exists(file): raise FileNotFoundError(file) tokenizer = TAPETokenizer(vocab = 'iupac') #load and tokenize tokenlist = [] with open(file, 'r') as f: #ids = torch.LongTensor(tokens) #token = 0 for line in f: words = tokenizer.tokenize(line.rstrip()) + [tokenizer.stop_token] #tokens += len(words) for word in words: tokenlist.append(tokenizer.convert_token_to_id(word)) #split into batches tokensperbatch = len(tokenlist) // num_batches end = tokensperbatch*num_batches #trim tokenlist = tokenlist[0:end] data = np.array(tokenlist) data = data.reshape(-1, num_batches) if not output_file: output_file = file + '.hdf5' with h5py.File(output_file, "w") as f: f.create_dataset('tokenized_sequences', data=data) return cls(output_file, bptt_length)
def test_basic(): import torch from tape import ProteinBertModel, ProteinBertConfig, TAPETokenizer # type: ignore config = ProteinBertConfig(hidden_size=12, intermediate_size=12 * 4, num_hidden_layers=2) model = ProteinBertModel(config) tokenizer = TAPETokenizer(vocab='iupac') sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ' token_ids = torch.tensor([tokenizer.encode(sequence)]) output = model(token_ids) sequence_output = output[0] # noqa pooled_output = output[1] # noqa
def test_forcedownload(): model = ProteinBertModel.from_pretrained('bert-base') url = BERT_PRETRAINED_MODEL_ARCHIVE_MAP['bert-base'] filename = url_to_filename(url, get_etag(url)) wholepath = get_cache() / filename oldtime = time.ctime(os.path.getmtime(wholepath)) model = ProteinBertModel.from_pretrained('bert-base', force_download=True) newtime = time.ctime(os.path.getmtime(wholepath)) assert (newtime != oldtime) # Deploy model # iupac is the vocab for TAPE models, use unirep for the UniRep model tokenizer = TAPETokenizer(vocab='iupac') # Pfam Family: Hexapep, Clan: CL0536 sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ' token_ids = torch.tensor([tokenizer.encode(sequence)]) model(token_ids)
def run_eval_epoch( eval_loader: DataLoader, runner: ForwardRunner, is_master: bool = True) -> typing.List[typing.Dict[str, typing.Any]]: torch.set_grad_enabled(False) runner.eval() save_outputs = [] from tape import TAPETokenizer tokenizer = TAPETokenizer(vocab="iupac") data_dict = { entry["primary"]: entry["id"] for entry in eval_loader.dataset.data } for batch in tqdm(eval_loader, desc='Evaluation', total=len(eval_loader), disable=not is_master): loss, metrics, outputs = runner.forward( batch, return_outputs=True) # type: ignore predictions = outputs[1].cpu().numpy() targets = batch['targets'].cpu().numpy() seqs = [] ids = [] for k in range(batch['input_ids'].shape[0]): seq = "".join([ c for c in tokenizer.convert_ids_to_tokens(batch['input_ids'][k]) if c != "<pad>" ][1:-1]) seqs.append(seq) ids.append(data_dict[seq]) for pred, target, seqs, ids in zip(predictions, targets, seqs, ids): save_outputs.append({ 'prediction': pred, 'target': target, 'seq': seqs, 'ids': ids }) # for pred, target in zip(predictions, targets): # save_outputs.append({'prediction': pred, 'target': target}) return save_outputs
def __init__(self, data_path: Union[str, Path], tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): super().__init__() if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer self.data_file = Path(data_path) if not self.data_file.exists(): raise FileNotFoundError(self.data_file) seqs = [] with open(self.data_file, 'r') as f: for line in f: seqs.append(line.rstrip()) self.data = seqs
def __init__(self, data_file: Union[str, Path], tokenizer: Union[str, TAPETokenizer] = 'iupac', label_column = 'target label', sequence_column = 'Sequence'): super().__init__() if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer if not os.path.exists(data_file): raise FileNotFoundError(data_file) df = pd.read_csv(data_file, sep ='\t') self.data = df[sequence_column] self.labels = df[label_column] #no more pandas from here self.data = list(self.data) self.labels = list(self.labels)
def UniRep_Embed(input_seq): T0 = time.time() UNIREPEB_ = [] PID = [] print("UniRep Embedding...") model = UniRepModel.from_pretrained('babbler-1900') model = model.to(DEVICE) tokenizer = TAPETokenizer(vocab='unirep') for key, value in input_seq.items(): PID.append(key) sequence = value if len(sequence) == 0: print('# WARNING: sequence', PID, 'has length=0. Skipping.', file=sys.stderr) continue with torch.no_grad(): token_ids = torch.tensor([tokenizer.encode(sequence)]) token_ids = token_ids.to(DEVICE) output = model(token_ids) unirep_output = output[0] unirep_output = torch.squeeze(unirep_output) unirep_output = unirep_output.mean(0) unirep_output = unirep_output.cpu().numpy() UNIREPEB_.append(unirep_output.tolist()) unirep_feature = pd.DataFrame(UNIREPEB_) col = ["UniRep_F" + str(i + 1) for i in range(0, 1900)] unirep_feature.columns = col unirep_feature = pd.concat([unirep_feature], axis=1) unirep_feature.index = PID # print(unirep_feature.shape) unirep_feature.to_csv("./dataset/unirep_feature.csv") print("Getting Deep Representation Learning Features with UniRep is done.") print("it took %0.3f mins.\n" % ((time.time() - T0) / 60)) return unirep_feature
def __init__(self, data_file: Union[str, Path], tokenizer: Union[str, TAPETokenizer] = 'iupac', label_dict = None): super().__init__() if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer if not os.path.exists(data_file): raise FileNotFoundError(data_file) if label_dict is None: self.label_dict = {'S': 0,'P': 1} #TODO make arg or learn from input data once else: self.label_dict = label_dict df = pd.read_csv(data_file, sep ='\t') self.data = df[df.columns[0]] self.labels = df[df.columns[1]] #no more pandas from here self.data = list(self.data) self.labels = list(self.labels)
os.makedirs(f_savepath, exist_ok=True) all_indices = {} if rpr=='protein': import torch from tape import ProteinBertModel, TAPETokenizer unique_protein = list(set(all_protein)) print(f'n unique protein used to compute repr: {len(unique_protein)}') unique_prot_to_idx = get_data_to_idx_mapping(all_protein) # init protein pretrained model model = ProteinBertModel.from_pretrained('bert-base') tokenizer = TAPETokenizer(vocab='iupac') results = Parallel(n_jobs=nworkers)(delayed(get_PROTrepr)(i,x,model,tokenizer,f_savepath) for i,x in enumerate(unique_protein)) all_indices = {} for x in results: i = x[0] prot = x[1] _all_idx = unique_prot_to_idx[prot] for _idx in _all_idx: all_indices[_idx] = i hp.save_pkl(f'{savepath}all_indices_{rpr}.pkl', all_indices) z_norma = True if z_norma:
import torch from tape import ProteinBertModel, TAPETokenizer model = ProteinBertModel.from_pretrained('bert-base') tokenizer = TAPETokenizer( vocab='iupac' ) # iupac is the vocab for TAPE models, use unirep for the UniRep model # Pfam Family: Hexapep, Clan: CL0536 sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ' token_ids = torch.tensor([tokenizer.encode(sequence)]) output = model(token_ids) sequence_output = output[0] pooled_output = output[1]
def DRLF_Embed(fastaFile, outFile, device=-2): path = fastaFile count = 0 SSAEMB_ = [] UNIREPEB_ = [] ##read Fasta File inData = fasta.fasta2csv(path) Seqs = inData["Seq"] PID_ = [] ##SSA Embedding print("SSA Embedding...") lm_embed, lstm_stack, proj = load_model( "./src/PretrainedModel/SSA_embed.model", use_cuda=True) with open(path, 'rb') as f: for name, sequence in fasta.parse_stream(f): pid = str(name.decode('utf-8')) if len(sequence) == 0: print('# WARNING: sequence', pid, 'has length=0. Skipping.', file=sys.stderr) continue PID_.append(pid) z = embed_sequence(sequence, lm_embed, lstm_stack, proj, final_only=True, pool='avg', use_cuda=True) SSAEMB_.append(z) count += 1 print(sequence, '# {} sequences processed...'.format(count), file=sys.stderr, end='\r') print("SSA embedding finished@") ssa_feature = pd.DataFrame(SSAEMB_) col = ["SSA_F" + str(i + 1) for i in range(0, 121)] ssa_feature.columns = col print("UniRep Embedding...") print("Loading UniRep Model...", file=sys.stderr, end='\r') model = UniRepModel.from_pretrained('babbler-1900') model = model.to(DEVICE) tokenizer = TAPETokenizer(vocab='unirep') count = 0 PID_ = inData["PID"] for sequence in Seqs: if len(sequence) == 0: print('# WARNING: sequence', pid, 'has length=0. Skipping.', file=sys.stderr) continue with torch.no_grad(): token_ids = torch.tensor([tokenizer.encode(sequence)]) token_ids = token_ids.to(DEVICE) output = model(token_ids) unirep_output = output[0] #print(unirep_output.shape) unirep_output = torch.squeeze(unirep_output) #print(unirep_output.shape) unirep_output = unirep_output.mean(0) unirep_output = unirep_output.cpu().numpy() # print(sequence,len(sequence),unirep_output.shape) UNIREPEB_.append(unirep_output.tolist()) count += 1 print(sequence, '# {} sequences processed...'.format(count), file=sys.stderr, end='\r') unirep_feature = pd.DataFrame(UNIREPEB_) col = ["UniRep_avg_F" + str(i + 1) for i in range(0, 1900)] unirep_feature.columns = col print("UniRep Embedding Finished@!") Features = pd.concat([ssa_feature, unirep_feature], axis=1) Features.index = PID_ Features.to_csv(outFile) print("Getting Deep Representation Learning Features is done.") return Features, inData
def main_training_loop(args: argparse.ArgumentParser): if args.enforce_walltime == True: loop_start_time = time.time() logger.info('Started timing loop') if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) #Setup Model tokenizer = TAPETokenizer(vocab='iupac') config = ProteinAWDLSTMConfig(**vars(args)) config.vocab_size = tokenizer.vocab_size if args.reset_hidden: logger.info(f'Resetting hidden state after {tokenizer.stop_token}') config.reset_token_id = tokenizer.convert_token_to_id( tokenizer.stop_token) model = ProteinAWDLSTMForLM(config) #training logger time_stamp = time.strftime("%y-%m-%d-%H-%M-%S", time.gmtime()) experiment_name = f"{args.experiment_name}_{model.base_model_prefix}_{time_stamp}" viz = visualization.get( args.output_dir, experiment_name, local_rank=-1 ) #debug=args.debug) #this -1 means traning is not distributed, debug makes experiment dry run for wandb train_data = Hdf5Dataset(os.path.join(args.data, 'train.hdf5'), batch_size=args.batch_size, bptt_length=args.bptt, buffer_size=args.buffer_size) val_data = Hdf5Dataset(os.path.join(args.data, 'valid.hdf5'), batch_size=args.batch_size, bptt_length=args.bptt, buffer_size=args.buffer_size) logger.info(f'Data loaded. One train epoch = {len(train_data)} steps.') logger.info(f'Data loaded. One valid epoch = {len(val_data)} steps.') train_loader = DataLoader(train_data, batch_size=1, collate_fn=train_data.collate_fn) val_loader = DataLoader(val_data, batch_size=1, collate_fn=train_data.collate_fn) #setup validation here so i can get a subsample from where i stopped each time i need it val_iterator = enumerate(val_loader) val_steps = 0 hidden = None #overwrite model when restarting/changing params if args.resume: logger.info(f'Loading pretrained model in {args.resume}') model = ProteinAWDLSTMForLM.from_pretrained(args.resume) if args.wandb_sweep: #This prevents errors. When model is partly set up from config, not commmandline, #config that is received from wandb might not match what is in args and ProteinConfig. #when then calling log_config, inconsistency would throw an error. #this overwrites args and ProteinConfig, so wandb has priority. #In case of doubt of match, check wandb run and save_pretrained config.json. Should always agree logger.info(f'Receiving config from wandb!') import wandb from training_utils import override_from_wandb override_from_wandb(wandb.config, args, config) model = ProteinAWDLSTMForLM(config) if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) if args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay) model.to(device) logger.info('Model set up!') num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f'Model has {num_parameters} trainable parameters') if torch.cuda.is_available(): model, optimizer = amp.initialize(model, optimizer, opt_level='O1') else: logger.info(f'Running model on {device}, not using nvidia apex') #set up wandb logging, tape visualizer class takes care of everything. just login to wandb in the env as usual viz.log_config(args) viz.log_config(model.config.to_dict()) viz.watch(model) logger.info( f'Logging experiment as {experiment_name} to wandb/tensorboard') #keep track of best loss num_epochs_no_improvement = 0 stored_loss = 100000000 learning_rate_steps = 0 global_step = 0 for epoch in range(1, args.epochs + 1): logger.info(f'Starting epoch {epoch}') viz.log_metrics({'Learning Rate': optimizer.param_groups[0]['lr']}, "train", global_step) epoch_start_time = time.time() start_time = time.time() #for lr update interval hidden = None for i, batch in enumerate(train_loader): data, targets = batch loss, reg_loss, hidden = training_step(model, data, targets, hidden, optimizer, args, i) viz.log_metrics( { 'loss': loss, 'regularized loss': reg_loss, 'perplexity': math.exp(loss), 'regularized perplexity': math.exp(reg_loss) }, "train", global_step) global_step += 1 update_steps = args.update_lr_steps if len( train_loader) > args.update_lr_steps else len( train_loader ) #ad hoc fix for smaller datasets, evaluate after full epochs # every update_lr_steps, evaluate performance and save model/progress in learning rate if global_step % update_steps == 0 and global_step > 0: total_loss = 0 total_reg_loss = 0 total_len = 0 #NOTE Plasmodium sets are 1% the size of Eukarya sets. run 1/100 of total set at each time #n_val_steps = (len(val_loader)//100) if len(val_loader) > 100000 else len(val_loader) #works because plasmodium set is smaller, don't want another arg for this #old border was too high, cannot train homology reduced eukarya 10 percent with it n_val_steps = ( len(val_loader) // 100 ) if len(val_loader) > 10000 else len( val_loader ) #works because plasmodium set is smaller, don't want another arg for this logger.info( f'Step {global_step}, validating for {n_val_steps} Validation steps' ) for j in range(n_val_steps): val_steps += 1 #if val_steps == len(val_loader): #reset the validation data when at its end # val_iterator = enumerate(val_loader) # hidden = None try: _, (data, targets) = next(val_iterator) except: val_iterator = enumerate(val_loader) logger.info( f'validation step{j}: resetting validation enumerator.' ) hidden = None _, (data, targets) = next(val_iterator) loss, reg_loss, hidden = validation_step( model, data, targets, hidden) total_len += len(data) total_reg_loss += reg_loss * len(data) total_loss += loss * len(data) val_reg_loss = total_reg_loss / total_len val_loss = total_loss / total_len val_metrics = { 'loss': val_loss, 'perplexity': math.exp(val_loss), 'regularized loss': val_reg_loss, 'regularized perplexity': math.exp(val_reg_loss) } viz.log_metrics(val_metrics, "val", global_step) elapsed = time.time() - start_time logger.info( f'Training step {global_step}, { elapsed / args.log_interval:.3f} s/batch. tr_loss: {loss:.2f}, tr_perplexity {math.exp(loss):.2f} va_loss: {val_loss:.2f}, va_perplexity {math.exp(val_loss):.2f}' ) start_time = time.time() if val_loss < stored_loss: num_epochs_no_improvement = 0 model.save_pretrained(args.output_dir) save_training_status(args.output_dir, epoch, global_step, num_epochs_no_improvement, stored_loss, learning_rate_steps) #also save with apex if torch.cuda.is_available(): checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict() } torch.save( checkpoint, os.path.join(args.output_dir, 'amp_checkpoint.pt')) logger.info( f'New best model with loss {val_loss}, Saving model, training step {global_step}' ) stored_loss = val_loss else: num_epochs_no_improvement += 1 logger.info( f'Step {global_step}: No improvement for {num_epochs_no_improvement} pseudo-epochs.' ) if num_epochs_no_improvement == args.wait_epochs: optimizer.param_groups[0][ 'lr'] = optimizer.param_groups[0][ 'lr'] * args.lr_step learning_rate_steps += 1 num_epochs_no_improvement = 0 logger.info( f'Step {global_step}: Decreasing learning rate. learning rate step {learning_rate_steps}.' ) viz.log_metrics( {'Learning Rate': optimizer.param_groups[0]['lr']}, "train", global_step) #break early after 5 lr steps if learning_rate_steps > 5: logger.info( 'Learning rate step limit reached, ending training early' ) return stored_loss if args.enforce_walltime == True and ( time.time() - loop_start_time) > 84600: #23.5 hours logger.info('Wall time limit reached, ending training early') return stored_loss logger.info(f'Epoch {epoch} training complete') logger.info( f'Epoch {epoch}, took {time.time() - epoch_start_time:.2f}.\t Train loss: {loss:.2f} \t Train perplexity: {math.exp(loss):.2f}' ) return stored_loss
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False) elif model_version == 'prot_bert': model = BertModel.from_pretrained("Rostlab/prot_bert", output_attentions=True) tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) elif model_version == 'prot_albert': model = AlbertModel.from_pretrained("Rostlab/prot_albert", output_attentions=True) tokenizer = AlbertTokenizer.from_pretrained("Rostlab/prot_albert", do_lower_case=False) else: model = ProteinBertModel.from_pretrained(model_version, output_attentions=True) tokenizer = TAPETokenizer() num_layers = model.config.num_hidden_layers num_heads = model.config.num_attention_heads elif args.model == 'xlnet': model_version = args.model_version if model_version == 'prot_xlnet': model = XLNetModel.from_pretrained("Rostlab/prot_xlnet", output_attentions=True) tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False) else: raise ValueError('Invalid model version') num_layers = model.config.n_layer num_heads = model.config.n_head else: raise ValueError(f"Invalid model: {args.model}")
def escape_tape(virus, vocabulary, pretrained='transformer'): if pretrained == 'transformer': from tape import ProteinBertModel model_class = ProteinBertModel model_name = 'bert-base' fname_prefix = 'tape_transformer' vocab = 'iupac' elif pretrained == 'unirep': from tape import UniRepModel model_class = UniRepModel model_name = 'babbler-1900' fname_prefix = 'unirep' vocab = 'unirep' from tape import TAPETokenizer model = model_class.from_pretrained(model_name) tokenizer = TAPETokenizer(vocab=vocab) seq, seqs_escape, train_fname, mut_fname, anchor_id = load(virus) if virus == 'h1': embed_fname = ('target/flu/embedding/{}_h1.npz'.format(fname_prefix)) elif virus == 'h3': embed_fname = ('target/flu/embedding/{}_h3.npz'.format(fname_prefix)) elif virus == 'bg505': embed_fname = ('target/hiv/embedding/{}_hiv.npz'.format(fname_prefix)) elif virus == 'sarscov2': embed_fname = ( 'target/cov/embedding/{}_sarscov2.npz'.format(fname_prefix)) elif virus == 'cov2rbd': embed_fname = ( 'target/cov/embedding/{}_sarscov2.npz'.format(fname_prefix)) else: raise ValueError('invalid option {}'.format(virus)) anchor = None for idx, record in enumerate(SeqIO.parse(train_fname, 'fasta')): if record.id == anchor_id: anchor = str(record.seq) assert (anchor is not None) base_embedding = tape_embed(anchor.replace('-', ''), model, tokenizer) with np.load(embed_fname, allow_pickle=True) as data: embeddings = {name: data[name][()]['avg'] for name in data.files} mutations = [str(record.seq) for record in SeqIO.parse(mut_fname, 'fasta')] mut2change = {} for mutation in mutations: didx = [c1 != c2 for c1, c2 in zip(anchor, mutation)].index(True) embedding = embeddings['mut_{}_{}'.format(didx, mutation[didx])] mutation_clean = mutation.replace('-', '') mut2change[mutation_clean] = abs(base_embedding - embedding).sum() anchor = anchor.replace('-', '') escape_idx, changes = [], [] for i in range(len(anchor)): if virus == 'bg505' and (i < 29 or i > 698): continue if virus == 'cov2rbd' and (i < 318 or i > 540): continue for word in vocabulary: if anchor[i] == word: continue mut_seq = anchor[:i] + word + anchor[i + 1:] if mut_seq in seqs_escape and \ (sum([ m['significant'] for m in seqs_escape[mut_seq] ]) > 0): escape_idx.append(len(changes)) changes.append(mut2change[mut_seq]) changes = np.array(changes) plot_result(-changes, escape_idx, virus, fname_prefix, legend_name='TAPE ({})'.format(fname_prefix))