예제 #1
0
    def __converge(self, tpr_file=None, **kwargs):
        for key in ('append_output', ):
            if key in kwargs:
                raise TypeError(
                    'Conflicting key word argument. Cannot accept {}.'.format(
                        key))

        self.__prep_input(tpr_file)

        md = from_tpr(self._tprs, append_output=False, **kwargs)
        self.build_plugins(ConvergencePluginConfig())
        for plugin in self.__plugins:
            md.add_dependency(plugin)

        workdir = os.getcwd()
        self._logger.info("=====CONVERGENCE INFO======\n")
        self._logger.info(f'Working directory: {workdir}')

        context = _context(md,
                           workdir_list=self.workdirs,
                           communicator=self._communicator)
        with context as session:
            session.run()

        # Get the absolute time (in ps) at which the convergence run finished.
        # This value will be needed if a production run needs to be restarted.
        # noinspection PyUnresolvedReferences
        self.run_data.set(start_time=context.potentials[0].time)
        for name in self.__names:
            current_alpha = self.run_data.get('alpha', name=name)
            current_target = self.run_data.get('target', name=name)
            message = f'Plugin {name}: alpha = {current_alpha}, target = {current_target}'
            self._logger.info(message)

        return context
예제 #2
0
    def __converge(self, tpr_file=None, **kwargs):

        for key in ('append_output', ):
            if key in kwargs:
                raise TypeError(
                    'Conflicting key word argument. Cannot accept {}.'.format(
                        key))

        self.__prep_input(tpr_file)

        md = from_tpr(self._tprs, append_output=False, **kwargs)
        self.build_plugins(ConvergencePluginConfig())
        if len(self.__plugins) == 0:
            warnings.warn('No BRER restraints are being applied! User error?')
        for plugin in self.__plugins:
            md.add_dependency(plugin)

        workdir = os.getcwd()
        self._logger.info("=====CONVERGENCE INFO======\n")
        self._logger.info(f'Working directory: {workdir}')

        context = _context(md,
                           workdir_list=self.workdirs,
                           communicator=self._communicator)
        # WARNING: We do not yet handle situations where a rank has no work to do.
        # See https://github.com/kassonlab/run_brer/issues/18
        # and https://github.com/kassonlab/run_brer/issues/55
        with context as session:
            session.run()

        # Through at least gmxapi 0.4, the *potentials* attribute is created on
        # the Context for any Session launched with MD work to perform. An explicit
        # error message here should be more helpful than an AttributeError below,
        # but we don't really know what went wrong.
        # Ref https://github.com/kassonlab/run_brer/issues/55
        if not hasattr(context, 'potentials'):
            raise RuntimeError(
                'Invalid gmxapi Context: missing "potentials" attribute.')

        # Get the absolute time (in ps) at which the convergence run finished.
        # This value will be needed if a production run needs to be restarted.
        # noinspection PyUnresolvedReferences
        self.run_data.set(start_time=context.potentials[0].time)
        for name in self.__names:
            current_alpha = self.run_data.get('alpha', name=name)
            current_target = self.run_data.get('target', name=name)
            message = f'Plugin {name}: alpha = {current_alpha}, target = {current_target}'
            self._logger.info(message)

        return context
예제 #3
0
def test_ensemble_potential_nompi(spc_water_box):
    """Test ensemble potential without an ensemble.
    """
    tpr_filename = spc_water_box
    print("Testing plugin potential with input file {}".format(
        os.path.abspath(tpr_filename)))

    assert api_is_at_least(0, 0, 5)
    md = from_tpr([tpr_filename], append_output=False)

    # Create a WorkElement for the potential
    params = {
        'sites': [1, 4],
        'nbins': 10,
        'binWidth': 0.1,
        'min_dist': 0.,
        'max_dist': 10.,
        'experimental': [1.] * 10,
        'nsamples': 1,
        'sample_period': 0.001,
        'nwindows': 4,
        'k': 10000.,
        'sigma': 1.
    }
    potential = WorkElement(namespace="myplugin",
                            operation="ensemble_restraint",
                            params=params)
    # Note that we could flexibly capture accessor methods as workflow elements, too. Maybe we can
    # hide the extra Python bindings by letting myplugin.HarmonicRestraint automatically convert
    # to a WorkElement when add_dependency is called on it.
    potential.name = "ensemble_restraint"
    md.add_dependency(potential)

    context = _context(md)

    with context as session:
        session.run()
예제 #4
0
    def __production(self, tpr_file=None, **kwargs):

        for key in ('append_output', 'end_time'):
            if key in kwargs:
                raise TypeError(
                    'Conflicting key word argument. Cannot accept {}.'.format(
                        key))

        tpr_list = list(self._tprs)
        tpr_list[self._rank] = self.__prep_input(tpr_file)
        if tpr_file is not None:
            # If bootstrap TPR is provided, we are not continuing from the
            # convergence phase trajectory.
            self.run_data.set(start_time=0.0)

        # Calculate the time (in ps) at which the trajectory for this BRER iteration should finish.
        # This should be: the end time of the convergence run + the amount of time for
        # production simulation (specified by the user).
        start_time = self.run_data.get('start_time')
        target_end_time = self.run_data.get('production_time') + start_time

        md = from_tpr(tpr_list,
                      end_time=target_end_time,
                      append_output=False,
                      **kwargs)

        self.build_plugins(ProductionPluginConfig())
        if len(self.__plugins) == 0:
            warnings.warn('No BRER restraints are being applied! User error?')
        for plugin in self.__plugins:
            md.add_dependency(plugin)

        workdir = os.getcwd()
        self._logger.info("=====PRODUCTION INFO======\n")
        self._logger.info(f'Working directory: {workdir}')

        context = _context(md,
                           workdir_list=self.workdirs,
                           communicator=self._communicator)
        # WARNING: We do not yet handle situations where a rank has no work to do.
        # See https://github.com/kassonlab/run_brer/issues/18
        # and https://github.com/kassonlab/run_brer/issues/55
        with context as session:
            session.run()

        # Through at least gmxapi 0.4, the *potentials* attribute is created on
        # the Context for any Session launched with MD work to perform. An explicit
        # error message here should be more helpful than an AttributeError below,
        # but we don't really know what went wrong.
        # Ref https://github.com/kassonlab/run_brer/issues/55
        if not hasattr(context, 'potentials'):
            raise RuntimeError(
                'Invalid gmxapi Context: missing "potentials" attribute.')

        # Get the start and end times for the simulation managed by this Python interpreter.
        # Note that these are the times for all potentials in a single simulation
        # (which should be the same). We are not gathering values across any potential ensembles.
        start_times = [
            potential.start_time for potential in context.potentials
            if hasattr(potential, 'start_time')
        ]
        if len(start_times) > 0:
            session_start_time = start_times[0]
            if not all(session_start_time == t for t in start_times):
                self._logger.warning(
                    'Potentials report inconsistent start times: '
                    ', '.join(str(t) for t in start_times))
            assert session_start_time >= start_time
        else:
            # If the plugin attribute is missing, assume that the convergence phase behaved properly.
            session_start_time = start_time

        end_times = [
            potential.time for potential in context.potentials
            if hasattr(potential, 'time')
        ]
        if len(end_times) > 0:
            session_end_time = end_times[0]
            if not all(session_end_time == t for t in end_times):
                self._logger.warning(
                    'Potentials report inconsistent end times: '
                    ', '.join(str(t) for t in end_times))
        else:
            session_end_time = None

        if session_end_time is not None:
            self.run_data.set(end_time=session_end_time)

        trajectory_time = None
        if session_end_time is not None:
            trajectory_time = session_end_time - session_start_time

        if trajectory_time is not None:
            self._logger.info(
                f"{trajectory_time} ps production phase trajectory segment.")
        for name in self.__names:
            current_alpha = self.run_data.get('alpha', name=name)
            current_target = self.run_data.get('target', name=name)
            self._logger.info("Plugin {}: alpha = {}, target = {}".format(
                name, current_alpha, current_target))

        return context
예제 #5
0
    def __train(self, tpr_file=None, **kwargs):
        for key in ('append_output', ):
            if key in kwargs:
                raise TypeError(
                    'Conflicting key word argument. Cannot accept {}.'.format(
                        key))

        # do re-sampling
        targets = self.pairs.re_sample()
        self._logger.info('New targets: {}'.format(targets))
        for name in self.__names:
            self.run_data.set(name=name, target=targets[name])

        # save the new targets to the BRER checkpoint file.
        self.run_data.save_config(fnm=self.state_json)

        workdir = self.workdirs[self._rank]

        # backup existing checkpoint.
        # TODO: Don't backup the cpt, actually use it!!
        cpt = '{}/state.cpt'.format(workdir)
        if os.path.exists(cpt):
            self._logger.warning(
                'There is a checkpoint file in your current working directory, but you '
                'are '
                'training. The cpt will be backed up and the run will start over with '
                'new targets')
            shutil.move(cpt, '{}.bak'.format(cpt))

        # If this is not the first BRER iteration, grab the checkpoint from the production
        # phase of the last round
        self.__prep_input(tpr_file)

        # Set up a dictionary to go from plugin name -> restraint name
        sites_to_name = {}

        # Build the gmxapi session.
        tpr_list: Sequence[str] = self._tprs
        md = from_tpr(tpr_list, append_output=False, **kwargs)
        self.build_plugins(TrainingPluginConfig())
        if len(self.__plugins) == 0:
            warnings.warn('No BRER restraints are being applied! User error?')
        for plugin in self.__plugins:
            plugin_name = plugin.name
            for name in self.__names:
                run_data_sites = "{}".format(
                    self.run_data.get('sites', name=name))
                if run_data_sites == plugin_name:
                    sites_to_name[plugin_name] = name
            md.add_dependency(plugin)
        context = _context(md,
                           workdir_list=self.workdirs,
                           communicator=self._communicator)

        self._logger.info("=====TRAINING INFO======\n")
        self._logger.info(f'Working directory: {workdir}')

        # Run it.
        # WARNING: We do not yet handle situations where a rank has no work to do.
        # See https://github.com/kassonlab/run_brer/issues/18
        # and https://github.com/kassonlab/run_brer/issues/55
        with context as session:
            session.run()

        # Through at least gmxapi 0.4, the *potentials* attribute is created on
        # the Context for any Session launched with MD work to perform. An explicit
        # error message here should be more helpful than an AttributeError below,
        # but we don't really know what went wrong.
        # Ref https://github.com/kassonlab/run_brer/issues/55
        if not hasattr(context, 'potentials'):
            raise RuntimeError(
                'Invalid gmxapi Context: missing "potentials" attribute.')

        for i in range(len(self.__names)):
            # TODO: ParallelArrayContext.potentials needs to be declared to avoid IDE
            #  warnings.
            # noinspection PyUnresolvedReferences
            current_name = sites_to_name[context.potentials[i].name]
            # In the future runs (convergence, production) we need the ABSOLUTE VALUE
            # of alpha.
            # noinspection PyUnresolvedReferences
            current_alpha = context.potentials[i].alpha
            if current_alpha == 0.0:
                raise RuntimeError(
                    'Alpha value was constrained to 0.0, which indicates something went wrong'
                )

            # noinspection PyUnresolvedReferences
            current_target = context.potentials[i].target

            self.run_data.set(name=current_name, alpha=current_alpha)
            self.run_data.set(name=current_name, target=current_target)
            self._logger.info("Plugin {}: alpha = {}, target = {}".format(
                current_name, current_alpha, current_target))

        return context
예제 #6
0
    def __production(self, tpr_file=None, **kwargs):

        for key in ('append_output', 'end_time'):
            if key in kwargs:
                raise TypeError(
                    'Conflicting key word argument. Cannot accept {}.'.format(
                        key))

        tpr_list = list(self._tprs)
        tpr_list[self._rank] = self.__prep_input(tpr_file)
        if tpr_file is not None:
            # If bootstrap TPR is provided, we are not continuing from the
            # convergence phase trajectory.
            self.run_data.set(start_time=0.0)

        # Calculate the time (in ps) at which the trajectory for this BRER iteration should finish.
        # This should be: the end time of the convergence run + the amount of time for
        # production simulation (specified by the user).
        start_time = self.run_data.get('start_time')
        target_end_time = self.run_data.get('production_time') + start_time

        md = from_tpr(tpr_list,
                      end_time=target_end_time,
                      append_output=False,
                      **kwargs)

        self.build_plugins(ProductionPluginConfig())
        for plugin in self.__plugins:
            md.add_dependency(plugin)

        workdir = os.getcwd()
        self._logger.info("=====PRODUCTION INFO======\n")
        self._logger.info(f'Working directory: {workdir}')

        context = _context(md,
                           workdir_list=self.workdirs,
                           communicator=self._communicator)
        with context as session:
            session.run()

        # Get the start and end times for the simulation managed by this Python interpreter.
        # Note that these are the times for all potentials in a single simulation
        # (which should be the same). We are not gathering values across any potential ensembles.
        start_times = [
            potential.start_time for potential in context.potentials
            if hasattr(potential, 'start_time')
        ]
        if len(start_times) > 0:
            session_start_time = start_times[0]
            if not all(session_start_time == t for t in start_times):
                self._logger.warning(
                    'Potentials report inconsistent start times: '
                    ', '.join(str(t) for t in start_times))
            assert session_start_time >= start_time
        else:
            # If the plugin attribute is missing, assume that the convergence phase behaved properly.
            session_start_time = start_time

        end_times = [
            potential.time for potential in context.potentials
            if hasattr(potential, 'time')
        ]
        if len(end_times) > 0:
            session_end_time = end_times[0]
            if not all(session_end_time == t for t in end_times):
                self._logger.warning(
                    'Potentials report inconsistent end times: '
                    ', '.join(str(t) for t in end_times))
        else:
            session_end_time = None

        if session_end_time is not None:
            self.run_data.set(end_time=session_end_time)

        trajectory_time = None
        if session_end_time is not None:
            trajectory_time = session_end_time - session_start_time

        if trajectory_time is not None:
            self._logger.info(
                f"{trajectory_time} ps production phase trajectory segment.")
        for name in self.__names:
            current_alpha = self.run_data.get('alpha', name=name)
            current_target = self.run_data.get('target', name=name)
            self._logger.info("Plugin {}: alpha = {}, target = {}".format(
                name, current_alpha, current_target))

        return context
예제 #7
0
    def __train(self, tpr_file=None, **kwargs):
        for key in ('append_output', ):
            if key in kwargs:
                raise TypeError(
                    'Conflicting key word argument. Cannot accept {}.'.format(
                        key))

        # do re-sampling
        targets = self.pairs.re_sample()
        self._logger.info('New targets: {}'.format(targets))
        for name in self.__names:
            self.run_data.set(name=name, target=targets[name])

        # save the new targets to the BRER checkpoint file.
        self.run_data.save_config(fnm=self.state_json)

        workdir = self.workdirs[self._rank]

        # backup existing checkpoint.
        # TODO: Don't backup the cpt, actually use it!!
        cpt = '{}/state.cpt'.format(workdir)
        if os.path.exists(cpt):
            self._logger.warning(
                'There is a checkpoint file in your current working directory, but you '
                'are '
                'training. The cpt will be backed up and the run will start over with '
                'new targets')
            shutil.move(cpt, '{}.bak'.format(cpt))

        # If this is not the first BRER iteration, grab the checkpoint from the production
        # phase of the last round
        self.__prep_input(tpr_file)

        # Set up a dictionary to go from plugin name -> restraint name
        sites_to_name = {}

        # Build the gmxapi session.
        tpr_list: Sequence[str] = self._tprs
        md = from_tpr(tpr_list, append_output=False, **kwargs)
        self.build_plugins(TrainingPluginConfig())
        for plugin in self.__plugins:
            plugin_name = plugin.name
            for name in self.__names:
                run_data_sites = "{}".format(
                    self.run_data.get('sites', name=name))
                if run_data_sites == plugin_name:
                    sites_to_name[plugin_name] = name
            md.add_dependency(plugin)
        context = _context(md,
                           workdir_list=self.workdirs,
                           communicator=self._communicator)

        self._logger.info("=====TRAINING INFO======\n")
        self._logger.info(f'Working directory: {workdir}')

        # Run it.
        with context as session:
            session.run()

        for i in range(len(self.__names)):
            # TODO: ParallelArrayContext.potentials needs to be declared to avoid IDE
            #  warnings.
            # noinspection PyUnresolvedReferences
            current_name = sites_to_name[context.potentials[i].name]
            # In the future runs (convergence, production) we need the ABSOLUTE VALUE
            # of alpha.
            # noinspection PyUnresolvedReferences
            current_alpha = context.potentials[i].alpha
            # noinspection PyUnresolvedReferences
            current_target = context.potentials[i].target

            self.run_data.set(name=current_name, alpha=current_alpha)
            self.run_data.set(name=current_name, target=current_target)
            self._logger.info("Plugin {}: alpha = {}, target = {}".format(
                current_name, current_alpha, current_target))

        return context