def _new_trial(self, force=False, parameters=None, **kwargs): """Create a new trial if all the required arguments are satisfied. To provide a better user experience if not all arguments are provided a delayed trials is created that holds all the data provided and will create the trial once all arguments are ready. Currently only `arguments` i.e the parameters of the experience is required. This is because they are needed to compute the trial uid (which is a hash of the parameters). If no project is set, the trial is inserted in a catch all project named `orphan` Parameters ---------- force: bool by default once the trial is set it cannot be changed. use force to override this behaviour. kwargs See :func:`~track.structure.Trial` for possible arguments Returns ------- returns a trial """ if isinstance(parameters, Namespace): parameters = dict(**vars(parameters)) if self.trial is not None and not is_delayed_call( self.trial) and not force: info(f'Trial is already set, to override use force=True') return self.trial # if arguments are not specified do not create the trial just yet # wait for the user to be able to specify the parameters so we can have a meaningful hash if parameters is None: if is_delayed_call(self.trial): raise RuntimeError('Trial needs parameters') self.trial = delay_call(self._new_trial, **kwargs) # return the logger with the delayed trial return self.trial # replace the trial or delayed trial by its actual value if parameters or is_delayed_call(self.trial): self.trial = self._make_trial(parameters=parameters, **kwargs) assert self.trial is not None, f'Trial cant be None because {parameters} or {is_delayed_call(self.trial)}' if self.project is None: self.project = self.set_project(name='orphan') assert self.project is not None, 'Project cant be None' self.protocol.add_project_trial(self.project, self.trial) if self.group is not None: self.protocol.add_group_trial(self.group, self.trial) return self.trial
def test_delay(): class Obj: a = 0 b = 0 def set(self, a=0, b=0): self.a = a self.b = b return self def get(self): return self.a + self.b def add(a, b): return a + b delayed_call = delay_call(add, a=2) assert delayed_call(b=2) == 4 obj = Obj() delayed_call2 = delay_call(obj.set, a=2) assert is_delayed_call(delayed_call2) assert delayed_call2.get() == 2
def add_tags(self, **kwargs): """Insert tags to current trials""" # We do not need to create the trial to add tags. # just append the tags to the trial call when it is going to be created if is_delayed_call(self.trial): self.trial.add_arguments(tags=kwargs) else: self.logger.add_tags(**kwargs)
def __getattr__(self, item): """Try to use the backend attributes if not available""" if is_delayed_call(self.trial): warning('Creating a trial without parameters!') self.logger = self.trial() self.trial = self.logger.trial # Look for the attribute in the top level logger if hasattr(self.logger, item): return getattr(self.logger, item) raise AttributeError(item)
def log_arguments(self, args: Union[ArgumentParser, Namespace, Dict] = None, show=False, **kwargs) -> Namespace: """Store the arguments that was used to run the trial. Parameters ---------- args: Union[ArgumentParser, Namespace, Dict] save up the trial's arguments show: bool print the arguments on the command line kwargs more trial's arguments Returns ------- returns the trial's arguments """ nargs = dict() if args is not None: nargs = args if isinstance(args, ArgumentParser): nargs = args.parse_args() if isinstance(nargs, Namespace): nargs = dict(**vars(nargs)) kwargs.update(nargs) # if we have a pending trial create it now as we have all the information if is_delayed_call(self.trial): self.trial = self.trial(parameters=kwargs) self.logger = TrialLogger(self.trial, self.protocol) else: # we do not need to log the arguments they are inside the trial already self.logger.log_arguments(**kwargs) if show: print('-' * 80) for k, v in vars(args).items(): print(f'{k:>30}: {v}') print('-' * 80) return args
def log_arguments(self, **kwargs): """log the trial arguments. This function has not effect if the trial was already created.""" # arguments are set at trial creation if is_delayed_call(self.trial): self.trial = self.trial(arguments=kwargs)