示例#1
0
    def test_sklearn_reload(self):
        """Test that trained model can be reloaded correctly."""
        tasks = ["task0"]
        task_types = {task: "classification" for task in tasks}
        n_samples = 10
        n_features = 3
        n_tasks = len(tasks)

        # Generate dummy dataset
        np.random.seed(123)
        ids = np.arange(n_samples)
        X = np.random.rand(n_samples, n_features)
        y = np.random.randint(2, size=(n_samples, n_tasks))
        w = np.ones((n_samples, n_tasks))

        dataset = Dataset.from_numpy(self.train_dir, X, y, w, ids, tasks)

        model_params = {
            "batch_size": None,
            "data_shape": dataset.get_data_shape()
        }

        verbosity = "high"
        classification_metric = Metric(metrics.roc_auc_score,
                                       verbosity=verbosity)
        model = SklearnModel(tasks,
                             task_types,
                             model_params,
                             self.model_dir,
                             mode="classification",
                             model_instance=RandomForestClassifier())

        # Fit trained model
        model.fit(dataset)
        model.save()

        # Load trained model
        reloaded_model = SklearnModel(tasks,
                                      task_types,
                                      model_params,
                                      self.model_dir,
                                      mode="classification")
        reloaded_model.reload()

        # Eval model on train
        transformers = []
        evaluator = Evaluator(reloaded_model,
                              dataset,
                              transformers,
                              verbosity=verbosity)
        scores = evaluator.compute_model_performance([classification_metric])

        assert scores[classification_metric.name] > .9
示例#2
0
    def test_sklearn_reload(self):
        """Test that trained model can be reloaded correctly."""
        tasks = ["task0"]
        task_types = {task: "classification" for task in tasks}
        n_samples = 10
        n_features = 3
        n_tasks = len(tasks)

        # Generate dummy dataset
        np.random.seed(123)
        ids = np.arange(n_samples)
        X = np.random.rand(n_samples, n_features)
        y = np.random.randint(2, size=(n_samples, n_tasks))
        w = np.ones((n_samples, n_tasks))

        dataset = Dataset.from_numpy(self.train_dir, X, y, w, ids, tasks)

        model_params = {"batch_size": None, "data_shape": dataset.get_data_shape()}

        verbosity = "high"
        classification_metric = Metric(metrics.roc_auc_score, verbosity=verbosity)
        model = SklearnModel(
            tasks,
            task_types,
            model_params,
            self.model_dir,
            mode="classification",
            model_instance=RandomForestClassifier(),
        )

        # Fit trained model
        model.fit(dataset)
        model.save()

        # Load trained model
        reloaded_model = SklearnModel(tasks, task_types, model_params, self.model_dir, mode="classification")
        reloaded_model.reload()

        # Eval model on train
        transformers = []
        evaluator = Evaluator(reloaded_model, dataset, transformers, verbosity=verbosity)
        scores = evaluator.compute_model_performance([classification_metric])

        assert scores[classification_metric.name] > 0.9
示例#3
0
class RFConvexHullPocketFinder(BindingPocketFinder):
  """Uses pre-trained RF model + ConvexHulPocketFinder to select pockets."""

  def __init__(self, pad=5):
    self.pad = pad
    self.convex_finder = ConvexHullPocketFinder(pad)

    # Load binding pocket model
    self.base_dir = tempfile.mkdtemp()
    logger.info("About to download trained model.")
    # TODO(rbharath): Shift refined to full once trained.
    call((
        "wget -nv -c http://deepchem.io.s3-website-us-west-1.amazonaws.com/trained_models/pocket_random_refined_RF.tar.gz"
    ).split())
    call(("tar -zxvf pocket_random_refined_RF.tar.gz").split())
    call(("mv pocket_random_refined_RF %s" % (self.base_dir)).split())
    self.model_dir = os.path.join(self.base_dir, "pocket_random_refined_RF")

    # Fit model on dataset
    self.model = SklearnModel(model_dir=self.model_dir)
    self.model.reload()

    # Create featurizers
    self.pocket_featurizer = BindingPocketFeaturizer()
    self.ligand_featurizer = CircularFingerprint(size=1024)

  def find_pockets(self, protein_file, ligand_file):
    """Compute features for a given complex

    TODO(rbharath): This has a log of code overlap with
    compute_binding_pocket_features in
    examples/binding_pockets/binding_pocket_datasets.py. Find way to refactor
    to avoid code duplication.
    """
    # if not ligand_file.endswith(".sdf"):
    #   raise ValueError("Only .sdf ligand files can be featurized.")
    # ligand_basename = os.path.basename(ligand_file).split(".")[0]
    # ligand_mol2 = os.path.join(
    #     self.base_dir, ligand_basename + ".mol2")
    #
    # # Write mol2 file for ligand
    # obConversion = ob.OBConversion()
    # conv_out = obConversion.SetInAndOutFormats(str("sdf"), str("mol2"))
    # ob_mol = ob.OBMol()
    # obConversion.ReadFile(ob_mol, str(ligand_file))
    # obConversion.WriteFile(ob_mol, str(ligand_mol2))
    #
    # # Featurize ligand
    # mol = Chem.MolFromMol2File(str(ligand_mol2), removeHs=False)
    # if mol is None:
    #   return None, None
    # # Default for CircularFingerprint
    # n_ligand_features = 1024
    # ligand_features = self.ligand_featurizer.featurize([mol])
    #
    # # Featurize pocket
    # pockets, pocket_atoms_map, pocket_coords = self.convex_finder.find_pockets(
    #     protein_file, ligand_file)
    # n_pockets = len(pockets)
    # n_pocket_features = BindingPocketFeaturizer.n_features
    #
    # features = np.zeros((n_pockets, n_pocket_features+n_ligand_features))
    # pocket_features = self.pocket_featurizer.featurize(
    #     protein_file, pockets, pocket_atoms_map, pocket_coords)
    # # Note broadcast operation
    # features[:, :n_pocket_features] = pocket_features
    # features[:, n_pocket_features:] = ligand_features
    # dataset = NumpyDataset(X=features)
    # pocket_preds = self.model.predict(dataset)
    # pocket_pred_proba = np.squeeze(self.model.predict_proba(dataset))
    #
    # # Find pockets which are active
    # active_pockets = []
    # active_pocket_atoms_map = {}
    # active_pocket_coords = []
    # for pocket_ind in range(len(pockets)):
    #   #################################################### DEBUG
    #   # TODO(rbharath): For now, using a weak cutoff. Fix later.
    #   #if pocket_preds[pocket_ind] == 1:
    #   if pocket_pred_proba[pocket_ind][1] > .15:
    #   #################################################### DEBUG
    #     pocket = pockets[pocket_ind]
    #     active_pockets.append(pocket)
    #     active_pocket_atoms_map[pocket] = pocket_atoms_map[pocket]
    #     active_pocket_coords.append(pocket_coords[pocket_ind])
    # return active_pockets, active_pocket_atoms_map, active_pocket_coords
    # # TODO(LESWING)
    raise ValueError("Karl Implement")
示例#4
0
class RFConvexHullPocketFinder(BindingPocketFinder):
  """Uses pre-trained RF model + ConvexHulPocketFinder to select pockets."""

  def __init__(self, pad=5):
    self.pad = pad
    self.convex_finder = ConvexHullPocketFinder(pad)

    # Load binding pocket model
    self.base_dir = tempfile.mkdtemp()
    print("About to download trained model.")
    # TODO(rbharath): Shift refined to full once trained.
    call((
        "wget -c http://deepchem.io.s3-website-us-west-1.amazonaws.com/trained_models/pocket_random_refined_RF.tar.gz"
    ).split())
    call(("tar -zxvf pocket_random_refined_RF.tar.gz").split())
    call(("mv pocket_random_refined_RF %s" % (self.base_dir)).split())
    self.model_dir = os.path.join(self.base_dir, "pocket_random_refined_RF")

    # Fit model on dataset
    self.model = SklearnModel(model_dir=self.model_dir)
    self.model.reload()

    # Create featurizers
    self.pocket_featurizer = BindingPocketFeaturizer()
    self.ligand_featurizer = CircularFingerprint(size=1024)

  def find_pockets(self, protein_file, ligand_file):
    """Compute features for a given complex

    TODO(rbharath): This has a log of code overlap with
    compute_binding_pocket_features in
    examples/binding_pockets/binding_pocket_datasets.py. Find way to refactor
    to avoid code duplication.
    """
    # if not ligand_file.endswith(".sdf"):
    #   raise ValueError("Only .sdf ligand files can be featurized.")
    # ligand_basename = os.path.basename(ligand_file).split(".")[0]
    # ligand_mol2 = os.path.join(
    #     self.base_dir, ligand_basename + ".mol2")
    #
    # # Write mol2 file for ligand
    # obConversion = ob.OBConversion()
    # conv_out = obConversion.SetInAndOutFormats(str("sdf"), str("mol2"))
    # ob_mol = ob.OBMol()
    # obConversion.ReadFile(ob_mol, str(ligand_file))
    # obConversion.WriteFile(ob_mol, str(ligand_mol2))
    #
    # # Featurize ligand
    # mol = Chem.MolFromMol2File(str(ligand_mol2), removeHs=False)
    # if mol is None:
    #   return None, None
    # # Default for CircularFingerprint
    # n_ligand_features = 1024
    # ligand_features = self.ligand_featurizer.featurize([mol])
    #
    # # Featurize pocket
    # pockets, pocket_atoms_map, pocket_coords = self.convex_finder.find_pockets(
    #     protein_file, ligand_file)
    # n_pockets = len(pockets)
    # n_pocket_features = BindingPocketFeaturizer.n_features
    #
    # features = np.zeros((n_pockets, n_pocket_features+n_ligand_features))
    # pocket_features = self.pocket_featurizer.featurize(
    #     protein_file, pockets, pocket_atoms_map, pocket_coords)
    # # Note broadcast operation
    # features[:, :n_pocket_features] = pocket_features
    # features[:, n_pocket_features:] = ligand_features
    # dataset = NumpyDataset(X=features)
    # pocket_preds = self.model.predict(dataset)
    # pocket_pred_proba = np.squeeze(self.model.predict_proba(dataset))
    #
    # # Find pockets which are active
    # active_pockets = []
    # active_pocket_atoms_map = {}
    # active_pocket_coords = []
    # for pocket_ind in range(len(pockets)):
    #   #################################################### DEBUG
    #   # TODO(rbharath): For now, using a weak cutoff. Fix later.
    #   #if pocket_preds[pocket_ind] == 1:
    #   if pocket_pred_proba[pocket_ind][1] > .15:
    #   #################################################### DEBUG
    #     pocket = pockets[pocket_ind]
    #     active_pockets.append(pocket)
    #     active_pocket_atoms_map[pocket] = pocket_atoms_map[pocket]
    #     active_pocket_coords.append(pocket_coords[pocket_ind])
    # return active_pockets, active_pocket_atoms_map, active_pocket_coords
    # # TODO(LESWING)
    raise ValueError("Karl Implement")