示例#1
0
  def test_process(self):
    amber_relax = relax.AmberRelaxation(**self.test_config)

    with open(os.path.join(self.test_dir, 'model_output.pdb')) as f:
      test_prot = protein.from_pdb_string(f.read())
    pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot)

    self.assertCountEqual(debug_info.keys(),
                          set({'initial_energy', 'final_energy',
                               'attempts', 'rmsd'}))
    self.assertLess(debug_info['final_energy'], debug_info['initial_energy'])
    self.assertGreater(debug_info['rmsd'], 0)

    prot_min = protein.from_pdb_string(pdb_min)
    # Most protein properties should be unchanged.
    np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype)
    np.testing.assert_almost_equal(test_prot.residue_index,
                                   prot_min.residue_index)
    # Atom mask and bfactors identical except for terminal OXT of last residue.
    np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :],
                                   prot_min.atom_mask[:-1, :])
    np.testing.assert_almost_equal(test_prot.b_factors[:-1, :],
                                   prot_min.b_factors[:-1, :])
    np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1],
                                   prot_min.atom_mask[:, :-1])
    np.testing.assert_almost_equal(test_prot.b_factors[:, :-1],
                                   prot_min.b_factors[:, :-1])
    # There are no residues with violations.
    np.testing.assert_equal(num_violations, np.zeros_like(num_violations))
示例#2
0
  def test_overwrite_b_factors(self):
    """tbd."""
    testdir = os.path.join(
        absltest.get_default_test_srcdir(),
        'alphafold/relax/testdata/'
        'multiple_disulfides_target.pdb')
    with open(testdir) as f:
      test_pdb = f.read()
    n_residues = 191
    bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1)

    output_pdb = utils.overwrite_b_factors(test_pdb, bfactors)

    # Check that the atom lines are unchanged apart from the B-factors.
    atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')]
    atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')]
    for line_original, line_new in zip(atom_lines_original, atom_lines_new):
      self.assertEqual(line_original[:60].strip(), line_new[:60].strip())
      self.assertEqual(line_original[66:].strip(), line_new[66:].strip())

    # Check B-factors are correctly set for all atoms present.
    as_protein = protein.from_pdb_string(output_pdb)
    np.testing.assert_almost_equal(
        np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0),
        np.where(as_protein.atom_mask > 0, bfactors, 0))
示例#3
0
 def process(
         self, *,
         prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
     """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
     out = amber_minimize.run_pipeline(
         prot=prot,
         max_iterations=self._max_iterations,
         tolerance=self._tolerance,
         stiffness=self._stiffness,
         exclude_residues=self._exclude_residues,
         max_outer_iterations=self._max_outer_iterations)
     min_pos = out['pos']
     start_pos = out['posinit']
     rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
     debug_data = {
         'initial_energy': out['einit'],
         'final_energy': out['efinal'],
         'attempts': out['min_attempts'],
         'rmsd': rmsd
     }
     pdb_str = amber_minimize.clean_protein(prot)
     min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
     min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
     utils.assert_equal_nonterminal_atom_types(
         protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask)
     violations = out['structural_violations'][
         'total_per_residue_violations_mask']
     return min_pdb, debug_data, violations
示例#4
0
    def test_to_pdb(self):
        with open(
                os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
                             '2rbg.pdb')) as f:
            pdb_string = f.read()
        prot = protein.from_pdb_string(pdb_string, chain_id='A')
        pdb_string_reconstr = protein.to_pdb(prot)
        prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)

        np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
        np.testing.assert_array_almost_equal(prot_reconstr.atom_positions,
                                             prot.atom_positions)
        np.testing.assert_array_almost_equal(prot_reconstr.atom_mask,
                                             prot.atom_mask)
        np.testing.assert_array_equal(prot_reconstr.residue_index,
                                      prot.residue_index)
        np.testing.assert_array_almost_equal(prot_reconstr.b_factors,
                                             prot.b_factors)
示例#5
0
 def test_from_pdb_str(self, pdb_file, chain_id, num_res):
     pdb_file = os.path.join(absltest.get_default_test_srcdir(),
                             TEST_DATA_DIR, pdb_file)
     with open(pdb_file) as f:
         pdb_string = f.read()
     prot = protein.from_pdb_string(pdb_string, chain_id)
     self._check_shapes(prot, num_res)
     self.assertGreaterEqual(prot.aatype.min(), 0)
     # Allow equal since unknown restypes have index equal to restype_num.
     self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
示例#6
0
 def get_label(self, protein):
     """tbd."""
     protein_struct_file = self._get_protein_struct_file(protein)
     # chain_id = protein.split('_')[1] if len(protein.split('_')) >=2 else None
     chain_id = None
     with gzip.open(protein_struct_file, 'r') as f:
         pdb_file = f.read().decode('utf8')
     prot_obj = protein_utils.from_pdb_string(pdb_file, chain_id)
     protein_chain_d = load_pdb_chain(prot_obj, confidence_threshold=0.5)
     protein_label = generate_label(protein_chain_d)
     return protein_label
示例#7
0
 def test_unresolved_violations(self):
   amber_relax = relax.AmberRelaxation(**self.test_config)
   with open(os.path.join(self.test_dir,
                                'with_violations_casp14.pdb')) as f:
     test_prot = protein.from_pdb_string(f.read())
   _, _, num_violations = amber_relax.process(prot=test_prot)
   exp_num_violations = np.array(
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0])
   # Check no violations were added. Can't check exactly due to stochasticity.
   self.assertTrue(np.all(num_violations <= exp_num_violations))
示例#8
0
 def test_ideal_atom_mask(self):
     with open(
             os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
                          '2rbg.pdb')) as f:
         pdb_string = f.read()
     prot = protein.from_pdb_string(pdb_string, chain_id='A')
     ideal_mask = protein.ideal_atom_mask(prot)
     non_ideal_residues = set([102] + list(range(127, 285)))
     for i, (res,
             atom_mask) in enumerate(zip(prot.residue_index,
                                         prot.atom_mask)):
         if res in non_ideal_residues:
             self.assertFalse(np.all(atom_mask == ideal_mask[i]),
                              msg=f'{res}')
         else:
             self.assertTrue(np.all(atom_mask == ideal_mask[i]),
                             msg=f'{res}')
def run_pipeline(prot: protein.Protein,
                 stiffness: float,
                 max_outer_iterations: int = 1,
                 place_hydrogens_every_iteration: bool = True,
                 max_iterations: int = 0,
                 tolerance: float = 2.39,
                 restraint_set: str = "non_hydrogen",
                 max_attempts: int = 100,
                 checks: bool = True,
                 exclude_residues: Optional[Sequence[int]] = None):
    """Run iterative amber relax.

  Successive relax iterations are performed until all violations have been
  resolved. Each iteration involves a restrained Amber minimization, with
  restraint exclusions determined by violation-participating residues.

  Args:
    prot: A protein to be relaxed.
    stiffness: kcal/mol A**2, the restraint stiffness.
    max_outer_iterations: The maximum number of iterative minimization.
    place_hydrogens_every_iteration: Whether hydrogens are re-initialized
        prior to every minimization.
    max_iterations: An `int` specifying the maximum number of L-BFGS steps
        per relax iteration. A value of 0 specifies no limit.
    tolerance: kcal/mol, the energy tolerance of L-BFGS.
        The default value is the OpenMM default.
    restraint_set: The set of atoms to restrain.
    max_attempts: The maximum number of minimization attempts per iteration.
    checks: Whether to perform cleaning checks.
    exclude_residues: An optional list of zero-indexed residues to exclude from
        restraints.

  Returns:
    out: A dictionary of output values.
  """

    # `protein.to_pdb` will strip any poorly-defined residues so we need to
    # perform this check before `clean_protein`.
    _check_residues_are_well_defined(prot)
    pdb_string = clean_protein(prot, checks=checks)

    exclude_residues = exclude_residues or []
    exclude_residues = set(exclude_residues)
    violations = np.inf
    iteration = 0

    while violations > 0 and iteration < max_outer_iterations:
        ret = _run_one_iteration(pdb_string=pdb_string,
                                 exclude_residues=exclude_residues,
                                 max_iterations=max_iterations,
                                 tolerance=tolerance,
                                 stiffness=stiffness,
                                 restraint_set=restraint_set,
                                 max_attempts=max_attempts)
        prot = protein.from_pdb_string(ret["min_pdb"])
        if place_hydrogens_every_iteration:
            pdb_string = clean_protein(prot, checks=True)
        else:
            pdb_string = ret["min_pdb"]
        ret.update(get_violation_metrics(prot))
        ret.update({
            "num_exclusions": len(exclude_residues),
            "iteration": iteration,
        })
        violations = ret["violations_per_residue"]
        exclude_residues = exclude_residues.union(ret["residue_violations"])

        logger.info(
            "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
            "num residue violations %d num residue exclusions %d ",
            ret["einit"], ret["efinal"], ret["opt_time"],
            ret["num_residue_violations"], ret["num_exclusions"])
        iteration += 1
    return ret
示例#10
0
def _load_test_protein(data_path):
    pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
    with open(pdb_path, 'r') as f:
        return protein.from_pdb_string(f.read())