示例#1
0
    def __init__(self, size=2048, radius=4, dtype=np.int32):
        """An abstract class for managing transformations applied to
        model-based optimization datasets when constructing the oracle; for
        example, if the oracle learns from molecule fingerprints

        Arguments:

        size: int
            the number of bits in the morgan fingerprint returned by RDKit,
            controls the vector size of the molecule embedding
        radius: int
            the substructure radius passed to RDKit that controls how local
            the information encoded in the molecule embedding is

        """

        # wrap the deepchem featurizer that relies on rdkit
        self.featurizer = feat.CircularFingerprint(size=size, radius=radius)
        self.size = size
        self.radius = radius
        self.dtype = dtype

        # download the vocabulary file is not present in the cache
        vocab_file = DiskResource(
            os.path.join(DATA_DIR, 'smiles_vocab.txt'),
            download_method="direct",
            download_target=f'{SERVER_URL}/smiles_vocab.txt')
        if not vocab_file.is_downloaded:
            vocab_file.download()
        self.tokenizer = SmilesTokenizer(
            os.path.join(DATA_DIR, 'smiles_vocab.txt'))
示例#2
0
    def __init__(self,
                 dataset: ContinuousDataset,
                 rollout_horizon=100,
                 **kwargs):
        """Initialize the ground truth score function f(x) for a model-based
        optimization problem, which involves loading the parameters of an
        oracle model and estimating its computational cost

        Arguments:

        dataset: DiscreteDataset
            an instance of a subclass of the DatasetBuilder class which has
            a set of design values 'x' and prediction values 'y', and defines
            batching and sampling methods for those attributes
        noise_std: float
            the standard deviation of gaussian noise added to the prediction
            values 'y' coming out of the ground truth score function f(x)
            in order to make the optimization problem difficult
        internal_measurements: int
            an integer representing the number of independent measurements of
            the prediction made by the oracle, which are subsequently
            averaged, and is useful when the oracle is stochastic

        """

        # the number of transitions per trajectory to sample
        self.rollout_horizon = rollout_horizon

        # ensure the trained policy has been downloaded
        policy = "dkitty_morphology/dkitty_oracle.pkl"
        policy = DiskResource(policy,
                              is_absolute=False,
                              download_method="direct",
                              download_target=f"{SERVER_URL}/{policy}")
        if not policy.is_downloaded and not policy.download():
            raise ValueError("unable to download trained policy for ant")

        # load the weights of the policy
        with open(policy.disk_target, "rb") as f:
            self.policy = pkl.load(f)

        # initialize the oracle using the super class
        super(DKittyMorphologyOracle, self).__init__(dataset,
                                                     internal_batch_size=1,
                                                     is_batched=False,
                                                     expect_normalized_y=False,
                                                     expect_normalized_x=False,
                                                     expect_logits=None,
                                                     **kwargs)