def __converge(self, tpr_file=None, **kwargs): for key in ('append_output', ): if key in kwargs: raise TypeError( 'Conflicting key word argument. Cannot accept {}.'.format( key)) self.__prep_input(tpr_file) md = from_tpr(self._tprs, append_output=False, **kwargs) self.build_plugins(ConvergencePluginConfig()) for plugin in self.__plugins: md.add_dependency(plugin) workdir = os.getcwd() self._logger.info("=====CONVERGENCE INFO======\n") self._logger.info(f'Working directory: {workdir}') context = _context(md, workdir_list=self.workdirs, communicator=self._communicator) with context as session: session.run() # Get the absolute time (in ps) at which the convergence run finished. # This value will be needed if a production run needs to be restarted. # noinspection PyUnresolvedReferences self.run_data.set(start_time=context.potentials[0].time) for name in self.__names: current_alpha = self.run_data.get('alpha', name=name) current_target = self.run_data.get('target', name=name) message = f'Plugin {name}: alpha = {current_alpha}, target = {current_target}' self._logger.info(message) return context
def __converge(self, tpr_file=None, **kwargs): for key in ('append_output', ): if key in kwargs: raise TypeError( 'Conflicting key word argument. Cannot accept {}.'.format( key)) self.__prep_input(tpr_file) md = from_tpr(self._tprs, append_output=False, **kwargs) self.build_plugins(ConvergencePluginConfig()) if len(self.__plugins) == 0: warnings.warn('No BRER restraints are being applied! User error?') for plugin in self.__plugins: md.add_dependency(plugin) workdir = os.getcwd() self._logger.info("=====CONVERGENCE INFO======\n") self._logger.info(f'Working directory: {workdir}') context = _context(md, workdir_list=self.workdirs, communicator=self._communicator) # WARNING: We do not yet handle situations where a rank has no work to do. # See https://github.com/kassonlab/run_brer/issues/18 # and https://github.com/kassonlab/run_brer/issues/55 with context as session: session.run() # Through at least gmxapi 0.4, the *potentials* attribute is created on # the Context for any Session launched with MD work to perform. An explicit # error message here should be more helpful than an AttributeError below, # but we don't really know what went wrong. # Ref https://github.com/kassonlab/run_brer/issues/55 if not hasattr(context, 'potentials'): raise RuntimeError( 'Invalid gmxapi Context: missing "potentials" attribute.') # Get the absolute time (in ps) at which the convergence run finished. # This value will be needed if a production run needs to be restarted. # noinspection PyUnresolvedReferences self.run_data.set(start_time=context.potentials[0].time) for name in self.__names: current_alpha = self.run_data.get('alpha', name=name) current_target = self.run_data.get('target', name=name) message = f'Plugin {name}: alpha = {current_alpha}, target = {current_target}' self._logger.info(message) return context
def test_ensemble_potential_nompi(spc_water_box): """Test ensemble potential without an ensemble. """ tpr_filename = spc_water_box print("Testing plugin potential with input file {}".format( os.path.abspath(tpr_filename))) assert api_is_at_least(0, 0, 5) md = from_tpr([tpr_filename], append_output=False) # Create a WorkElement for the potential params = { 'sites': [1, 4], 'nbins': 10, 'binWidth': 0.1, 'min_dist': 0., 'max_dist': 10., 'experimental': [1.] * 10, 'nsamples': 1, 'sample_period': 0.001, 'nwindows': 4, 'k': 10000., 'sigma': 1. } potential = WorkElement(namespace="myplugin", operation="ensemble_restraint", params=params) # Note that we could flexibly capture accessor methods as workflow elements, too. Maybe we can # hide the extra Python bindings by letting myplugin.HarmonicRestraint automatically convert # to a WorkElement when add_dependency is called on it. potential.name = "ensemble_restraint" md.add_dependency(potential) context = _context(md) with context as session: session.run()
def __production(self, tpr_file=None, **kwargs): for key in ('append_output', 'end_time'): if key in kwargs: raise TypeError( 'Conflicting key word argument. Cannot accept {}.'.format( key)) tpr_list = list(self._tprs) tpr_list[self._rank] = self.__prep_input(tpr_file) if tpr_file is not None: # If bootstrap TPR is provided, we are not continuing from the # convergence phase trajectory. self.run_data.set(start_time=0.0) # Calculate the time (in ps) at which the trajectory for this BRER iteration should finish. # This should be: the end time of the convergence run + the amount of time for # production simulation (specified by the user). start_time = self.run_data.get('start_time') target_end_time = self.run_data.get('production_time') + start_time md = from_tpr(tpr_list, end_time=target_end_time, append_output=False, **kwargs) self.build_plugins(ProductionPluginConfig()) if len(self.__plugins) == 0: warnings.warn('No BRER restraints are being applied! User error?') for plugin in self.__plugins: md.add_dependency(plugin) workdir = os.getcwd() self._logger.info("=====PRODUCTION INFO======\n") self._logger.info(f'Working directory: {workdir}') context = _context(md, workdir_list=self.workdirs, communicator=self._communicator) # WARNING: We do not yet handle situations where a rank has no work to do. # See https://github.com/kassonlab/run_brer/issues/18 # and https://github.com/kassonlab/run_brer/issues/55 with context as session: session.run() # Through at least gmxapi 0.4, the *potentials* attribute is created on # the Context for any Session launched with MD work to perform. An explicit # error message here should be more helpful than an AttributeError below, # but we don't really know what went wrong. # Ref https://github.com/kassonlab/run_brer/issues/55 if not hasattr(context, 'potentials'): raise RuntimeError( 'Invalid gmxapi Context: missing "potentials" attribute.') # Get the start and end times for the simulation managed by this Python interpreter. # Note that these are the times for all potentials in a single simulation # (which should be the same). We are not gathering values across any potential ensembles. start_times = [ potential.start_time for potential in context.potentials if hasattr(potential, 'start_time') ] if len(start_times) > 0: session_start_time = start_times[0] if not all(session_start_time == t for t in start_times): self._logger.warning( 'Potentials report inconsistent start times: ' ', '.join(str(t) for t in start_times)) assert session_start_time >= start_time else: # If the plugin attribute is missing, assume that the convergence phase behaved properly. session_start_time = start_time end_times = [ potential.time for potential in context.potentials if hasattr(potential, 'time') ] if len(end_times) > 0: session_end_time = end_times[0] if not all(session_end_time == t for t in end_times): self._logger.warning( 'Potentials report inconsistent end times: ' ', '.join(str(t) for t in end_times)) else: session_end_time = None if session_end_time is not None: self.run_data.set(end_time=session_end_time) trajectory_time = None if session_end_time is not None: trajectory_time = session_end_time - session_start_time if trajectory_time is not None: self._logger.info( f"{trajectory_time} ps production phase trajectory segment.") for name in self.__names: current_alpha = self.run_data.get('alpha', name=name) current_target = self.run_data.get('target', name=name) self._logger.info("Plugin {}: alpha = {}, target = {}".format( name, current_alpha, current_target)) return context
def __train(self, tpr_file=None, **kwargs): for key in ('append_output', ): if key in kwargs: raise TypeError( 'Conflicting key word argument. Cannot accept {}.'.format( key)) # do re-sampling targets = self.pairs.re_sample() self._logger.info('New targets: {}'.format(targets)) for name in self.__names: self.run_data.set(name=name, target=targets[name]) # save the new targets to the BRER checkpoint file. self.run_data.save_config(fnm=self.state_json) workdir = self.workdirs[self._rank] # backup existing checkpoint. # TODO: Don't backup the cpt, actually use it!! cpt = '{}/state.cpt'.format(workdir) if os.path.exists(cpt): self._logger.warning( 'There is a checkpoint file in your current working directory, but you ' 'are ' 'training. The cpt will be backed up and the run will start over with ' 'new targets') shutil.move(cpt, '{}.bak'.format(cpt)) # If this is not the first BRER iteration, grab the checkpoint from the production # phase of the last round self.__prep_input(tpr_file) # Set up a dictionary to go from plugin name -> restraint name sites_to_name = {} # Build the gmxapi session. tpr_list: Sequence[str] = self._tprs md = from_tpr(tpr_list, append_output=False, **kwargs) self.build_plugins(TrainingPluginConfig()) if len(self.__plugins) == 0: warnings.warn('No BRER restraints are being applied! User error?') for plugin in self.__plugins: plugin_name = plugin.name for name in self.__names: run_data_sites = "{}".format( self.run_data.get('sites', name=name)) if run_data_sites == plugin_name: sites_to_name[plugin_name] = name md.add_dependency(plugin) context = _context(md, workdir_list=self.workdirs, communicator=self._communicator) self._logger.info("=====TRAINING INFO======\n") self._logger.info(f'Working directory: {workdir}') # Run it. # WARNING: We do not yet handle situations where a rank has no work to do. # See https://github.com/kassonlab/run_brer/issues/18 # and https://github.com/kassonlab/run_brer/issues/55 with context as session: session.run() # Through at least gmxapi 0.4, the *potentials* attribute is created on # the Context for any Session launched with MD work to perform. An explicit # error message here should be more helpful than an AttributeError below, # but we don't really know what went wrong. # Ref https://github.com/kassonlab/run_brer/issues/55 if not hasattr(context, 'potentials'): raise RuntimeError( 'Invalid gmxapi Context: missing "potentials" attribute.') for i in range(len(self.__names)): # TODO: ParallelArrayContext.potentials needs to be declared to avoid IDE # warnings. # noinspection PyUnresolvedReferences current_name = sites_to_name[context.potentials[i].name] # In the future runs (convergence, production) we need the ABSOLUTE VALUE # of alpha. # noinspection PyUnresolvedReferences current_alpha = context.potentials[i].alpha if current_alpha == 0.0: raise RuntimeError( 'Alpha value was constrained to 0.0, which indicates something went wrong' ) # noinspection PyUnresolvedReferences current_target = context.potentials[i].target self.run_data.set(name=current_name, alpha=current_alpha) self.run_data.set(name=current_name, target=current_target) self._logger.info("Plugin {}: alpha = {}, target = {}".format( current_name, current_alpha, current_target)) return context
def __production(self, tpr_file=None, **kwargs): for key in ('append_output', 'end_time'): if key in kwargs: raise TypeError( 'Conflicting key word argument. Cannot accept {}.'.format( key)) tpr_list = list(self._tprs) tpr_list[self._rank] = self.__prep_input(tpr_file) if tpr_file is not None: # If bootstrap TPR is provided, we are not continuing from the # convergence phase trajectory. self.run_data.set(start_time=0.0) # Calculate the time (in ps) at which the trajectory for this BRER iteration should finish. # This should be: the end time of the convergence run + the amount of time for # production simulation (specified by the user). start_time = self.run_data.get('start_time') target_end_time = self.run_data.get('production_time') + start_time md = from_tpr(tpr_list, end_time=target_end_time, append_output=False, **kwargs) self.build_plugins(ProductionPluginConfig()) for plugin in self.__plugins: md.add_dependency(plugin) workdir = os.getcwd() self._logger.info("=====PRODUCTION INFO======\n") self._logger.info(f'Working directory: {workdir}') context = _context(md, workdir_list=self.workdirs, communicator=self._communicator) with context as session: session.run() # Get the start and end times for the simulation managed by this Python interpreter. # Note that these are the times for all potentials in a single simulation # (which should be the same). We are not gathering values across any potential ensembles. start_times = [ potential.start_time for potential in context.potentials if hasattr(potential, 'start_time') ] if len(start_times) > 0: session_start_time = start_times[0] if not all(session_start_time == t for t in start_times): self._logger.warning( 'Potentials report inconsistent start times: ' ', '.join(str(t) for t in start_times)) assert session_start_time >= start_time else: # If the plugin attribute is missing, assume that the convergence phase behaved properly. session_start_time = start_time end_times = [ potential.time for potential in context.potentials if hasattr(potential, 'time') ] if len(end_times) > 0: session_end_time = end_times[0] if not all(session_end_time == t for t in end_times): self._logger.warning( 'Potentials report inconsistent end times: ' ', '.join(str(t) for t in end_times)) else: session_end_time = None if session_end_time is not None: self.run_data.set(end_time=session_end_time) trajectory_time = None if session_end_time is not None: trajectory_time = session_end_time - session_start_time if trajectory_time is not None: self._logger.info( f"{trajectory_time} ps production phase trajectory segment.") for name in self.__names: current_alpha = self.run_data.get('alpha', name=name) current_target = self.run_data.get('target', name=name) self._logger.info("Plugin {}: alpha = {}, target = {}".format( name, current_alpha, current_target)) return context
def __train(self, tpr_file=None, **kwargs): for key in ('append_output', ): if key in kwargs: raise TypeError( 'Conflicting key word argument. Cannot accept {}.'.format( key)) # do re-sampling targets = self.pairs.re_sample() self._logger.info('New targets: {}'.format(targets)) for name in self.__names: self.run_data.set(name=name, target=targets[name]) # save the new targets to the BRER checkpoint file. self.run_data.save_config(fnm=self.state_json) workdir = self.workdirs[self._rank] # backup existing checkpoint. # TODO: Don't backup the cpt, actually use it!! cpt = '{}/state.cpt'.format(workdir) if os.path.exists(cpt): self._logger.warning( 'There is a checkpoint file in your current working directory, but you ' 'are ' 'training. The cpt will be backed up and the run will start over with ' 'new targets') shutil.move(cpt, '{}.bak'.format(cpt)) # If this is not the first BRER iteration, grab the checkpoint from the production # phase of the last round self.__prep_input(tpr_file) # Set up a dictionary to go from plugin name -> restraint name sites_to_name = {} # Build the gmxapi session. tpr_list: Sequence[str] = self._tprs md = from_tpr(tpr_list, append_output=False, **kwargs) self.build_plugins(TrainingPluginConfig()) for plugin in self.__plugins: plugin_name = plugin.name for name in self.__names: run_data_sites = "{}".format( self.run_data.get('sites', name=name)) if run_data_sites == plugin_name: sites_to_name[plugin_name] = name md.add_dependency(plugin) context = _context(md, workdir_list=self.workdirs, communicator=self._communicator) self._logger.info("=====TRAINING INFO======\n") self._logger.info(f'Working directory: {workdir}') # Run it. with context as session: session.run() for i in range(len(self.__names)): # TODO: ParallelArrayContext.potentials needs to be declared to avoid IDE # warnings. # noinspection PyUnresolvedReferences current_name = sites_to_name[context.potentials[i].name] # In the future runs (convergence, production) we need the ABSOLUTE VALUE # of alpha. # noinspection PyUnresolvedReferences current_alpha = context.potentials[i].alpha # noinspection PyUnresolvedReferences current_target = context.potentials[i].target self.run_data.set(name=current_name, alpha=current_alpha) self.run_data.set(name=current_name, target=current_target) self._logger.info("Plugin {}: alpha = {}, target = {}".format( current_name, current_alpha, current_target)) return context