예제 #1
0
    def ForwardMode(self, request, context):
        assert self.state is None

        if request.precision == 'single':
            precision = np.float32
        elif request.precision == 'double':
            precision = np.float64
        else:
            raise Exception("Unknown precision")

        system = pickle.loads(request.system)

        gradients = []
        force_names = []

        for grad_name, grad_args in system.gradients:
            force_names.append(grad_name)
            op_fn = getattr(ops, grad_name)
            grad = op_fn(*grad_args, precision=precision)
            gradients.append(grad)

        integrator = system.integrator

        stepper = custom_ops.AlchemicalStepper_f64(gradients, integrator.lambs)

        ctxt = custom_ops.ReversibleContext_f64(stepper, system.x0, system.v0,
                                                integrator.cas, integrator.cbs,
                                                integrator.ccs, integrator.dts,
                                                integrator.seed)

        start = time.time()
        ctxt.forward_mode()

        full_du_dls = stepper.get_du_dl()  # [FxT]
        energies = stepper.get_energies()

        keep_idxs = []

        if request.n_frames > 0:
            xs = ctxt.get_all_coords()
            interval = max(1, xs.shape[0] // request.n_frames)
            for frame_idx in range(xs.shape[0]):
                if frame_idx % interval == 0:
                    keep_idxs.append(frame_idx)

        frames = np.zeros((0, *system.x0.shape), dtype=system.x0.dtype)

        reply = service_pb2.ForwardReply(du_dls=pickle.dumps(full_du_dls),
                                         energies=pickle.dumps(energies),
                                         frames=pickle.dumps(frames))

        # store and set state for backwards mode use.
        if request.inference is False:
            self.state = (ctxt, gradients, force_names, stepper, system)

        return reply
예제 #2
0
    def test_reverse_mode_lambda(self):
        """
        This test ensures that we can reverse-mode differentiate
        observables that are dU_dlambdas of each state. We provide
        adjoints with respect to each computed dU/dLambda.
        """

        np.random.seed(4321)

        N = 5
        B = 5
        A = 0
        T = 0
        D = 3

        x0 = np.random.rand(N, D).astype(dtype=np.float64) * 2

        precision = np.float64

        (bond_params,
         ref_bond), test_bond = prepare_bonded_system(x0, B, A, T, precision)

        (restr_params,
         ref_restr), test_restr = prepare_restraints(x0, B, precision)

        E = 2

        lambda_plane_idxs = np.random.randint(low=0,
                                              high=2,
                                              size=N,
                                              dtype=np.int32)
        lambda_offset_idxs = np.random.randint(low=0,
                                               high=2,
                                               size=N,
                                               dtype=np.int32)

        (charge_params,
         lj_params), ref_nb_fn, test_nb_ctor = prepare_nonbonded_system(
             x0,
             E,
             lambda_plane_idxs,
             lambda_offset_idxs,
             p_scale=10.0,
             cutoff=1000.0,
             precision=precision)

        test_nb = test_nb_ctor()

        masses = np.random.rand(N)

        v0 = np.random.rand(x0.shape[0], x0.shape[1])
        N = len(masses)

        num_steps = 5
        lambda_schedule = np.random.rand(num_steps)
        cas = np.random.rand(num_steps)
        cbs = np.random.rand(len(masses)) / 10
        ccs = np.zeros_like(cbs)

        step_sizes = np.random.rand(num_steps)

        def loss_fn(du_dls):
            return jnp.sum(du_dls * du_dls) / du_dls.shape[0]

        def sum_loss_fn(du_dls):
            du_dls = np.sum(du_dls, axis=0)
            return jnp.sum(du_dls * du_dls) / du_dls.shape[0]

        def integrate_once_through(x_t, v_t, bond_params, restr_params,
                                   charge_params, lj_params):

            ref_bond_impl = functools.partial(ref_bond, params=bond_params)
            ref_restr_impl = functools.partial(ref_restr, params=restr_params)
            ref_nb_impl = functools.partial(ref_nb_fn,
                                            charge_params=charge_params,
                                            lj_params=lj_params)

            def ref_total_nrg_fn(*args):
                nrgs = []
                for fn in [ref_bond_impl, ref_restr_impl, ref_nb_impl]:
                    nrgs.append(fn(*args))
                return jnp.sum(nrgs)

            dU_dx_fn = jax.grad(ref_total_nrg_fn, argnums=(0, ))
            dU_dl_fn = jax.grad(ref_total_nrg_fn, argnums=(1, ))

            all_du_dls = []
            for step in range(num_steps):
                lamb = lambda_schedule[step]
                du_dl = dU_dl_fn(x_t, lamb)[0]
                all_du_dls.append(du_dl)
                dt = step_sizes[step]
                cb_tmp = np.expand_dims(cbs, axis=-1)
                v_t = cas[step] * v_t + cb_tmp * dU_dx_fn(x_t, lamb)[0]
                x_t = x_t + v_t * dt
                # note that we do not calculate the du_dl of the last frame.

            all_du_dls = jnp.stack(all_du_dls)
            return loss_fn(all_du_dls)

        # when we have multiple parameters, we need to set this up correctly
        ref_loss = integrate_once_through(x0, v0, bond_params, restr_params,
                                          charge_params, lj_params)

        grad_fn = jax.grad(integrate_once_through, argnums=(2, 3))
        ref_dl_dp_bond, ref_dl_dp_restr = grad_fn(x0, v0, bond_params,
                                                  restr_params, charge_params,
                                                  lj_params)

        stepper = custom_ops.AlchemicalStepper_f64(
            [test_bond, test_restr, test_nb], lambda_schedule)

        seed = 1234

        ctxt = custom_ops.ReversibleContext_f64(stepper, x0, v0, cas, cbs, ccs,
                                                step_sizes, seed)

        # run 5 steps forward
        ctxt.forward_mode()
        test_du_dls = stepper.get_du_dl()
        test_loss = sum_loss_fn(test_du_dls)
        loss_grad_fn = jax.grad(sum_loss_fn, argnums=(0, ))
        du_dl_adjoint = loss_grad_fn(test_du_dls)[0]

        # limit of precision is due to the settings in fixed_point.hpp
        # np.testing.assert_almost_equal(test_loss, ref_loss, decimal=7)
        np.testing.assert_allclose(test_loss, ref_loss, rtol=1e-6)
        stepper.set_du_dl_adjoint(du_dl_adjoint)
        ctxt.set_x_t_adjoint(np.zeros_like(x0))
        ctxt.backward_mode()

        test_dl_dp = test_bond.get_du_dp_tangents()
        np.testing.assert_allclose(test_dl_dp, ref_dl_dp_bond, rtol=1e-6)

        test_dl_dp = test_restr.get_du_dp_tangents()
        np.testing.assert_allclose(test_dl_dp, ref_dl_dp_restr, rtol=1e-6)
예제 #3
0
    def ForwardMode(self, request, context):

        if request.precision == 'single':
            precision = np.float32
        elif request.precision == 'double':
            precision = np.float64
        else:
            raise Exception("Unknown precision")

        system = pickle.loads(request.system)

        gradients = []
        force_names = []

        for grad_name, grad_args in system.gradients:
            force_names.append(grad_name)
            op_fn = getattr(ops, grad_name)
            grad = op_fn(*grad_args, precision=precision)
            gradients.append(grad)

        integrator = system.integrator

        stepper = custom_ops.AlchemicalStepper_f64(gradients, integrator.lambs)

        ctxt = custom_ops.ReversibleContext_f64(stepper, system.x0, system.v0,
                                                integrator.cas, integrator.cbs,
                                                integrator.ccs, integrator.dts,
                                                integrator.seed)

        start = time.time()

        # ensure only one GPU can be running at given time.
        total_size = 0

        with self.mutex:

            ctxt.forward_mode()
            full_du_dls = stepper.get_du_dl()  # [FxT]
            stripped_du_dls = []
            energies = stepper.get_energies()

            for force_du_dls in full_du_dls:
                # zero out
                if np.all(force_du_dls) == 0:
                    stripped_du_dls.append(None)
                else:
                    stripped_du_dls.append(force_du_dls)
                    total_size += len(force_du_dls)

            keep_idxs = []

            if request.n_frames > 0:
                xs = ctxt.get_all_coords()
                interval = max(1, xs.shape[0] // request.n_frames)
                for frame_idx in range(xs.shape[0]):
                    if frame_idx % interval == 0:
                        keep_idxs.append(frame_idx)
                frames = xs[keep_idxs]
            else:
                frames = np.zeros((0, *system.x0.shape), dtype=system.x0.dtype)

            # store and set state for backwards mode use.
            if request.inference is False:
                self.states[request.key] = (ctxt, gradients, force_names,
                                            stepper, system)

            return service_pb2.ForwardReply(
                du_dls=pickle.dumps(stripped_du_dls),  # tbd strip zeros
                energies=pickle.dumps(energies),
                frames=pickle.dumps(frames),
            )
예제 #4
0
    lowering_steps = 10000

    new_lambda_schedule = np.concatenate([
        np.linspace(1.0, 0.0, lowering_steps),
        np.zeros(n_steps - lowering_steps)
    ])

    stepper = custom_ops.AlchemicalStepper_f64(gradients, new_lambda_schedule
                                               # integrator.lambs
                                               )

    v0 = np.zeros_like(x0)

    ctxt = custom_ops.ReversibleContext_f64(stepper, x0, v0, integrator.cas,
                                            integrator.cbs, integrator.ccs,
                                            integrator.dts, integrator.seed)

    combined_pdb_str = StringIO(Chem.MolToPDBBlock(combined_pdb))
    out_file = "pose_dock.pdb"

    pdb_writer = PDBWriter(combined_pdb_str, out_file)

    # pdb_writer.write_header()

    # frames = ctxt.get_all_coords()
    # for frame_idx, x in enumerate(frames):
    #     # if frame_idx % 100 == 0:
    #     pdb_writer.write(x*10)
    #     break
    # pdb_writer.close()