예제 #1
0
    def __init__(self, size=2048, radius=4, dtype=np.int32):
        """An abstract class for managing transformations applied to
        model-based optimization datasets when constructing the oracle; for
        example, if the oracle learns from molecule fingerprints

        Arguments:

        size: int
            the number of bits in the morgan fingerprint returned by RDKit,
            controls the vector size of the molecule embedding
        radius: int
            the substructure radius passed to RDKit that controls how local
            the information encoded in the molecule embedding is

        """

        # wrap the deepchem featurizer that relies on rdkit
        self.featurizer = feat.CircularFingerprint(size=size, radius=radius)
        self.size = size
        self.radius = radius
        self.dtype = dtype

        # download the vocabulary file is not present in the cache
        vocab_file = DiskResource(
            os.path.join(DATA_DIR, 'smiles_vocab.txt'),
            download_method="direct",
            download_target=f'{SERVER_URL}/smiles_vocab.txt')
        if not vocab_file.is_downloaded:
            vocab_file.download()
        self.tokenizer = SmilesTokenizer(
            os.path.join(DATA_DIR, 'smiles_vocab.txt'))
예제 #2
0
    def __init__(self,
                 dataset: ContinuousDataset,
                 rollout_horizon=100,
                 **kwargs):
        """Initialize the ground truth score function f(x) for a model-based
        optimization problem, which involves loading the parameters of an
        oracle model and estimating its computational cost

        Arguments:

        dataset: DiscreteDataset
            an instance of a subclass of the DatasetBuilder class which has
            a set of design values 'x' and prediction values 'y', and defines
            batching and sampling methods for those attributes
        noise_std: float
            the standard deviation of gaussian noise added to the prediction
            values 'y' coming out of the ground truth score function f(x)
            in order to make the optimization problem difficult
        internal_measurements: int
            an integer representing the number of independent measurements of
            the prediction made by the oracle, which are subsequently
            averaged, and is useful when the oracle is stochastic

        """

        # the number of transitions per trajectory to sample
        self.rollout_horizon = rollout_horizon

        # ensure the trained policy has been downloaded
        policy = "dkitty_morphology/dkitty_oracle.pkl"
        policy = DiskResource(policy,
                              is_absolute=False,
                              download_method="direct",
                              download_target=f"{SERVER_URL}/{policy}")
        if not policy.is_downloaded and not policy.download():
            raise ValueError("unable to download trained policy for ant")

        # load the weights of the policy
        with open(policy.disk_target, "rb") as f:
            self.policy = pkl.load(f)

        # initialize the oracle using the super class
        super(DKittyMorphologyOracle, self).__init__(dataset,
                                                     internal_batch_size=1,
                                                     is_batched=False,
                                                     expect_normalized_y=False,
                                                     expect_normalized_x=False,
                                                     expect_logits=None,
                                                     **kwargs)
예제 #3
0
    def register_y_shards(transcription_factor='SIX6_REF_R1'):
        """Registers a remote file for download that contains prediction
        values in a format compatible with the dataset builder class;
        these files are downloaded all at once in the dataset initialization

        Arguments:

        transcription_factor: str
            a string argument that specifies which transcription factor to
            select for model-based optimization, where the goal is to find
            a length 8 polypeptide with maximum binding affinity

        Returns:

        resources: list of RemoteResource
            a list of RemoteResource objects specific to this dataset, which
            will be automatically downloaded while the dataset is built
            and may serve as shards if the dataset is large

        """

        return [DiskResource(
            file.replace("-x-", "-y-"), is_absolute=False,
            download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
            download_method="direct") for file in TF_BIND_8_FILES
            if transcription_factor in file]
예제 #4
0
    def register_y_shards(assay_chembl_id="CHEMBL1794345",
                          standard_type="Potency"):
        """Registers a remote file for download that contains prediction
        values in a format compatible with the dataset builder class;
        these files are downloaded all at once in the dataset initialization

        Arguments:

        assay_chembl_id: str
            a string identifier that specifies which assay to use for
            model-based optimization, where the goal is to find a design
            value 'x' that maximizes a certain property
        standard_type: str
            a string identifier that specifies which property of the assay
            is being measured for model-based optimization, where the goal is
            to maximize that property

        Returns:

        resources: list of RemoteResource
            a list of RemoteResource objects specific to this dataset, which
            will be automatically downloaded while the dataset is built
            and may serve as shards if the dataset is large

        """

        return [
            DiskResource(
                file.replace("-x-", "-y-"),
                is_absolute=False,
                download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
                download_method="direct") for file in CHEMBL_FILES
            if f"{standard_type}-{assay_chembl_id}" in file
        ]
예제 #5
0
    def register_y_shards():
        """Registers a remote file for download that contains prediction
        values in a format compatible with the dataset builder class;
        these files are downloaded all at once in the dataset initialization

        Returns:

        resources: list of RemoteResource
            a list of RemoteResource objects specific to this dataset, which
            will be automatically downloaded while the dataset is built
            and may serve as shards if the dataset is large

        """

        return [DiskResource(
            file.replace("-x-", "-y-"), is_absolute=False,
            download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}",
            download_method="direct") for file in TOY_DISCRETE_FILES]
예제 #6
0
    def register_x_shards():
        """Registers a remote file for download that contains design values
        in a format compatible with the dataset builder class;
        these files are downloaded all at once in the dataset initialization

        Returns:

        resources: list of RemoteResource
            a list of RemoteResource objects specific to this dataset, which
            will be automatically downloaded while the dataset is built
            and may serve as shards if the dataset is large

        """

        return [
            DiskResource(file,
                         is_absolute=False,
                         download_target=f"{SERVER_URL}/{file}",
                         download_method="direct")
            for file in DKITTY_MORPHOLOGY_FILES
        ]
예제 #7
0
    def __init__(self,
                 dataset: Union[DatasetBuilder, type, str],
                 oracle: Union[OracleBuilder, type, str],
                 dataset_kwargs=None,
                 oracle_kwargs=None,
                 relabel=False):
        """Initialize a model-based optimization problem using a static task
        dataset and a ground truth score function that is either an exact
        simulator, or an approximate model such as a neural network

        Arguments:

        dataset: Union[DatasetBuilder, type, str]
            a static dataset in a model-based optimization problem that
            exposes a set of designs 'x' and predictions 'y'
        oracle: Union[OracleBuilder, type, str]
            a ground truth score function in a model-based optimization
            problem that is either an exact simulator or an approximate model
        dataset_kwargs: dict
            additional keyword arguments that are provided to the dataset
            class when it is initialized for the first time
        oracle_kwargs: dict
            additional keyword arguments that are provided to the oracle
            class when it is initialized for the first time
        relabel: bool
            a boolean indicator that specifies whether the dataset prediction
            values should be relabeled with the predictions of the oracle

        """

        # use additional_kwargs to override self.kwargs
        kwargs = dataset_kwargs if dataset_kwargs else dict()

        # if self.entry_point is a function call it
        if callable(dataset):
            dataset = dataset(**kwargs)

        # if self.entry_point is a string import it first
        elif isinstance(dataset, str):
            dataset = import_name(dataset)(**kwargs)

        # return if the dataset could not be loaded
        elif not isinstance(dataset, DatasetBuilder):
            raise ValueError("dataset could not be loaded")

        # expose the built dataset
        self.dataset = dataset

        # use additional_kwargs to override self.kwargs
        kwargs = oracle_kwargs if oracle_kwargs else dict()

        # if self.entry_point is a function call it
        if callable(oracle):
            oracle = oracle(dataset, **kwargs)

        # if self.entry_point is a string import it first
        elif isinstance(oracle, str):
            oracle = import_name(oracle)(dataset, **kwargs)

        # return if the oracle could not be loaded
        elif not isinstance(oracle, OracleBuilder):
            raise ValueError("oracle could not be loaded")

        # expose the built oracle
        self.oracle = oracle

        # only relabel when an approximate model is used
        relabel = relabel and isinstance(oracle, ApproximateOracle)

        # attempt to download the appropriate shards
        new_shards = []
        for shard in dataset.y_shards:
            if relabel and isinstance(shard, DiskResource):

                # create a name for the new sharded prediction
                m = SHARD_PATTERN.search(shard.disk_target)
                file = f"{m.group(1)}-{oracle.name}-y-{m.group(3)}.npy"
                bare = os.path.join(os.path.basename(os.path.dirname(file)),
                                    os.path.basename(file))

                # create a disk resource for the new shard
                new_shards.append(
                    DiskResource(file,
                                 is_absolute=True,
                                 download_method="direct",
                                 download_target=f"https://design-bench."
                                 f"s3-us-west-1.amazonaws.com/{bare}"))

        # check if every shard was downloaded successfully
        # this naturally handles when the shard is already downloaded
        if relabel and len(new_shards) > 0 and all(
            [f.is_downloaded or f.download() for f in new_shards]):

            # assign the y shards to the downloaded files and re sample
            # the dataset if sub sampling is being used
            dataset.y_shards = new_shards
            dataset.subsample(max_samples=dataset.dataset_size,
                              distribution=dataset.dataset_distribution,
                              max_percentile=dataset.dataset_max_percentile,
                              min_percentile=dataset.dataset_min_percentile)

        elif relabel:

            # test if the shards are stored on the disk
            # this means that downloading cached predictions failed
            name = None
            test_shard = dataset.y_shards[0]
            if isinstance(test_shard, DiskResource):

                # create a name for the new sharded prediction
                m = SHARD_PATTERN.search(test_shard.disk_target)
                name = f"{m.group(1)}-{oracle.name}"

            # relabel the dataset using the new oracle model
            dataset.relabel(lambda x, y: oracle.predict(x),
                            to_disk=name is not None,
                            is_absolute=True,
                            disk_target=name)