示例#1
0
def make_data_loader(mols: List[str],
                     values: Optional[List[Any]] = None,
                     batch_size: int = 32,
                     shuffle_buffer: Optional[int] = None,
                     value_spec: tf.TensorSpec = tf.TensorSpec((), dtype=tf.float32),
                     max_size: Optional[int] = None,
                     drop_last_batch: bool = False) -> tf.data.Dataset:
    """Make a data loader for data compatible with NFP-style neural networks

    Args:
        mols: List of molecules in a string format
        values: List of output values, if included in the output
        value_spec: Tensorflow specification for the output
        batch_size: Number of molecules per batch
        shuffle_buffer: Size of a shuffle buffer. Use ``None`` to leave data unshuffled
        max_size: Maximum number of atoms per molecule
        drop_last_batch: Whether to keep the last batch in the dataset. Set to ``True`` if, for example, you need every batch to be the same size
    Returns:
        Data loader that generates molecules in the desired shapes
    """

    # Convert the molecules to dictionary formats
    mol_dicts = [_to_nfp_dict(convert_string_to_dict(s)) for s in mols]

    # Make the initial data loader
    record_sig = {
        "atom": tf.TensorSpec(shape=(None,), dtype=tf.int32),
        "bond": tf.TensorSpec(shape=(None,), dtype=tf.int32),
        "connectivity": tf.TensorSpec(shape=(None, 2), dtype=tf.int32),
    }
    if values is None:
        def generator():
            yield from mol_dicts
    else:
        def generator():
            yield from zip(mol_dicts, values)

        record_sig = (record_sig, value_spec)

    loader = tf.data.Dataset.from_generator(generator=generator, output_signature=record_sig).cache()  # TODO (wardlt): Make caching optional?

    # Shuffle, if desired
    if shuffle_buffer is not None:
        loader = loader.shuffle(shuffle_buffer)

    # Make the batches
    if max_size is None:
        loader = loader.padded_batch(batch_size=batch_size, drop_remainder=drop_last_batch)
    else:
        max_bonds = 4 * max_size  # If all atoms are carbons, they each have 4 points
        padded_records = {
            "atom": tf.TensorShape((max_size,)),
            "bond": tf.TensorShape((max_bonds,)),
            "connectivity": tf.TensorShape((max_bonds, 2))
        }
        if values is not None:
            padded_records = (padded_records, value_spec.shape)
        loader = loader.padded_batch(batch_size=batch_size, padded_shapes=padded_records, drop_remainder=drop_last_batch)

    return loader
示例#2
0
    def __init__(self,
                 smiles: List[str],
                 outputs: List[float],
                 batch_size: int,
                 shuffle: bool = True,
                 random_state: int = None):
        """

        Args:
            smiles: List of molecules
            outputs: List of molecular outputs
            batch_size: Number of batches to use to train model
            shuffle: Whether to shuffle after each epoch
            random_state: Random state for the shuffling
        """

        super(GraphLoader, self).__init__()

        # Convert the molecules to MPNN-ready formats
        mols = [convert_string_to_dict(s) for s in smiles]
        self.entries = np.array(list(zip(mols, outputs)))

        # Other data
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Give it a first shuffle, if needed
        self.rng = np.random.RandomState(random_state)
        if shuffle:
            self.rng.shuffle(self.entries)
示例#3
0
def evaluate_mpnn(model_msg: Union[List[MPNNMessage], List[tf.keras.Model], List[str], List[Path]],
                  smiles: Union[List[str], List[dict]],
                  batch_size: int = 128, cache: bool = True, n_jobs: Optional[int] = 1) -> np.ndarray:
    """Run inference on a list of molecules

    Args:
        model_msg: List of MPNNs to evaluate. Accepts either a pickled message, model, or a path
        smiles: List of molecules to evaluate either as SMILES or InChI strings, or lists of MPNN-ready dictionary objections
        batch_size: Number of molecules per batch
        cache: Whether to cache models if being read from disk
        n_jobs: Number of parallel jobs to run. Set `None` to use total number of cores
            Note: The Pool is cached, so the first value of n_jobs is set to will remain
            for the life of the process (except if the value is 1, which does not use a Pool)
    Returns:
        Predicted value for each molecule
    """
    assert len(smiles) > 0, "You must provide at least one molecule to inference function"

    # Access the model
    if isinstance(model_msg[0], MPNNMessage):
        # Unpack the messages
        models = [m.get_model() for m in model_msg]
    elif isinstance(model_msg[0], (str, Path)):
        # Load the model from disk
        if cache:
            models = []
            for p in model_msg:
                if p not in _model_cache:
                    _model_cache[p] = tf.keras.models.load_model(str(p), custom_objects=custom_objects)
                models.append(_model_cache[p])
        else:
            models = [tf.keras.models.load_model(p, custom_objects=custom_objects)
                      for p in model_msg]
    else:
        # No action needed
        models = model_msg

    # Ensure all molecules are ready for inference
    if isinstance(smiles[0], dict):
        mols = smiles
    else:
        if n_jobs == 1:
            mols = [convert_string_to_dict(s) for s in smiles]
        else:
            pool = get_process_pool(n_jobs)
            mols = pool.map(convert_string_to_dict, smiles)

    # Stuff them into batches
    chunks = [mols[start:start + batch_size] for start in range(0, len(mols), batch_size)]
    batches = [_merge_batch(c) for c in chunks]

    # Feed the batches through the MPNN
    all_outputs = []
    for model in models:
        outputs = [model(b) for b in batches]
        all_outputs.append(np.vstack(outputs))
    return np.hstack(all_outputs)
def test_inference(model):
    # Evaluate serial, then it parallel
    results_serial = evaluate_mpnn([model], ['C', 'CC'])
    results_parallel = evaluate_mpnn([model], ['C', 'CC'], n_jobs=2)
    assert np.isclose(results_parallel, results_serial).all()

    # Try running inference on a pre-processed molecule
    preparsed = [convert_string_to_dict(x) for x in ['C', 'CC']]
    results_preparsed = evaluate_mpnn([model], preparsed)
    assert np.isclose(results_preparsed, results_serial).all()
示例#5
0
    def get_inference_inputs(self,
                             record: MoleculeData) -> Tuple[str, Any, float]:
        """Determine which model to use for inference and the inputs needed for that model

        Args:
            record: Molecule to evaluate
        Returns:
            - Name of the model that should be run
            - Inputs to the machine learning model
            - Value to be calibrated
        """

        # Determine which model to run
        current_step = self.get_current_step(record)

        # Get the model spec for that level and the input value
        if current_step == 'base':
            model_spec = self.base_model
            init_value = 0
        else:
            model_spec = self.get_models(current_step)
            init_value = record.oxidation_potential[current_step] if self.oxidation_state == OxidationState.OXIDIZED \
                else record.reduction_potential[current_step]

        # Get the inputs
        model_type = model_spec.model_type
        if model_type == ModelType.SCHNET:
            recipe = get_recipe_by_name(current_step)
            # Use the geometry at the base level of fidelity, and select the charged geometry only if available
            input_val = record.data[recipe.geometry_level][
                self.oxidation_state if recipe.adiabatic else "neutral"].xyz
        elif model_type == ModelType.MPNN:
            # Use a dictionary
            input_val = convert_string_to_dict(record.identifier['inchi'])
        else:
            raise NotImplementedError(f'No support for {model_type} yet')

        return current_step, input_val, init_value