def batch_make_att_masks(node_list, tree_decoder=None, walker=None, dtype=np.byte): if walker is None: walker = OnehotBuilder() if tree_decoder is None: tree_decoder = create_tree_decoder() true_binary = np.zeros( (len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) rule_masks = np.zeros( (len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) for i in range(len(node_list)): node = node_list[i] tree_decoder.decode(node, walker) true_binary[i, np.arange(walker.num_steps), walker.global_rule_used[:walker.num_steps]] = 1 true_binary[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1 for j in range(walker.num_steps): rule_masks[i, j, walker.mask_list[j]] = 1 rule_masks[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1.0 return true_binary, rule_masks
def decode_chunk(raw_logits, use_random, decode_times): """ tbd """ tree_decoder = create_tree_decoder() chunk_result = [[] for _ in range(raw_logits.shape[1])] for i in tqdm(range(raw_logits.shape[1])): pred_logits = raw_logits[:, i, :] walker = ConditionalDecoder(np.squeeze(pred_logits), use_random) for _decode in range(decode_times): new_t = Node('smiles') try: tree_decoder.decode(new_t, walker) sampled = get_smiles_from_tree(new_t) except Exception as ex: if not type(ex).__name__ == 'DecodingLimitExceeded': print('Warning, decoder failed with', ex) # failed. output a random junk. import random import string sampled = 'JUNK' + ''.join( random.choice(string.ascii_uppercase + string.digits) for _ in range(256)) chunk_result[i].append(sampled) return chunk_result
def __init__(self, *args, **kwargs): # get model config model_config = json.load(open(cmd_args.model_config, 'r')) self.ae = MolVAE(model_config) # load model weights model_weights = paddle.load(cmd_args.saved_model) self.ae.set_state_dict(model_weights) self.onehot_walker = OnehotBuilder() self.tree_decoder = create_tree_decoder() self.grammar = parser.Grammar(cmd_args.grammar_file)
def process_chunk(smiles_list): grammar = parser.Grammar(cmd_args.grammar_file) cfg_tree_list = [] for smiles in smiles_list: ts = parser.parse(smiles, grammar) assert isinstance(ts, list) and len(ts) == 1 n = AnnotatedTree2MolTree(ts[0]) cfg_tree_list.append(n) walker = OnehotBuilder() tree_decoder = create_tree_decoder() onehot, masks = batch_make_att_masks(cfg_tree_list, tree_decoder, walker, dtype=np.byte) return (onehot, masks)
def __init__(self, *args, **kwargs): if cmd_args.ae_type == 'vae': self.ae = MolVAE() elif cmd_args.ae_type == 'autoenc': self.ae = MolAutoEncoder() else: raise Exception('unknown ae type %s' % cmd_args.ae_type) if cmd_args.mode == 'gpu': self.ae = self.ae.cuda() assert cmd_args.saved_model is not None if cmd_args.mode == 'cpu': self.ae.load_state_dict(torch.load(cmd_args.saved_model, map_location=lambda storage, loc: storage)) else: self.ae.load_state_dict(torch.load(cmd_args.saved_model)) self.onehot_walker = OnehotBuilder() self.tree_decoder = create_tree_decoder() self.grammar = parser.Grammar(cmd_args.grammar_file)
def batch_make_att_masks(node_list, tree_decoder = None, walker = None, dtype=np.byte): if walker is None: walker = OnehotBuilder() if tree_decoder is None: tree_decoder = create_tree_decoder() true_binary = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) rule_masks = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) for i in range(len(node_list)): node = node_list[i] tree_decoder.decode(node, walker) true_binary[i, np.arange(walker.num_steps), walker.global_rule_used[:walker.num_steps]] = 1 true_binary[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1 for j in range(walker.num_steps): rule_masks[i, j, walker.mask_list[j]] = 1 rule_masks[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1.0 return true_binary, rule_masks
def decode_chunk(raw_logits, use_random, decode_times): tree_decoder = create_tree_decoder() chunk_result = [[] for _ in range(raw_logits.shape[1])] for i in tqdm(range(raw_logits.shape[1])): pred_logits = raw_logits[:, i, :] walker = ConditionalDecoder(np.squeeze(pred_logits), use_random) for _decode in range(decode_times): new_t = Node('smiles') try: tree_decoder.decode(new_t, walker) sampled = get_smiles_from_tree(new_t) except Exception as ex: if not type(ex).__name__ == 'DecodingLimitExceeded': print('Warning, decoder failed with', ex) # failed. output a random junk. import random, string sampled = 'JUNK' + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(256)) chunk_result[i].append(sampled) return chunk_result
def raw_logit_to_smile_labels(raw_logits, use_random=False): y = [] index = [] result_list = [] for i in range(raw_logits.shape[1]): pred_logits = raw_logits[:, i, :] walker = ConditionalDecoder(np.squeeze(pred_logits), use_random) new_t = Node('smiles') try: tree_decoder = create_tree_decoder() tree_decoder.decode(new_t, walker) sampled = get_smiles_from_tree(new_t) except Exception as ex: if not type(ex).__name__ == 'DecodingLimitExceeded': print('Warning, decoder failed with', ex) # failed. output None sampled = None if sampled is None: continue mol = Chem.MolFromSmiles(sampled) ## decoded smile is not valid molecule if mol is None: continue logP = Descriptors.MolLogP(mol) sa_score = sascorer.calculateScore(mol) qed = QED.qed(mol) y.append([qed, sa_score, logP]) result_list.append(sampled) index.append(i) return result_list, index, y