def batch_make_att_masks(node_list,
                         tree_decoder=None,
                         walker=None,
                         dtype=np.byte):
    if walker is None:
        walker = OnehotBuilder()
    if tree_decoder is None:
        tree_decoder = create_tree_decoder()

    true_binary = np.zeros(
        (len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype)
    rule_masks = np.zeros(
        (len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype)

    for i in range(len(node_list)):
        node = node_list[i]
        tree_decoder.decode(node, walker)

        true_binary[i,
                    np.arange(walker.num_steps),
                    walker.global_rule_used[:walker.num_steps]] = 1
        true_binary[i,
                    np.arange(walker.num_steps, cmd_args.max_decode_steps),
                    -1] = 1

        for j in range(walker.num_steps):
            rule_masks[i, j, walker.mask_list[j]] = 1

        rule_masks[i,
                   np.arange(walker.num_steps, cmd_args.max_decode_steps),
                   -1] = 1.0

    return true_binary, rule_masks
示例#2
0
def decode_chunk(raw_logits, use_random, decode_times):
    """
    tbd
    """
    tree_decoder = create_tree_decoder()
    chunk_result = [[] for _ in range(raw_logits.shape[1])]

    for i in tqdm(range(raw_logits.shape[1])):
        pred_logits = raw_logits[:, i, :]
        walker = ConditionalDecoder(np.squeeze(pred_logits), use_random)

        for _decode in range(decode_times):
            new_t = Node('smiles')
            try:
                tree_decoder.decode(new_t, walker)
                sampled = get_smiles_from_tree(new_t)
            except Exception as ex:
                if not type(ex).__name__ == 'DecodingLimitExceeded':
                    print('Warning, decoder failed with', ex)
                # failed. output a random junk.
                import random
                import string
                sampled = 'JUNK' + ''.join(
                    random.choice(string.ascii_uppercase + string.digits)
                    for _ in range(256))

            chunk_result[i].append(sampled)

    return chunk_result
示例#3
0
    def __init__(self, *args, **kwargs):
        # get model config
        model_config = json.load(open(cmd_args.model_config, 'r'))
        self.ae = MolVAE(model_config)

        # load model weights
        model_weights = paddle.load(cmd_args.saved_model)
        self.ae.set_state_dict(model_weights)

        self.onehot_walker = OnehotBuilder()
        self.tree_decoder = create_tree_decoder()
        self.grammar = parser.Grammar(cmd_args.grammar_file)
示例#4
0
def process_chunk(smiles_list):
    grammar = parser.Grammar(cmd_args.grammar_file)

    cfg_tree_list = []
    for smiles in smiles_list:
        ts = parser.parse(smiles, grammar)
        assert isinstance(ts, list) and len(ts) == 1

        n = AnnotatedTree2MolTree(ts[0])
        cfg_tree_list.append(n)

    walker = OnehotBuilder()
    tree_decoder = create_tree_decoder()
    onehot, masks = batch_make_att_masks(cfg_tree_list, tree_decoder, walker, dtype=np.byte)

    return (onehot, masks)
示例#5
0
    def __init__(self, *args, **kwargs):
        if cmd_args.ae_type == 'vae':
            self.ae = MolVAE()
        elif cmd_args.ae_type == 'autoenc':
            self.ae = MolAutoEncoder()
        else:
            raise Exception('unknown ae type %s' % cmd_args.ae_type)
        if cmd_args.mode == 'gpu':
            self.ae = self.ae.cuda()

        assert cmd_args.saved_model is not None

	if cmd_args.mode == 'cpu':
		self.ae.load_state_dict(torch.load(cmd_args.saved_model, map_location=lambda storage, loc: storage))
        else:
        	self.ae.load_state_dict(torch.load(cmd_args.saved_model))

        self.onehot_walker = OnehotBuilder()
        self.tree_decoder = create_tree_decoder()
        self.grammar = parser.Grammar(cmd_args.grammar_file)
示例#6
0
def batch_make_att_masks(node_list, tree_decoder = None, walker = None, dtype=np.byte):
    if walker is None:
        walker = OnehotBuilder()
    if tree_decoder is None:
        tree_decoder = create_tree_decoder()

    true_binary = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype)
    rule_masks = np.zeros((len(node_list), cmd_args.max_decode_steps, DECISION_DIM), dtype=dtype)

    for i in range(len(node_list)):
        node = node_list[i]
        tree_decoder.decode(node, walker)

        true_binary[i, np.arange(walker.num_steps), walker.global_rule_used[:walker.num_steps]] = 1
        true_binary[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1

        for j in range(walker.num_steps):
            rule_masks[i, j, walker.mask_list[j]] = 1

        rule_masks[i, np.arange(walker.num_steps, cmd_args.max_decode_steps), -1] = 1.0

    return true_binary, rule_masks
示例#7
0
def decode_chunk(raw_logits, use_random, decode_times):
    tree_decoder = create_tree_decoder()    
    chunk_result = [[] for _ in range(raw_logits.shape[1])]
        
    for i in tqdm(range(raw_logits.shape[1])):
        pred_logits = raw_logits[:, i, :]
        walker = ConditionalDecoder(np.squeeze(pred_logits), use_random)

        for _decode in range(decode_times):
            new_t = Node('smiles')
            try:
                tree_decoder.decode(new_t, walker)
                sampled = get_smiles_from_tree(new_t)
            except Exception as ex:
                if not type(ex).__name__ == 'DecodingLimitExceeded':
                    print('Warning, decoder failed with', ex)
                # failed. output a random junk.
                import random, string
                sampled = 'JUNK' + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(256))

            chunk_result[i].append(sampled)

    return chunk_result
def raw_logit_to_smile_labels(raw_logits, use_random=False):
    y = []
    index = []
    result_list = []
    for i in range(raw_logits.shape[1]):
        pred_logits = raw_logits[:, i, :]
        walker = ConditionalDecoder(np.squeeze(pred_logits), use_random)
        new_t = Node('smiles')
        try:
            tree_decoder = create_tree_decoder()
            tree_decoder.decode(new_t, walker)
            sampled = get_smiles_from_tree(new_t)

        except Exception as ex:
            if not type(ex).__name__ == 'DecodingLimitExceeded':
                print('Warning, decoder failed with', ex)

            # failed. output None
            sampled = None

        if sampled is None:
            continue

        mol = Chem.MolFromSmiles(sampled)

        ## decoded smile is not valid molecule
        if mol is None:
            continue

        logP = Descriptors.MolLogP(mol)
        sa_score = sascorer.calculateScore(mol)
        qed = QED.qed(mol)
        y.append([qed, sa_score, logP])
        result_list.append(sampled)
        index.append(i)
    return result_list, index, y