コード例 #1
0
    def load_best(self, task):
        best = self.best_name
        if self.best_name is None:
            best = self.uid

        state = self.storage.safe_load(best, device=task.device)

        if state is not None:
            set_rng_states(state['rng'])
            load_state_dict(task, state)
コード例 #2
0
    def on_new_trial(self, task, step, parameters, uid):
        """On new trial try to resume the new trial"""
        # Make a unique id for resuming
        self.uid = parameters.get('uid', uid)

        if self.uid is None:
            self.uid = unique_trial_id(task.__class__.__name__, parameters)

        state = self.storage.safe_load(self.uid, device=task.device)

        if state is not None:
            set_rng_states(state['rng'])
            load_state_dict(task, state)
            info(f'Resuming (trial_id: {self.uid})')
        else:
            meta = dict(parameters=parameters, task=type(task).__name__)
            self.storage.save_meta(self.uid, meta)
            info(f'Starting a new (trial_id: {self.uid})')

        if state is None and self.save_init:
            state = state_dict(task)
            # state['rng'] = get_rng_states()
            self.storage.save(f'init_{self.uid}', state)
コード例 #3
0
 def load_state_dict(self, state):
     load_state_dict(self, state, strict=True, force_default=True)
     self.state = state['state']
     self.loss = state['loss']
コード例 #4
0
 def load_state_dict(self, state, strict=True):
     load_state_dict(self, state, strict, force_default=True)
     self._first_epoch = state['epoch']
     self.current_epoch = state['epoch']
コード例 #5
0
 def load_state_dict(self, state_dict):
     self.stopped = state_dict['stopped']
     load_state_dict(self.criterion, state_dict['criterion'])