Exemple #1
0
def acquire(dataset_key: str):
    """
    Downloads and/or generates and preprocesses a selected dataset.

    :param dataset_key: key of the dataset
    """
    dataset = get_dataset(dataset_key)
    logger.info(f"Acquiring dataset for key {dataset_key}")
    dataset.acquire()
    logger.info(f"Dataset for key {dataset_key} acquired successfully!")
Exemple #2
0
def featurize(dataset_key: str, featurizer_key: str):
    """
    Featurize dataset using a selected method

    :param dataset_key: key of the dataset
    :param featurizer_key: key of the dataset
    """
    dataset = get_dataset(dataset_key)
    featurizer = get_featurizer(featurizer_key)
    logger.info(
        f"Featurizing with '{featurizer_key}' on dataset '{dataset_key}'")
    featurizer.featurize_dataset(dataset)
    logger.info(
        f"Finished featurizing with '{featurizer_key}' on dataset '{dataset_key}'!"
    )
Exemple #3
0
if cfg['training']['model_selection_mode'] == 'maximize':
    model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
    model_selection_sign = -1
else:
    raise ValueError('model_selection_mode must be '
                     'either maximize or minimize.')

# Output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

shutil.copyfile(args.config, os.path.join(out_dir, 'config.yaml'))

# Dataset
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg, return_idx=True)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=cfg['training']['n_workers'],
    shuffle=True,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=1,
    num_workers=cfg['training']['n_workers_val'],
    shuffle=False,
Exemple #4
0
is_cuda = (torch.cuda.is_available() and not args.no_cuda)
device = torch.device("cuda" if is_cuda else "cpu")

out_dir = cfg['training']['out_dir']
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')

batch_size = cfg['generation']['batch_size']
input_type = cfg['data']['input_type']
vis_n_outputs = cfg['generation']['vis_n_outputs']
if vis_n_outputs is None:
    vis_n_outputs = -1

# Dataset
dataset = config.get_dataset('test', cfg, return_idx=True)
print(dataset)

# Model
model = config.get_model(cfg, device=device, dataset=dataset)

checkpoint_io = CheckpointIO(out_dir, model=model)
checkpoint_io.load(cfg['test']['model_file'])

# Generator
generator = config.get_generator(model, cfg, device=device)

# Determine what to generate
generate_mesh = cfg['generation']['generate_mesh']
generate_pointcloud = cfg['generation']['generate_pointcloud']
Exemple #5
0
def evaluate_megan(save_path: str,
                   beam_size: int = 10,
                   max_gen_steps: int = 16,
                   beam_batch_size: int = 10,
                   show_every: int = 100,
                   n_max_atoms: int = 200,
                   dataset_key: str = 'uspto_50k',
                   split_type: str = 'default',
                   split_key: str = 'test',
                   results_file: str = ''):
    """
    Evaluate MEGAN model
    """
    config_path = os.path.join(save_path, 'config.gin')
    gin.parse_config_file(config_path)

    dataset = config.get_dataset(dataset_key)
    featurizer_key = gin.query_parameter('train_megan.featurizer_key')
    featurizer = get_featurizer(featurizer_key)
    assert isinstance(featurizer, MeganTrainingSamplesFeaturizer)
    action_vocab = featurizer.get_actions_vocabulary(save_path)

    base_action_masks = get_base_action_masks(n_max_atoms + 1,
                                              action_vocab=action_vocab)

    logger.info("Creating model...")
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    predict_forward = featurizer.forward

    logger.info("Loading data...")
    x_df = dataset.load_x()
    meta_df = dataset.load_metadata()
    split_df = config.get_split(split_type).load(dataset.dir)

    model_path = os.path.join(save_path, 'model_best.pt')
    checkpoint = load_state_dict(model_path)
    model = Megan(n_atom_actions=action_vocab['n_atom_actions'],
                  n_bond_actions=action_vocab['n_bond_actions'],
                  prop2oh=action_vocab['prop2oh']).to(device)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    split_ind = np.argwhere(split_df[split_key] == 1).flatten()
    if 'class' in meta_df:
        split_ind = np.argwhere((split_df[split_key] == 1)
                                & (meta_df['class'] == 1)).flatten()

    np.random.shuffle(split_ind)
    logger.info(f"Evaluating on {len(split_ind)} samples from {split_key}")

    top_k = np.zeros(beam_size, dtype=float)
    accs = np.zeros(beam_size, dtype=float)

    def split_every(n, iterable):
        i = iter(iterable)
        piece = list(islice(i, n))
        while piece:
            yield piece
            piece = list(islice(i, n))

    n_batches = int(np.ceil(len(split_ind)) / beam_batch_size)
    prog_bar = tqdm(desc=f'{save_path} beam search on {split_key}',
                    total=len(split_ind))

    start_time = time.time()
    n_samples, n_gen_reactions = 0, 0
    is_incorrect, is_duplicate = [], []
    n_preds = np.zeros(len(split_ind), dtype=int)

    pred_path = os.path.join(
        save_path, f'pred_{split_key}_{beam_size}_{max_gen_steps}.txt')
    pred_path_i = 1
    while os.path.exists(pred_path):
        pred_path = os.path.join(
            save_path,
            f'pred_{split_key}_{beam_size}_{max_gen_steps}_{pred_path_i}.txt')
        pred_path_i += 1
    target_sub_key = 'reactants' if 'reactants' in x_df else 'substrates'

    for batch_i, batch_ind in enumerate(split_every(beam_batch_size,
                                                    split_ind)):
        input_mols = []
        target_mols = []

        for ind in batch_ind:
            if predict_forward:
                input_mapped = x_df['substrates'][ind]
                target_mapped = x_df['product'][ind]
            else:
                input_mapped = x_df['product'][ind]
                target_mapped = x_df[target_sub_key][ind]

            try:
                target_mol = Chem.MolFromSmiles(target_mapped)
                input_mol = Chem.MolFromSmiles(input_mapped)

                # mark reactants (this is used only by models that use such information)
                if featurizer.forward:
                    mark_reactants(input_mol, target_mol)

                # remap input and target molecules according to canonical SMILES atom order
                input_mol, target_mol = remap_reaction_to_canonical(
                    input_mol, target_mol)

                # fix a bug in marking explicit Hydrogen atoms by RdKit
                input_mol = fix_explicit_hs(input_mol)

            except Exception as e:
                logger.warning(f'Exception while input mol to SMILES {str(e)}')
                input_mol = None
                target_mol = None

            if input_mol is None or target_mol is None:
                input_mols.append(None)
                target_mols.append(None)
                continue

            input_mols.append(input_mol)
            target_mols.append(target_mol)

        if 'reaction_type_id' in meta_df:
            reaction_types = meta_df['reaction_type_id'][batch_ind].values
        else:
            reaction_types = None

        rdkit_cache = RdkitCache(props=action_vocab['props'])

        with torch.no_grad():
            beam_search_results = beam_search(
                [model],
                input_mols,
                rdkit_cache=rdkit_cache,
                max_steps=max_gen_steps,
                beam_size=beam_size,
                batch_size=beam_batch_size,
                base_action_masks=base_action_masks,
                max_atoms=n_max_atoms,
                reaction_types=reaction_types,
                action_vocab=action_vocab)

        with open(pred_path, 'a') as fp:
            for sample_i, ind in enumerate(batch_ind):
                input_mol, target_mol = input_mols[sample_i], target_mols[
                    sample_i]
                try:
                    target_smi = mol_to_unmapped_smiles(target_mol)
                    target_mapped = Chem.MolToSmiles(target_mol)
                except Exception as e:
                    logger.info(f"Exception while target to smi: {str(e)}")
                    n_samples += 1
                    continue

                has_correct = False
                final_smis = set()

                results = beam_search_results[sample_i]
                n_preds[n_samples] = len(results)

                fp.write(
                    f'{ind} {Chem.MolToSmiles(input_mol)} {target_smi} {target_mapped}\n'
                )

                for i, path in enumerate(results):
                    if path['final_smi_unmapped']:
                        try:
                            final_mol = Chem.MolFromSmiles(
                                path['final_smi_unmapped'])

                            if final_mol is None:
                                final_smi = path['final_smi_unmapped']
                            else:
                                input_mol, final_mol = remap_reaction_to_canonical(
                                    input_mol, final_mol)
                                final_smi = mol_to_unmapped_smiles(final_mol)

                        except Exception as e:
                            final_smi = path['final_smi_unmapped']
                    else:
                        final_smi = path['final_smi_unmapped']

                    # for forward prediction, if we generate more than 1 product we heuristically select the biggest one
                    if predict_forward:
                        final_smi = list(sorted(final_smi.split('.'), key=len))
                        final_smi = final_smi[-1]

                    str_actions = '|'.join(f"({str(a)};{p})"
                                           for a, p in path['actions'])
                    str_ch = '{' + ','.join(
                        [str(c) for c in path['changed_atoms']]) + '}'
                    fp.write(
                        f'{i} {path["final_smi"]} {final_smi} {str_ch} {str_actions}\n'
                    )
                    is_duplicate.append(final_smi in final_smis)
                    is_incorrect.append(final_smi is None or final_smi == '')
                    final_smis.add(final_smi)
                    correct = prediction_is_correct(final_smi, target_smi)
                    # correct = final_smi == target_smi
                    if correct and not has_correct:
                        top_k[i:] += 1
                        accs[i] += 1
                        has_correct = True
                    n_gen_reactions += 1
                fp.write('\n')
                n_samples += 1

        if (batch_i > 0
                and batch_i % show_every == 0) or batch_i >= n_batches - 1:
            print("^" * 100)
            print(
                f'Beam search parameters: beam size={beam_size}, max steps={max_gen_steps}'
            )
            print()
            for k, top in enumerate(top_k):
                acc = accs[k]
                print('Top {:3d}: {:7.4f}% cum {:7.4f}%'.format(
                    k + 1, acc * 100 / n_samples, top * 100 / n_samples))
            print()
            avg_incorrect = '{:.4f}%'.format(100 * np.sum(is_incorrect) /
                                             len(is_incorrect))
            avg_duplicates = '{:.4f}%'.format(100 * np.sum(is_duplicate) /
                                              len(is_duplicate))
            avg_n_preds = '{:.4f}'.format(n_gen_reactions / n_samples)
            less_preds = '{:.4f}%'.format(
                100 * np.sum(n_preds[:n_samples] < beam_size) / n_samples)
            zero_preds = '{:.4f}%'.format(
                100 * np.sum(n_preds[:n_samples] == 0) / n_samples)
            print(
                f'Avg incorrect reactions in Top {beam_size}: {avg_incorrect}')
            print(
                f'Avg duplicate reactions in Top {beam_size}: {avg_duplicates}'
            )
            print(f'Avg number of predictions per target: {avg_n_preds}')
            print(f'Targets with < {beam_size} predictions: {less_preds}')
            print(f'Targets with zero predictions: {zero_preds}')
            print()

        prog_bar.update(len(batch_ind))
    prog_bar.close()

    total_time = time.time() - start_time
    s_targets = '{:.4f}'.format(total_time / n_samples)
    s_reactions = '{:.4f}'.format(n_gen_reactions / total_time)
    total_time = '{:.4f}'.format(total_time)
    avg_incorrect = '{:.4f}%'.format(100 * np.sum(is_incorrect) /
                                     len(is_incorrect))
    avg_duplicates = '{:.4f}%'.format(100 * np.sum(is_duplicate) /
                                      len(is_duplicate))
    avg_n_preds = '{:.4f}'.format(n_gen_reactions / n_samples)
    less_preds = '{:.4f}%'.format(100 * np.sum(n_preds < beam_size) /
                                  n_samples)
    zero_preds = '{:.4f}%'.format(100 * np.sum(n_preds == 0) / n_samples)

    summary_path = \
        os.path.join(save_path, f'eval_{split_key}_{beam_size}_{max_gen_steps}.txt')

    with open(summary_path, 'w') as fp:
        fp.write(f'Evaluation on {split_key} set ({split_type} split)\n')
        fp.write(f'Beam size = {beam_size}, batch size = {beam_batch_size}, '
                 f'max gen steps = {max_gen_steps}\n')
        fp.write(
            f'Avg incorrect reactions in Top {beam_size}: {avg_incorrect}\n')
        fp.write(
            f'Avg duplicate reactions in Top {beam_size}: {avg_duplicates}\n')
        fp.write(f'Avg number of predictions per target: {avg_n_preds}\n')
        fp.write(f'Targets with < {beam_size} predictions: {less_preds}\n')
        fp.write(f'Targets with zero predictions: {zero_preds}\n')
        fp.write(
            f'Total evaluation time on {len(split_ind)} targets: {total_time} seconds '
            f'({s_targets} seconds per target, {s_reactions} reactions per second) \n\n'
        )
        for k, top in enumerate(top_k):
            acc = accs[k]
            fp.write('Top {:3d}: {:7.4f}% cum {:7.4f}%\n'.format(
                k + 1, acc * 100 / n_samples, top * 100 / n_samples))

    if results_file:
        top_k_str = ' '.join('{:7.2f}%'.format(top * 100 / n_samples)
                             for top in top_k)
        with open(results_file, 'a') as fp:
            fp.write('{:>50s}: {:<s}\n'.format(save_path, top_k_str))

    logger.info(f'Saved Top {beam_size} to {summary_path}')
Exemple #6
0
def get_dataset(dataset_key: str = gin.REQUIRED) -> Dataset:
    return config.get_dataset(dataset_key)