def __init__(self, size=2048, radius=4, dtype=np.int32): """An abstract class for managing transformations applied to model-based optimization datasets when constructing the oracle; for example, if the oracle learns from molecule fingerprints Arguments: size: int the number of bits in the morgan fingerprint returned by RDKit, controls the vector size of the molecule embedding radius: int the substructure radius passed to RDKit that controls how local the information encoded in the molecule embedding is """ # wrap the deepchem featurizer that relies on rdkit self.featurizer = feat.CircularFingerprint(size=size, radius=radius) self.size = size self.radius = radius self.dtype = dtype # download the vocabulary file is not present in the cache vocab_file = DiskResource( os.path.join(DATA_DIR, 'smiles_vocab.txt'), download_method="direct", download_target=f'{SERVER_URL}/smiles_vocab.txt') if not vocab_file.is_downloaded: vocab_file.download() self.tokenizer = SmilesTokenizer( os.path.join(DATA_DIR, 'smiles_vocab.txt'))
def __init__(self, dataset: ContinuousDataset, rollout_horizon=100, **kwargs): """Initialize the ground truth score function f(x) for a model-based optimization problem, which involves loading the parameters of an oracle model and estimating its computational cost Arguments: dataset: DiscreteDataset an instance of a subclass of the DatasetBuilder class which has a set of design values 'x' and prediction values 'y', and defines batching and sampling methods for those attributes noise_std: float the standard deviation of gaussian noise added to the prediction values 'y' coming out of the ground truth score function f(x) in order to make the optimization problem difficult internal_measurements: int an integer representing the number of independent measurements of the prediction made by the oracle, which are subsequently averaged, and is useful when the oracle is stochastic """ # the number of transitions per trajectory to sample self.rollout_horizon = rollout_horizon # ensure the trained policy has been downloaded policy = "dkitty_morphology/dkitty_oracle.pkl" policy = DiskResource(policy, is_absolute=False, download_method="direct", download_target=f"{SERVER_URL}/{policy}") if not policy.is_downloaded and not policy.download(): raise ValueError("unable to download trained policy for ant") # load the weights of the policy with open(policy.disk_target, "rb") as f: self.policy = pkl.load(f) # initialize the oracle using the super class super(DKittyMorphologyOracle, self).__init__(dataset, internal_batch_size=1, is_batched=False, expect_normalized_y=False, expect_normalized_x=False, expect_logits=None, **kwargs)
def register_y_shards(transcription_factor='SIX6_REF_R1'): """Registers a remote file for download that contains prediction values in a format compatible with the dataset builder class; these files are downloaded all at once in the dataset initialization Arguments: transcription_factor: str a string argument that specifies which transcription factor to select for model-based optimization, where the goal is to find a length 8 polypeptide with maximum binding affinity Returns: resources: list of RemoteResource a list of RemoteResource objects specific to this dataset, which will be automatically downloaded while the dataset is built and may serve as shards if the dataset is large """ return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", download_method="direct") for file in TF_BIND_8_FILES if transcription_factor in file]
def register_y_shards(assay_chembl_id="CHEMBL1794345", standard_type="Potency"): """Registers a remote file for download that contains prediction values in a format compatible with the dataset builder class; these files are downloaded all at once in the dataset initialization Arguments: assay_chembl_id: str a string identifier that specifies which assay to use for model-based optimization, where the goal is to find a design value 'x' that maximizes a certain property standard_type: str a string identifier that specifies which property of the assay is being measured for model-based optimization, where the goal is to maximize that property Returns: resources: list of RemoteResource a list of RemoteResource objects specific to this dataset, which will be automatically downloaded while the dataset is built and may serve as shards if the dataset is large """ return [ DiskResource( file.replace("-x-", "-y-"), is_absolute=False, download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", download_method="direct") for file in CHEMBL_FILES if f"{standard_type}-{assay_chembl_id}" in file ]
def register_y_shards(): """Registers a remote file for download that contains prediction values in a format compatible with the dataset builder class; these files are downloaded all at once in the dataset initialization Returns: resources: list of RemoteResource a list of RemoteResource objects specific to this dataset, which will be automatically downloaded while the dataset is built and may serve as shards if the dataset is large """ return [DiskResource( file.replace("-x-", "-y-"), is_absolute=False, download_target=f"{SERVER_URL}/{file.replace('-x-', '-y-')}", download_method="direct") for file in TOY_DISCRETE_FILES]
def register_x_shards(): """Registers a remote file for download that contains design values in a format compatible with the dataset builder class; these files are downloaded all at once in the dataset initialization Returns: resources: list of RemoteResource a list of RemoteResource objects specific to this dataset, which will be automatically downloaded while the dataset is built and may serve as shards if the dataset is large """ return [ DiskResource(file, is_absolute=False, download_target=f"{SERVER_URL}/{file}", download_method="direct") for file in DKITTY_MORPHOLOGY_FILES ]
def __init__(self, dataset: Union[DatasetBuilder, type, str], oracle: Union[OracleBuilder, type, str], dataset_kwargs=None, oracle_kwargs=None, relabel=False): """Initialize a model-based optimization problem using a static task dataset and a ground truth score function that is either an exact simulator, or an approximate model such as a neural network Arguments: dataset: Union[DatasetBuilder, type, str] a static dataset in a model-based optimization problem that exposes a set of designs 'x' and predictions 'y' oracle: Union[OracleBuilder, type, str] a ground truth score function in a model-based optimization problem that is either an exact simulator or an approximate model dataset_kwargs: dict additional keyword arguments that are provided to the dataset class when it is initialized for the first time oracle_kwargs: dict additional keyword arguments that are provided to the oracle class when it is initialized for the first time relabel: bool a boolean indicator that specifies whether the dataset prediction values should be relabeled with the predictions of the oracle """ # use additional_kwargs to override self.kwargs kwargs = dataset_kwargs if dataset_kwargs else dict() # if self.entry_point is a function call it if callable(dataset): dataset = dataset(**kwargs) # if self.entry_point is a string import it first elif isinstance(dataset, str): dataset = import_name(dataset)(**kwargs) # return if the dataset could not be loaded elif not isinstance(dataset, DatasetBuilder): raise ValueError("dataset could not be loaded") # expose the built dataset self.dataset = dataset # use additional_kwargs to override self.kwargs kwargs = oracle_kwargs if oracle_kwargs else dict() # if self.entry_point is a function call it if callable(oracle): oracle = oracle(dataset, **kwargs) # if self.entry_point is a string import it first elif isinstance(oracle, str): oracle = import_name(oracle)(dataset, **kwargs) # return if the oracle could not be loaded elif not isinstance(oracle, OracleBuilder): raise ValueError("oracle could not be loaded") # expose the built oracle self.oracle = oracle # only relabel when an approximate model is used relabel = relabel and isinstance(oracle, ApproximateOracle) # attempt to download the appropriate shards new_shards = [] for shard in dataset.y_shards: if relabel and isinstance(shard, DiskResource): # create a name for the new sharded prediction m = SHARD_PATTERN.search(shard.disk_target) file = f"{m.group(1)}-{oracle.name}-y-{m.group(3)}.npy" bare = os.path.join(os.path.basename(os.path.dirname(file)), os.path.basename(file)) # create a disk resource for the new shard new_shards.append( DiskResource(file, is_absolute=True, download_method="direct", download_target=f"https://design-bench." f"s3-us-west-1.amazonaws.com/{bare}")) # check if every shard was downloaded successfully # this naturally handles when the shard is already downloaded if relabel and len(new_shards) > 0 and all( [f.is_downloaded or f.download() for f in new_shards]): # assign the y shards to the downloaded files and re sample # the dataset if sub sampling is being used dataset.y_shards = new_shards dataset.subsample(max_samples=dataset.dataset_size, distribution=dataset.dataset_distribution, max_percentile=dataset.dataset_max_percentile, min_percentile=dataset.dataset_min_percentile) elif relabel: # test if the shards are stored on the disk # this means that downloading cached predictions failed name = None test_shard = dataset.y_shards[0] if isinstance(test_shard, DiskResource): # create a name for the new sharded prediction m = SHARD_PATTERN.search(test_shard.disk_target) name = f"{m.group(1)}-{oracle.name}" # relabel the dataset using the new oracle model dataset.relabel(lambda x, y: oracle.predict(x), to_disk=name is not None, is_absolute=True, disk_target=name)