Example #1
0
def _get_case_data(mongo_client):
    """
    Get a :class:`.CaseData` instance.

    Parameters
    ----------
    mongo_client : :class:`pymongo.MongoClient`
        The mongo client the database should connect to.

    """

    # The basic idea here is that the _counter.get_count method will
    # return a different "fitness value" each time it is called.
    # When the test runs fitness_calculator.get_fitness_value(), if
    # caching is working, the same number as before will be returned.
    # However, if caching is not working, a different number will be
    # returned as the fitness value.

    db = stk.ValueMongoDb(
        mongo_client=mongo_client,
        collection='test_caching',
        database='_stk_pytest_database',
    )

    fitness_calculator = stk.FitnessFunction(
        fitness_function=_counter.get_count,
        input_database=db,
        output_database=db,
    )
    molecule = stk.BuildingBlock('BrCCBr')
    fitness_value = fitness_calculator.get_fitness_value(molecule)

    return CaseData(
        fitness_calculator=fitness_calculator,
        molecule=molecule,
        fitness_value=fitness_value,
    )
Example #2
0
    return sum(scores)


# Defines synthetic accesibility function to use.
synthetic_accesibility_func = scscore

cage_fitness_calculator = stk.PropertyVector(
    pore_diameter,
    largest_window,
    window_std,
    synthetic_accesibility_func,
)

fitness_calculator = stk.If(
    condition=lambda mol: failed_optimizer.is_in_cache(mol),
    true_calculator=stk.FitnessFunction(lambda mol: None),
    false_calculator=cage_fitness_calculator,
)

# #####################################################################
# Fitness normalizer.
# #####################################################################


def valid_fitness(population, mol):
    f = population.get_fitness_values()[mol]
    if not isinstance(f, list):
        return f is not None

    elif isinstance(f, list):
        return None not in population.get_fitness_values()[mol]
Example #3
0
import pytest
import stk

from ..case_data import CaseData


@pytest.fixture(
    params=(
        CaseData(
            fitness_calculator=stk.FitnessFunction(
                fitness_function=stk.Molecule.get_num_atoms,
            ),
            molecule=stk.BuildingBlock('BrCCBr'),
            fitness_value=8,
        ),
    ),
)
def fitness_function(request):
    return request.param
Example #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--mongodb_uri',
        help='The MongoDB URI for the database to connect to.',
        default='mongodb://localhost:27017/',
    )
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    # Use a random seed to get reproducible results.
    random_seed = 4
    generator = np.random.RandomState(random_seed)

    logger.info('Making building blocks.')

    # Load the building block databases.
    fluoros = tuple(
        get_building_blocks(
            path=pathlib.Path(__file__).parent / 'fluoros.txt',
            functional_group_factory=stk.FluoroFactory(),
        ))
    bromos = tuple(
        get_building_blocks(
            path=pathlib.Path(__file__).parent / 'bromos.txt',
            functional_group_factory=stk.BromoFactory(),
        ))

    initial_population = tuple(get_initial_population(fluoros, bromos))
    # Write the initial population.
    for i, record in enumerate(initial_population):
        write(record.get_molecule(), f'initial_{i}.mol')

    client = pymongo.MongoClient(args.mongodb_uri)
    db = stk.ConstructedMoleculeMongoDb(client)
    ea = stk.EvolutionaryAlgorithm(
        initial_population=initial_population,
        fitness_calculator=stk.FitnessFunction(get_fitness_value),
        mutator=stk.RandomMutator(
            mutators=(
                stk.RandomBuildingBlock(
                    building_blocks=fluoros,
                    is_replaceable=is_fluoro,
                    random_seed=generator.randint(0, 1000),
                ),
                stk.SimilarBuildingBlock(
                    building_blocks=fluoros,
                    is_replaceable=is_fluoro,
                    random_seed=generator.randint(0, 1000),
                ),
                stk.RandomBuildingBlock(
                    building_blocks=bromos,
                    is_replaceable=is_bromo,
                    random_seed=generator.randint(0, 1000),
                ),
                stk.SimilarBuildingBlock(
                    building_blocks=bromos,
                    is_replaceable=is_bromo,
                    random_seed=generator.randint(0, 1000),
                ),
            ),
            random_seed=generator.randint(0, 1000),
        ),
        crosser=stk.GeneticRecombination(get_gene=get_functional_group_type, ),
        generation_selector=stk.Best(
            num_batches=25,
            duplicate_molecules=False,
        ),
        mutation_selector=stk.Roulette(
            num_batches=5,
            random_seed=generator.randint(0, 1000),
        ),
        crossover_selector=stk.Roulette(
            num_batches=3,
            batch_size=2,
            random_seed=generator.randint(0, 1000),
        ),
    )

    logger.info('Starting EA.')

    generations = []
    for generation in ea.get_generations(50):
        for record in generation.get_molecule_records():
            db.put(record.get_molecule())
        generations.append(generation)

    # Write the final population.
    for i, record in enumerate(generation.get_molecule_records()):
        write(record.get_molecule(), f'final_{i}.mol')

    logger.info('Making fitness plot.')

    fitness_progress = stk.ProgressPlotter(
        generations=generations,
        get_property=lambda record: record.get_fitness_value(),
        y_label='Fitness Value',
    )
    fitness_progress.write('fitness_progress.png')

    logger.info('Making rotatable bonds plot.')

    rotatable_bonds_progress = stk.ProgressPlotter(
        generations=generations,
        get_property=get_num_rotatable_bonds,
        y_label='Number of Rotatable Bonds',
    )
    rotatable_bonds_progress.write('rotatable_bonds_progress.png')