Ejemplo n.º 1
0
  def test_process(self):
    amber_relax = relax.AmberRelaxation(**self.test_config)

    with open(os.path.join(self.test_dir, 'model_output.pdb')) as f:
      test_prot = protein.from_pdb_string(f.read())
    pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot)

    self.assertCountEqual(debug_info.keys(),
                          set({'initial_energy', 'final_energy',
                               'attempts', 'rmsd'}))
    self.assertLess(debug_info['final_energy'], debug_info['initial_energy'])
    self.assertGreater(debug_info['rmsd'], 0)

    prot_min = protein.from_pdb_string(pdb_min)
    # Most protein properties should be unchanged.
    np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype)
    np.testing.assert_almost_equal(test_prot.residue_index,
                                   prot_min.residue_index)
    # Atom mask and bfactors identical except for terminal OXT of last residue.
    np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :],
                                   prot_min.atom_mask[:-1, :])
    np.testing.assert_almost_equal(test_prot.b_factors[:-1, :],
                                   prot_min.b_factors[:-1, :])
    np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1],
                                   prot_min.atom_mask[:, :-1])
    np.testing.assert_almost_equal(test_prot.b_factors[:, :-1],
                                   prot_min.b_factors[:, :-1])
    # There are no residues with violations.
    np.testing.assert_equal(num_violations, np.zeros_like(num_violations))
Ejemplo n.º 2
0
    def test_to_pdb(self):
        with open(
                os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
                             '2rbg.pdb')) as f:
            pdb_string = f.read()
        prot = protein.from_pdb_string(pdb_string, chain_id='A')
        pdb_string_reconstr = protein.to_pdb(prot)
        prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)

        np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
        np.testing.assert_array_almost_equal(prot_reconstr.atom_positions,
                                             prot.atom_positions)
        np.testing.assert_array_almost_equal(prot_reconstr.atom_mask,
                                             prot.atom_mask)
        np.testing.assert_array_equal(prot_reconstr.residue_index,
                                      prot.residue_index)
        np.testing.assert_array_almost_equal(prot_reconstr.b_factors,
                                             prot.b_factors)
Ejemplo n.º 3
0
 def test_from_pdb_str(self, pdb_file, chain_id, num_res):
     pdb_file = os.path.join(absltest.get_default_test_srcdir(),
                             TEST_DATA_DIR, pdb_file)
     with open(pdb_file) as f:
         pdb_string = f.read()
     prot = protein.from_pdb_string(pdb_string, chain_id)
     self._check_shapes(prot, num_res)
     self.assertGreaterEqual(prot.aatype.min(), 0)
     # Allow equal since unknown restypes have index equal to restype_num.
     self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
Ejemplo n.º 4
0
 def test_unresolved_violations(self):
   amber_relax = relax.AmberRelaxation(**self.test_config)
   with open(os.path.join(self.test_dir,
                                'with_violations_casp14.pdb')) as f:
     test_prot = protein.from_pdb_string(f.read())
   _, _, num_violations = amber_relax.process(prot=test_prot)
   exp_num_violations = np.array(
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0])
   # Check no violations were added. Can't check exactly due to stochasticity.
   self.assertTrue(np.all(num_violations <= exp_num_violations))
Ejemplo n.º 5
0
 def test_ideal_atom_mask(self):
     with open(
             os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
                          '2rbg.pdb')) as f:
         pdb_string = f.read()
     prot = protein.from_pdb_string(pdb_string, chain_id='A')
     ideal_mask = protein.ideal_atom_mask(prot)
     non_ideal_residues = set([102] + list(range(127, 285)))
     for i, (res,
             atom_mask) in enumerate(zip(prot.residue_index,
                                         prot.atom_mask)):
         if res in non_ideal_residues:
             self.assertFalse(np.all(atom_mask == ideal_mask[i]),
                              msg=f'{res}')
         else:
             self.assertTrue(np.all(atom_mask == ideal_mask[i]),
                             msg=f'{res}')
Ejemplo n.º 6
0
  def test_overwrite_b_factors(self):
    testdir = os.path.join(
        absltest.get_default_test_srcdir(),
        'alphafold/relax/testdata/'
        'multiple_disulfides_target.pdb')
    with open(testdir) as f:
      test_pdb = f.read()
    n_residues = 191
    bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1)

    output_pdb = utils.overwrite_b_factors(test_pdb, bfactors)

    # Check that the atom lines are unchanged apart from the B-factors.
    atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')]
    atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')]
    for line_original, line_new in zip(atom_lines_original, atom_lines_new):
      self.assertEqual(line_original[:60].strip(), line_new[:60].strip())
      self.assertEqual(line_original[66:].strip(), line_new[66:].strip())

    # Check B-factors are correctly set for all atoms present.
    as_protein = protein.from_pdb_string(output_pdb)
    np.testing.assert_almost_equal(
        np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0),
        np.where(as_protein.atom_mask > 0, bfactors, 0))
Ejemplo n.º 7
0
def _load_test_protein(data_path):
    pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
    with open(pdb_path, 'r') as f:
        return protein.from_pdb_string(f.read())
Ejemplo n.º 8
0
def run_pipeline(
    prot: protein.Protein,
    stiffness: float,
    max_outer_iterations: int = 1,
    place_hydrogens_every_iteration: bool = True,
    max_iterations: int = 0,
    tolerance: float = 2.39,
    restraint_set: str = "non_hydrogen",
    max_attempts: int = 100,
    checks: bool = True,
    exclude_residues: Optional[Sequence[int]] = None):
  """Run iterative amber relax.

  Successive relax iterations are performed until all violations have been
  resolved. Each iteration involves a restrained Amber minimization, with
  restraint exclusions determined by violation-participating residues.

  Args:
    prot: A protein to be relaxed.
    stiffness: kcal/mol A**2, the restraint stiffness.
    max_outer_iterations: The maximum number of iterative minimization.
    place_hydrogens_every_iteration: Whether hydrogens are re-initialized
        prior to every minimization.
    max_iterations: An `int` specifying the maximum number of L-BFGS steps
        per relax iteration. A value of 0 specifies no limit.
    tolerance: kcal/mol, the energy tolerance of L-BFGS.
        The default value is the OpenMM default.
    restraint_set: The set of atoms to restrain.
    max_attempts: The maximum number of minimization attempts per iteration.
    checks: Whether to perform cleaning checks.
    exclude_residues: An optional list of zero-indexed residues to exclude from
        restraints.

  Returns:
    out: A dictionary of output values.
  """

  # `protein.to_pdb` will strip any poorly-defined residues so we need to
  # perform this check before `clean_protein`.
  _check_residues_are_well_defined(prot)
  pdb_string = clean_protein(prot, checks=checks)

  exclude_residues = exclude_residues or []
  exclude_residues = set(exclude_residues)
  violations = np.inf
  iteration = 0

  while violations > 0 and iteration < max_outer_iterations:
    ret = _run_one_iteration(
        pdb_string=pdb_string,
        exclude_residues=exclude_residues,
        max_iterations=max_iterations,
        tolerance=tolerance,
        stiffness=stiffness,
        restraint_set=restraint_set,
        max_attempts=max_attempts)
    prot = protein.from_pdb_string(ret["min_pdb"])
    if place_hydrogens_every_iteration:
      pdb_string = clean_protein(prot, checks=True)
    else:
      pdb_string = ret["min_pdb"]
    ret.update(get_violation_metrics(prot))
    ret.update({
        "num_exclusions": len(exclude_residues),
        "iteration": iteration,
    })
    violations = ret["violations_per_residue"]
    exclude_residues = exclude_residues.union(ret["residue_violations"])

    logging.info("Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
                 "num residue violations %d num residue exclusions %d ",
                 ret["einit"], ret["efinal"], ret["opt_time"],
                 ret["num_residue_violations"], ret["num_exclusions"])
    iteration += 1
  return ret