コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-input_dir', type=str)
    parser.add_argument('-input_file',
                        type=str,
                        default='',
                        help='Optional single input file input')
    parser.add_argument('-target_file', type=str, required=True)
    parser.add_argument('-n_latent', type=int, default=0)
    parser.add_argument('-beam_size', type=int, default=5)
    parser.add_argument('-clean', action='store_true', default=False)
    parser.add_argument('-result_path', type=str)
    parser.add_argument('-seed', type=int)

    args = parser.parse_args()

    beam_size, n_latent = args.beam_size, args.n_latent

    if n_latent > 1:
        smiles_list = combine_latent(input_dir=args.input_dir,
                                     n_latent=n_latent,
                                     beam_size=beam_size,
                                     output_path='%s/combined' %
                                     args.input_dir,
                                     clean=args.clean)
    else:
        smiles_list = data_utils.read_file(args.input_file,
                                           beam_size=beam_size)

    target_list = data_utils.read_file(args.target_file, beam_size=1)
    match_smiles_lists(smiles_list, target_list, beam_size, args)
コード例 #2
0
ファイル: rxn_dataset.py プロジェクト: fallen32/retro_diverse
    def __init__(self, src_path, tgt_path, tgt_beam_size=1):
        # Assumes the data with reaction class labels

        def parse_func(line):
            splits = line.strip().split(' ')
            rxn_class_label = data_utils.parse_rxn_token(splits[0]) - 1
            smiles_tokens = splits[1:]
            return (rxn_class_label, ''.join(smiles_tokens))

        src_data = data_utils.read_file(src_path, parse_func=parse_func)
        classes, src_smiles = zip(*src_data)
        tgt_smiles = data_utils.read_file(tgt_path)

        # Expand the src input when the target has multiple solutions for each input
        if tgt_beam_size > 1:
            new_classes, new_src_smiles = [], []
            for idx in range(len(src_smiles)):
                new_classes += [classes[idx]] * tgt_beam_size
                new_src_smiles += [src_smiles[idx]] * tgt_beam_size
            classes = new_classes
            src_smiles = new_src_smiles

            # clean up targets, if smiles is invalid, replace with ''
            new_tgt_smiles = []
            for smiles in tgt_smiles:
                mol = Chem.MolFromSmiles(smiles)

                if mol is None:
                    new_tgt_smiles.append('')
                else:
                    new_tgt_smiles.append(smiles)
            tgt_smiles = new_tgt_smiles

        self.data = list(zip(src_smiles, tgt_smiles, classes))
コード例 #3
0
def combine_latent(input_dir,
                   n_latent,
                   beam_size,
                   output_path=None,
                   clean=False,
                   cross=False,
                   cross_ensem=False,
                   alternative=False):
    """
    Reads the output smiles from each of the latent classes and combines them.
    Args:
        input_dir: The path to the input directory containing output files
        n_latent: The number of latent classes used for the model
        beam_size: Number of smiles results per reaction
        output_path: If given, writes the combined smiles to this path
    """
    # results_path is the prefix for the different latent file outputs
    latent_list = []

    def parse(line):
        c_line = line.strip().replace(' ', '')
        smiles, score = c_line.split(',')
        score = float(score)
        return (smiles, score)

    for latent_idx in range(n_latent):
        file_path = '%s/output_%d' % (input_dir, latent_idx)
        smiles_list = data_utils.read_file(file_path,
                                           beam_size=beam_size,
                                           parse_func=parse)

        latent_list.append(smiles_list)

    combined_list = []

    if output_path is not None:
        output_file = open(output_path, 'w+')

    n_data = len(latent_list[0])
    for data_idx in tqdm.tqdm(range(n_data)):
        r_dict = {}
        for latent_idx in range(n_latent):
            output_list = latent_list[latent_idx][data_idx]
            for smiles, score in output_list:

                if clean:
                    smiles = data_utils.canonicalize(smiles)
                    if smiles == '':
                        continue

                if smiles not in r_dict:  # Add the output to dictionary
                    r_dict[smiles] = (score, latent_idx)
                else:
                    if score > r_dict[smiles][
                            0]:  # Update with the best score if applicable
                        r_dict[smiles] = (score, latent_idx)
        sorted_output = sorted(r_dict.items(),
                               key=operator.itemgetter(1),
                               reverse=True)
        top_smiles = []
        for beam_idx in range(beam_size):
            if beam_idx < len(sorted_output):
                smiles, (score, latent_idx) = sorted_output[beam_idx]
                top_smiles.append(smiles)

                if output_path is not None:
                    output_file.write('%s,%.4f,%d\n' %
                                      (smiles, score, latent_idx))

        combined_list.append(top_smiles)
    if output_path is not None:
        output_file.close()
    return combined_list
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-data_dir', required=True)
    parser.add_argument('-output_dir', required=True)
    parser.add_argument('-template', type=str)
    parser.add_argument('-n_aug', default=5, type=int)
    parser.add_argument('-type', choices=['sets', 'templates'], required=True)
    parser.add_argument('-transductive', action='store_true', default=False)
    args = parser.parse_args()

    # Make output directory
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Read the source SMILES
    src_smiles_dict = {}
    for data_type in ['train', 'val', 'test']:
        src_smiles_list = data_utils.read_file('%s/src-%s.txt' %
                                               (args.data_dir, data_type))
        src_smiles_dict[data_type] = src_smiles_list

    if args.type == 'templates':
        with open(args.template, 'r+') as template_file:
            template_dict = json.load(template_file)
            template_list = list(template_dict.keys())

    # If transductive flag is on, add the valid and test source smiles into the train set
    if args.transductive:
        src_smiles_dict['train'] = src_smiles_dict['train'] + src_smiles_dict[
            'val'] + src_smiles_dict['test']

    for data_type in ['train', 'val', 'test']:
        output_src = open('%s/src-%s.txt' % (args.output_dir, data_type), 'w+')
        output_tgt = open('%s/tgt-%s.txt' % (args.output_dir, data_type), 'w+')

        src_smiles_list = src_smiles_dict[data_type]
        for smiles in tqdm(src_smiles_list):
            if args.type == 'sets':
                new_smiles_list = []
                n = 0
                while len(new_smiles_list) < args.n_aug:
                    n += 1
                    if n > MAX_ITS:
                        break
                    new_smiles = get_random_set(smiles)
                    if new_smiles is not None and new_smiles not in new_smiles_list:
                        new_smiles_list.append(new_smiles)
                write_smiles(output_src, [smiles] * len(new_smiles_list))
                write_smiles(output_tgt, new_smiles_list)
            elif args.type == 'templates':
                # randomly shuffle the template list
                random.shuffle(template_list)
                new_smiles_list = []

                for template in template_list:
                    rd_rxn = rdchiralReaction(template)
                    rd_rct = rdchiralReactants(smiles)

                    outcome_list = rdchiralRun(rd_rxn, rd_rct)
                    if len(outcome_list) > 0:
                        random.shuffle(outcome_list)
                        new_smiles = outcome_list[0]
                        if new_smiles not in new_smiles_list:
                            new_smiles_list.append(new_smiles)
                    if len(new_smiles_list) >= args.n_aug:
                        break
                write_smiles(output_src, [smiles] * len(new_smiles_list))
                write_smiles(output_tgt, new_smiles_list)
コード例 #5
0
ファイル: draw_output.py プロジェクト: fallen32/retro_diverse
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-data_dir', default='data/stanford_clean')
    parser.add_argument('-base_output', required=True)
    parser.add_argument('-mixture_output', required=True)
    parser.add_argument('-n_output', type=int, default=100,
                        help='Number of random output examples')
    parser.add_argument('-n_draw', type=int, default=5,
                        help='Number of output to draw per example')
    parser.add_argument('-output_dir', required=True)
    parser.add_argument('-beam_size', type=int, default=10)
    parser.add_argument('-dim', type=int, default=500)
    args = parser.parse_args()

    dim = args.dim
    n_draw = args.n_draw
    beam_size = args.beam_size
    n_output = args.n_output

    assert n_draw <= beam_size  # Cannot draw more than there are examples

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    examples_dir = '%s/examples' % output_dir
    if not os.path.exists(examples_dir):
        os.makedirs(examples_dir)

    # Load data
    data_type = 'test'
    src_path = '%s/src-%s.txt' % (args.data_dir, data_type)
    tgt_path = '%s/tgt-%s.txt' % (args.data_dir, data_type)

    def parse_line_with_class_label(line):
        splits = line.strip().split(' ')
        rxn_class_label = splits[0]
        smiles_tokens = splits[1:]
        return (rxn_class_label, ''.join(smiles_tokens))

    src_data = data_utils.read_file(
        src_path, parse_func=parse_line_with_class_label)
    src_class, src_smiles = zip(*src_data)
    tgt_smiles = data_utils.read_file(tgt_path)
    print('Data loaded...')

    base_smiles = data_utils.read_file(
        args.base_output, beam_size=beam_size)
    print('Base model output smiles loaded...')
    mixture_smiles = data_utils.read_file(
        args.mixture_output, beam_size=beam_size)
    print('Mixture model output smiles loaded...')

    n_data = len(src_smiles)
    indices = list(range(n_data))

    random.shuffle(indices)
    selected_indices = indices[:n_output]

    src_class = [src_class[i] for i in selected_indices]
    src_smiles = [src_smiles[i] for i in selected_indices]
    tgt_smiles = [tgt_smiles[i] for i in selected_indices]

    base_smiles = [base_smiles[i] for i in selected_indices]
    mixture_smiles = [mixture_smiles[i] for i in selected_indices]

    output_file = open('%s/example_labels.txt' % output_dir, 'w+')
    smiles_file = open('%s/smiles.txt' % output_dir, 'w+')

    for idx in tqdm(range(n_output)):
        rxn_class = src_class[idx]

        cur_src_smiles = src_smiles[idx]
        cur_tgt_smiles = tgt_smiles[idx]

        smiles_file.write('%s,%s\n' % (cur_src_smiles, cur_tgt_smiles))

        base_smiles_beam = base_smiles[idx]
        mixture_smiles_beam = mixture_smiles[idx]

        if random.random() > 0.5:
            beam_1 = base_smiles_beam
            beam_2 = mixture_smiles_beam
            output_file.write('Example %d,%s,%s\n' % (idx, 'base', 'mixture'))
        else:
            beam_1 = mixture_smiles_beam
            beam_2 = base_smiles_beam
            output_file.write('Example %d,%s,%s\n' % (idx, 'mixture', 'base'))

        smiles_list = [cur_src_smiles, cur_tgt_smiles]

        for beam_idx in range(beam_size):
            smiles_list += [beam_1[beam_idx], beam_2[beam_idx]]

        draw_smiles_list = []
        for smiles in smiles_list:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                draw_smiles_list.append('')
            else:
                draw_smiles_list.append(smiles)

        draw_mols = [prep_mol(smiles) for smiles in draw_smiles_list]

        n_x, n_y = 2, args.n_draw + 1
        drawer = rdMolDraw2D.MolDraw2DSVG(n_x * dim, n_y * dim, dim, dim)
        drawer.SetFontSize(0.6)

        drawer.DrawMolecules(draw_mols)

        drawer.FinishDrawing()
        svg = drawer.GetDrawingText()

        temp_path = '%s/temp' % output_dir
        f_temp = open(temp_path, 'w+')
        f_temp.write(svg)
        f_temp.close()

        cairosvg.svg2png(
            url=temp_path, write_to='%s/example_%d.png' % (examples_dir, idx))

    output_file.close()