def save(self, task): if self.uid is None: raise BadCheckpoint('No uid was given cannot save state') was_saved = False state = state_dict(task) state['rng'] = get_rng_states() # Was enough time passed since last save now = datetime.utcnow() elapsed = now - self.last_save should_save = elapsed.total_seconds() > self.time_buffer # Is it the best model we have seen so far is_best = True if self.keep_best is not None: is_best = self.keep_best(task.metrics.value()) if state: # Current model is not the best and we did not save the last model in a different path # (which is the best right now) # So we need to move the last state so it does not get overridden by current state if not is_best and self.best_name is None: info(f'Saving best ({self.keep_best.metric}: {self.keep_best.best})') self.best_name = self.new_best_name() was_pending = self.save_pending() if not was_pending: self.storage.rename(self.uid, self.best_name) if should_save: was_saved = self.storage.save(self.uid, state) self.save_pending() self.pending = None self.last_save = datetime.utcnow() else: self.save_pending() self.pending = (is_best, state) # we have a new best and the best was saved as with a different filename # So we need to change both the best state and the latest state if is_best and self.best_name is not None: info(f'New best ({self.keep_best.metric}: {self.keep_best.best})') self.storage.remove(self.best_name) self.best_name = self.new_best_name() was_pending = self.save_pending() if not was_pending: self.storage.copyfile(self.uid, self.best_name) else: warning('The state dictionary was empty!') if was_saved: info('Checkpoint saved') return info('Skipped Checkpoint')
def state_dict(self, destination=None, prefix='', keep_vars=False): state = state_dict(self, destination, prefix, keep_vars, force_default=True) state['epoch'] = self.current_epoch return state
def on_new_trial(self, task, step, parameters, uid): """On new trial try to resume the new trial""" # Make a unique id for resuming self.uid = parameters.get('uid', uid) if self.uid is None: self.uid = unique_trial_id(task.__class__.__name__, parameters) state = self.storage.safe_load(self.uid, device=task.device) if state is not None: set_rng_states(state['rng']) load_state_dict(task, state) info(f'Resuming (trial_id: {self.uid})') else: meta = dict(parameters=parameters, task=type(task).__name__) self.storage.save_meta(self.uid, meta) info(f'Starting a new (trial_id: {self.uid})') if state is None and self.save_init: state = state_dict(task) # state['rng'] = get_rng_states() self.storage.save(f'init_{self.uid}', state)
def state_dict(self, destination=None, prefix='', keep_vars=False): state = state_dict(self, destination, prefix, keep_vars, force_default=True) state['state'] = self.state state['loss'] = self.loss return state
def state_dict(self): state = dict(stopped=self.stopped) state['criterion'] = state_dict(self.criterion) return state