def batch_make_att_masks(node_list, tree_decoder=None, walker=None, dtype=None): """ tbd """ if dtype is None: dtype=np.byte if walker is None: walker = OnehotBuilder() if tree_decoder is None: tree_decoder = create_tree_decoder() true_binary = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) rule_masks = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype) for i in range(len(node_list)): node = node_list[i] tree_decoder.decode(node, walker) true_binary[i, np.arange(walker.num_steps), walker.global_rule_used[:walker.num_steps]] = 1 true_binary[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1 for j in range(walker.num_steps): rule_masks[i, j, walker.mask_list[j]] = 1 rule_masks[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1.0 return true_binary, rule_masks
def __init__(self, *args, **kwargs): # get model config model_config = json.load(open(cmd_args.model_config, 'r')) self.ae = MolVAE(model_config) # load model weights model_weights = paddle.load(cmd_args.saved_model) self.ae.set_state_dict(model_weights) self.onehot_walker = OnehotBuilder() self.tree_decoder = create_tree_decoder() self.grammar = parser.Grammar(cmd_args.grammar_file)
def process_chunk(smiles_list): grammar = parser.Grammar(cmd_args.grammar_file) cfg_tree_list = [] for smiles in smiles_list: ts = parser.parse(smiles, grammar) assert isinstance(ts, list) and len(ts) == 1 n = AnnotatedTree2MolTree(ts[0]) cfg_tree_list.append(n) walker = OnehotBuilder() tree_decoder = create_tree_decoder() onehot, masks = batch_make_att_masks(cfg_tree_list, tree_decoder, walker, dtype=np.byte) return (onehot, masks)
def __init__(self, *args, **kwargs): if cmd_args.ae_type == 'vae': self.ae = MolVAE() elif cmd_args.ae_type == 'autoenc': self.ae = MolAutoEncoder() else: raise Exception('unknown ae type %s' % cmd_args.ae_type) if cmd_args.mode == 'gpu': self.ae = self.ae.cuda() assert cmd_args.saved_model is not None if cmd_args.mode == 'cpu': self.ae.load_state_dict(torch.load(cmd_args.saved_model, map_location=lambda storage, loc: storage)) else: self.ae.load_state_dict(torch.load(cmd_args.saved_model)) self.onehot_walker = OnehotBuilder() self.tree_decoder = create_tree_decoder() self.grammar = parser.Grammar(cmd_args.grammar_file)