Ejemplo n.º 1
0
    def encode(self, chunk, use_random=False):
        '''
        Args:
            chunk: a list of `n` strings, each being a SMILES.

        Returns:
            A numpy array of dtype np.float32, of shape (n, latent_dim)
            Note: Each row should be the *mean* of the latent space distrubtion rather than a sampled point from that distribution.
            (It can be anythin as long as it fits what self.decode expects)
        '''

        '''
        cfg_tree_list = []
        for smiles in chunk:
            ts = parser.parse(smiles, self.grammar)
            assert isinstance(ts, list) and len(ts) == 1

            n = AnnotatedTree2MolTree(ts[0])
            cfg_tree_list.append(n)
        '''
        if type(chunk[0]) is str:
            cfg_tree_list = parse(chunk, self.grammar)
        else:
            cfg_tree_list = chunk
            
        onehot, _ = batch_make_att_masks(cfg_tree_list, self.tree_decoder, self.onehot_walker, dtype=np.float32)

        x_inputs = np.transpose(onehot, [0, 2, 1])
        if use_random:
            self.ae.train()
        else:
            self.ae.eval()
        z_mean, _ = self.ae.encoder(x_inputs)

        return z_mean.data.cpu().numpy()
Ejemplo n.º 2
0
def process_chunk(smiles_list):
    grammar = parser.Grammar(cmd_args.grammar_file)

    cfg_tree_list = []
    for smiles in smiles_list:
        ts = parser.parse(smiles, grammar)
        assert isinstance(ts, list) and len(ts) == 1

        n = AnnotatedTree2MolTree(ts[0])
        cfg_tree_list.append(n)

    walker = OnehotBuilder()
    tree_decoder = create_tree_decoder()
    onehot, masks = batch_make_att_masks(cfg_tree_list, tree_decoder, walker, dtype=np.byte)

    return (onehot, masks)
Ejemplo n.º 3
0
def parse_smiles_with_cfg(smiles_file, grammar_file):
    grammar = parser.Grammar(cmd_args.grammar_file)

    cfg_tree_list = []
    with open(smiles_file, 'r') as f:
        for row in tqdm(f):
            smiles = row.strip()
            ts = parser.parse(smiles, grammar)
            assert isinstance(ts, list) and len(ts) == 1
            n = AnnotatedTree2MolTree(ts[0])
            cfg_tree_list.append(n)

    return cfg_tree_list

if __name__ == '__main__':

    cfg_tree_list = parse_smiles_with_cfg(cmd_args.smiles_file, cmd_args.grammar_file)

    all_true_binary, all_rule_masks = batch_make_att_masks(cfg_tree_list)
    
    print(all_true_binary.shape, all_rule_masks.shape)

    f_smiles = '.'.join(cmd_args.smiles_file.split('/')[-1].split('.')[0:-1])

    out_file = '%s/%s.h5' % (cmd_args.save_dir, f_smiles)    
    h5f = h5py.File(out_file, 'w')

    h5f.create_dataset('x', data=all_true_binary)
    h5f.create_dataset('masks', data=all_rule_masks)
    h5f.close()
Ejemplo n.º 4
0
    cfg_tree_list = []
    with open(smiles_file, 'r') as f:
        for row in tqdm(f):
            smiles = row.strip()
            ts = parser.parse(smiles, grammar)
            assert isinstance(ts, list) and len(ts) == 1
            n = AnnotatedTree2MolTree(ts[0])
            cfg_tree_list.append(n)

    return cfg_tree_list


if __name__ == '__main__':

    cfg_tree_list = parse_smiles_with_cfg(cmd_args.smiles_file,
                                          cmd_args.grammar_file)

    all_true_binary, all_rule_masks = batch_make_att_masks(cfg_tree_list)

    print(all_true_binary.shape, all_rule_masks.shape)

    f_smiles = '.'.join(cmd_args.smiles_file.split('/')[-1].split('.')[0:-1])

    out_file = '%s/%s.h5' % (cmd_args.save_dir, f_smiles)
    h5f = h5py.File(out_file, 'w')

    h5f.create_dataset('x', data=all_true_binary)
    h5f.create_dataset('masks', data=all_rule_masks)
    h5f.close()