Exemplo n.º 1
0
        swa_model.update_parameters(model)

        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            update_bn(loaders['train'], swa_model, device=torch.device('cuda'))
            if args.swa_on_cpu:
                # moving swa_model to gpu for evaluation
                model = model.cpu()
                swa_model = swa_model.to(device)
            print("SWA eval")
            swa_res = utils.eval(loaders['test'],
                                 swa_model,
                                 criterion,
                                 device=device)
            if args.swa_on_cpu:
                model = model.to(device)
                swa_model = swa_model.cpu()
        else:
            swa_res = {'loss': None, 'accuracy': None}

    if (epoch + 1) % args.save_freq == 0:
        utils.save_checkpoint(
            args.dir,
            epoch + 1,
            state_dict=model.state_dict(),
            swa_state_dict=swa_model.state_dict() if args.swa else None,
            optimizer=optimizer.state_dict())

    time_ep = time.time() - time_ep
    values = [
        epoch + 1, lr, train_res['loss'], train_res['accuracy'],
        test_res['loss'], test_res['accuracy'], time_ep
Exemplo n.º 2
0
class BaseTrainer:
    """
    Base class for all trainers
    """
    def __init__(self, model, loss, metrics, resume, config, train_logger=None):
        self.config = config
        self.logger = logging.getLogger(self.__class__.__name__)
        #if type(model) is tuple:
        #    self.model,self.model_ref
        self.model = model
        self.model_ref = model

        self.loss = loss
        self.metrics = metrics
        self.name = config['name']
        self.logged = config['super_computer'] if 'super_computer' in config else False
        self.iterations = config['trainer']['iterations']
        self.val_step = config['trainer']['val_step']
        self.save_step = config['trainer']['save_step']
        self.save_step_minor = config['trainer']['save_step_minor'] if 'save_step_minor' in config['trainer'] else None
        self.log_step = config['trainer']['log_step']
        self.verbosity = config['trainer']['verbosity']
        self.with_cuda = config['cuda'] and torch.cuda.is_available()
        if config['cuda'] and not torch.cuda.is_available():
            self.logger.warning('Warning: There\'s no CUDA support on this machine, '
                                'training is performed on CPU.')
        elif config['cuda']:
            self.gpu = torch.device('cuda:' + str(config['gpu']))
            self.model = self.model.to(self.gpu)
        else:
            self.gpu=None
        if 'multiprocess' in config or 'distributed' in config:
            self.model = DistributedDataParallel(
                    self.model,
                    find_unused_parameters=True)

        self.train_logger = train_logger
        if config['optimizer_type']!="none":
            main_params=[]
            slow_params=[]
            slower_params=[]
            not_as_slow_params=[]
            slow_param_names = config['trainer']['slow_param_names'] if 'slow_param_names' in config['trainer'] else []
            slower_param_names = config['trainer']['slower_param_names'] if 'slower_param_names' in config['trainer'] else []
            not_as_slow_param_names = config['trainer']['not_as_slow_param_names'] if 'not_as_slow_param_names' in config['trainer'] else []
            freeze_param_names = config['trainer']['freeze_param_names'] if 'freeze_param_names' in config['trainer'] else []
            only_params = config['trainer']['only_params'] if 'only_params' in config['trainer'] else None
            for name,param in model.named_parameters():
                if only_params is None or any([p in name for p in only_params]):
                    goSlow=False
                    goSlower=False
                    goNotAsSlow=False
                    freeze=False
                    for sp in slower_param_names:
                        if sp in name:
                            goSlower=True
                            break
                    for sp in not_as_slow_param_names:
                        if sp in name:
                            goNotAsSlow=True
                            break
                    for sp in slow_param_names:
                        if sp in name:
                            goSlow=True
                            break
                    for fp in freeze_param_names:
                        if fp in name:
                            freeze=True
                            break
                    if freeze:
                        pass
                    elif goNotAsSlow:
                        not_as_slow_params.append(param)
                    elif goSlower:
                        slower_params.append(param)
                    elif goSlow:
                        slow_params.append(param)
                    elif ('hwr' in name and self.hwr_frozen) or ('style_extractor' in name and self.style_frozen):
                        pass
                    elif 'style_extractor' in name and self.curriculum.need_style_in_disc:
                        discriminator_params.append(param)
                    else:
                        main_params.append(param)
            to_opt = [
                    {'params': main_params}, 
                    {'params': slow_params, 'lr': config['optimizer']['lr']*0.1}]
            if len(not_as_slow_params)>0:
                to_opt.append({'params': not_as_slow_params, 'lr': config['optimizer']['lr']*0.5})
            if len(slower_params)>0:
                to_opt.append({'params': slower_params, 'lr': config['optimizer']['lr']*0.01})
            should_be_in_to_opt = 2 + (1 if len(slower_param_names)>0 else 0) + (1 if len(not_as_slow_param_names)>0 else 0)
            assert should_be_in_to_opt + len(to_opt) #help catch errors in param names
            self.optimizer = getattr(optim, config['optimizer_type'])(to_opt,
                                                                      **config['optimizer'])
                    #self.optimizer = getattr(optim, config['optimizer_type'])(model.parameters(),



        self.swa = config['trainer']['swa'] if 'swa' in config['trainer'] else (config['trainer']['weight_averaging'] if 'weight_averaging' in config['trainer'] else False)
        if self.swa:
            self.swa_model = AveragedModel(self.model)#type(self.model)(config['model'])
            #if config['cuda']:
            #    self.swa_model = self.swa_model.to(self.gpu)
            self.swa_start = config['trainer']['swa_start'] if 'swa_start' in config['trainer'] else config['trainer']['weight_averaging_start']
            #self.swa_c_iters = config['trainer']['swa_c_iters'] if 'swa_c_iters' in config['trainer'] else config['trainer']['weight_averaging_c_iters']
            self.swa_avg_every = config['trainer']['swa_avg_every'] if 'swa_avg_every' in config['trainer'] else 0
            assert(self.val_step>=self.swa_avg_every) #otherwise we'll start evaluating more than the (swa)model is updated



        self.useLearningSchedule = config['trainer']['use_learning_schedule'] if 'use_learning_schedule' in config['trainer'] else False
        if self.useLearningSchedule=='LR_test':
            start_lr=0.000001
            slope = (1-start_lr)/self.iterations
            lr_lambda = lambda step_num: start_lr + slope*step_num
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,lr_lambda)
        elif self.useLearningSchedule=='cyclic': #only decreasing
            min_lr_mul = config['trainer']['min_lr_mul'] if 'min_lr_mul' in config['trainer'] else 0.001
            cycle_size = config['trainer']['cycle_size'] if 'cycle_size' in config['trainer'] else 500
            lr_lambda = lambda step_num: (1-(1-min_lr_mul)*((step_num-1)%cycle_size)/(cycle_size-1))
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,lr_lambda)
        elif self.useLearningSchedule=='cyclic-full':
            min_lr_mul = config['trainer']['min_lr_mul'] if 'min_lr_mul' in config['trainer'] else 0.25
            cycle_size = config['trainer']['cycle_size'] if 'cycle_size' in config['trainer'] else 500
            def trueCycle (step_num):
                cycle_num = step_num//cycle_size
                if cycle_num%2==0: #even, rising
                    return ((1-min_lr_mul)*((step_num)%cycle_size)/(cycle_size-1)) + min_lr_mul
                else: #odd
                    return (1-(1-min_lr_mul)*((step_num)%cycle_size)/(cycle_size-1))
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,trueCycle)
        elif self.useLearningSchedule=='cyclic-decay':
            min_lr_mul = config['trainer']['min_lr_mul'] if 'min_lr_mul' in config['trainer'] else 0.25
            cycle_size = config['trainer']['cycle_size'] if 'cycle_size' in config['trainer'] else 500
            decay_rate = config['trainer']['decay_rate'] if 'decay_rate' in config['trainer'] else 0.99994 #saturates at about 50000 iterations
            def decayCycle (step_num):
                cycle_num = step_num//cycle_size
                decay = decay_rate**step_num
                if cycle_num%2==0: #even, rising
                    return decay*((1-min_lr_mul)*((step_num)%cycle_size)/(cycle_size-1)) + min_lr_mul
                else: #odd
                    return -decay*(1-min_lr_mul)*((step_num)%cycle_size)/(cycle_size-1) + 1-(1-min_lr_mul)*(1-decay)
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,decayCycle)
        elif self.useLearningSchedule=='1cycle':
            low_lr_mul = config['trainer']['low_lr_mul'] if 'low_lr_mul' in config['trainer'] else 0.25
            min_lr_mul = config['trainer']['min_lr_mul'] if 'min_lr_mul' in config['trainer'] else 0.0001
            cycle_size = config['trainer']['cycle_size'] if 'cycle_size' in config['trainer'] else 1000
            iters_in_trailoff = self.iterations-(2*cycle_size)
            def oneCycle (step_num):
                cycle_num = step_num//cycle_size
                if step_num<cycle_size: #rising
                    return ((1-low_lr_mul)*((step_num)%cycle_size)/(cycle_size-1)) + low_lr_mul
                elif step_num<cycle_size*2: #falling
                    return (1-(1-low_lr_mul)*((step_num)%cycle_size)/(cycle_size-1))
                else: #trail off
                    t_step_num = step_num-(2*cycle_size)
                    return low_lr_mul*(iters_in_trailoff-t_step_num)/iters_in_trailoff + min_lr_mul*t_step_num/iters_in_trailoff

            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,oneCycle)
        elif self.useLearningSchedule=='detector':
            warmup_steps = config['trainer']['warmup_steps'] if 'warmup_steps' in config['trainer'] else 1000
            lr_lambda = lambda step_num: min((step_num+1)**-0.3, (step_num+1)*warmup_steps**-1.3)
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,lr_lambda)
        elif self.useLearningSchedule=='step':
            steps = config['trainer']['lr_steps']
            assert(type(steps) is list)
            def stepLR(step_num):
                mul=1
                for step in steps:
                    if step_num>=step:
                        mul*=0.1
                return mul
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,stepLR)
        elif self.useLearningSchedule=='multi_rise':
            steps = config['trainer']['warmup_steps']
            assert(type(steps) is list)
            steps=[0]+steps
            def riseLR(step_num):
                for i,step in enumerate(steps[1:]):
                    if step_num<step:
                        return (step_num-steps[i])*(0.99/(step-steps[i]))+.01
                return 1.0
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,riseLR)
        elif self.useLearningSchedule=='multi_rise then swa':
            steps = config['trainer']['warmup_steps']
            warmup_cap = 1.0
            swa_lr_mul = config['trainer']['swa_lr_mul'] if 'swa_lr_mul' in config['trainer'] else 0.001
            assert(type(steps) is list)
            steps=[0]+steps
            def riseLR(step_num):
                if step_num<self.swa_start:
                    for i,step in enumerate(steps[1:]):
                        if step_num<step:
                            return warmup_cap*((step_num-steps[i])*(0.99/(step-steps[i]))+.01)
                    return 1.0
                else:
                    return swa_lr_mul
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,riseLR)
        elif self.useLearningSchedule=='multi_rise then ramp_to_swa':
            steps = config['trainer']['warmup_steps']
            down_steps = config['trainer']['ramp_down_steps']
            warmup_cap = 1.0
            swa_lr_mul = config['trainer']['swa_lr_mul'] if 'swa_lr_mul' in config['trainer'] else 0.001
            assert(type(steps) is list)
            steps=[0]+steps
            def riseLR(step_num):
                if step_num<self.swa_start-down_steps:
                    for i,step in enumerate(steps[1:]):
                        if step_num<step:
                            return warmup_cap*((step_num-steps[i])*(0.99/(step-steps[i]))+.01)
                    return 1.0
                elif step_num<self.swa_start:
                    return 1 - (1-swa_lr_mul)*(down_steps-(self.swa_start-step_num))/down_steps
                else:
                    return swa_lr_mul
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,riseLR)
        elif self.useLearningSchedule=='multi_rise with cyclic_full then swa':
            steps = config['trainer']['warmup_steps']
            warmup_cap = config['trainer']['warmup_cap']
            min_lr_mul = config['trainer']['min_lr_mul'] if 'min_lr_mul' in config['trainer'] else 0.25
            cycle_size = config['trainer']['cycle_size']
            swa_lr_mul = config['trainer']['swa_lr_mul'] if 'swa_lr_mul' in config['trainer'] else 0.001
            def trueCycle (step_num):
                cycle_num = step_num//cycle_size
                if cycle_num%2==0: #even, rising
                    return ((1-min_lr_mul)*((step_num)%cycle_size)/(cycle_size-1)) + min_lr_mul
                else: #odd
                    return (1-(1-min_lr_mul)*((step_num)%cycle_size)/(cycle_size-1))
            assert(type(steps) is list)
            steps=[0]+steps
            def riseLR(step_num):
                if step_num<self.swa_start:
                    for i,step in enumerate(steps[1:]):
                        if step_num<step:
                            return warmup_cap*((step_num-steps[i])*(0.99/(step-steps[i]))+.01)
                    return trueCycle(step_num)
                else:
                    return swa_lr_mul
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,riseLR)
        elif self.useLearningSchedule=='spike then swa':
            warmup_steps = config['trainer']['warmup_steps'] if 'warmup_steps' in config['trainer'] else 1000
            swa_lr_mul = config['trainer']['swa_lr_mul'] if 'swa_lr_mul' in config['trainer'] else 0.1
            def spikeThenSWA(step_num):
                if step_num<self.swa_start:
                    return min((max(0.000001,step_num-(warmup_steps-3))/100)**-0.1, step_num*(1.485/warmup_steps)+.01)
                else:
                    return swa_lr_mul
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,spikeThenSWA)
        elif self.useLearningSchedule is True:
            warmup_steps = config['trainer']['warmup_steps'] if 'warmup_steps' in config['trainer'] else 1000
            #lr_lambda = lambda step_num: min((step_num+1)**-0.3, (step_num+1)*warmup_steps**-1.3)
            lr_lambda = lambda step_num: min((max(0.000001,step_num-(warmup_steps-3))/100)**-0.1, step_num*(1.485/warmup_steps)+.01)
            #y=((x-(2000-3))/100)^-0.1 and y=x*(1.485/2000)+0.01
            self.lr_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer,lr_lambda)
        elif self.useLearningSchedule:
            print('Unrecognized learning schedule: {}'.format(self.useLearningSchedule))
            exit()
        
        self.monitor = config['trainer']['monitor']
        self.monitor_mode = config['trainer']['monitor_mode']
        #assert self.monitor_mode == 'min' or self.monitor_mode == 'max'
        self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf
        self.retry_count = config['trainer']['retry_count'] if 'retry_count' in config['trainer'] else 1
        self.start_iteration = 1
        self.iteration=self.start_iteration
        self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], self.name)
        ensure_dir(self.checkpoint_dir)
        json.dump(config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
                  indent=4, sort_keys=False)
        self.iteration=999999999999999
        self.side_process=False
        self.reset_iteration = config['trainer']['reset_resume_iteration'] if 'reset_resume_iteration' in config['trainer'] else False
        if resume:
            self._resume_checkpoint(resume)

    def finishSetup(self):
        """
        things that slave processes shouldn't do
        """
        ensure_dir(self.checkpoint_dir)
        json.dump(self.config, open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'),
                  indent=4, sort_keys=False)

    def train(self):
        """
        Full training logic
        """
        sumLog=defaultdict(lambda:0.0)
        sumTime=0
        #for metric in self.metrics:
        #    sumLog['avg_'+metric.__name__]=0

        for self.iteration in range(self.start_iteration, self.iterations + 1):
            if not self.logged:
                print('iteration: {}'.format(self.iteration), end='\r')

            t = timeit.default_timer()
            result=None
            lastErr=None
            if self.useLearningSchedule:
                self.lr_schedule.step()
            for attempt in range(self.retry_count):
                try:
                    result = self._train_iteration(self.iteration)
                    break
                except RuntimeError as err:
                    print(err)
                    torch.cuda.empty_cache() #this is primarily to catch rare CUDA out of memory errors
                    lastErr = err

            if result is None:
                result = self._train_iteration(self.iteration)
                #if self.retry_count>1:
                #    print('Failed all {} times!'.format(self.retry_count))
                #raise lastErr

            elapsed_time = timeit.default_timer() - t
            sumLog['sec_per_iter'] += elapsed_time
            #print('iter: '+str(elapsed_time))

            #Stochastic Weight Averaging    https://github.com/timgaripov/swa/blob/master/train.py
            if self.swa and self.iteration>=self.swa_start and (self.swa_avg_every==0 or (self.iteration-self.swa_start)%self.swa_avg_every==0):
                #swa_n = (self.iterations-self.swa_start)//self.swa_c_iters
                #moving_average(self.swa_model, self.model, 1.0 / (swa_n + 1))
                #swa_n += 1
                if self.swa_model is None:
                    self.swa_model = AveragedModel(self.model)
                self.swa_model.update_parameters(self.model)

            if self.side_process:
                continue #when multithreading, current log, and validation, is only collected on master


            for key, value in result.items():
                if key == 'metrics':
                    for i, metric in enumerate(self.metrics):
                        sumLog['avg_'+metric.__name__] += result['metrics'][i]
                else:
                    sumLog['avg_'+key] += value
            
            #log prep
            if (    self.iteration%self.log_step==0 or 
                    self.iteration%self.val_step==0 or 
                    self.iteration % self.save_step == 0 or 
                    (self.save_step_minor is not None and self.iteration % self.save_step_minor==0)
                ):
                log = {'iteration': self.iteration}

                for key, value in result.items():
                    if key == 'metrics':
                        for i, metric in enumerate(self.metrics):
                            log[metric.__name__] = result['metrics'][i]
                    else:
                        log[key] = value

            #LOG
            if self.iteration%self.log_step==0:
                #prinpt()#clear inplace text
                print('                   ', end='\r')
                if self.iteration-self.start_iteration>=self.log_step: #skip avg if started in odd spot
                    for key in sumLog:
                        sumLog[key] /= self.log_step
                    #self._minor_log(sumLog)
                    log = {**log, **sumLog}
                self._minor_log(log)
                for key in sumLog:
                    sumLog[key] =0.0
                if self.iteration%self.val_step!=0: #we'll do it later if we have a validation pass
                    self.train_logger.add_entry(log)

            #VALIDATION
            if self.iteration%self.val_step==0:
                if self.swa and self.iteration>=self.swa_start:
                    temp_model = self.model.cpu()
                    self.model = self.swa_model
                    self.bn_update()
                    val_result = self._valid_epoch()
                    self.model = temp_model.cuda()
                    for key, value in val_result.items():
                        if 'metrics' in key:
                            for i, metric in enumerate(self.metrics):
                                log['swa_val_' + metric.__name__] = val_result[key][i]
                        else:
                            log['swa_'+key] = value
                else:
                    val_result = self._valid_epoch()
                    for key, value in val_result.items():
                        if 'metrics' in key:
                            for i, metric in enumerate(self.metrics):
                                log['val_' + metric.__name__] = val_result[key][i]
                        else:
                            log[key] = value
                            #sumLog['avg_'+key] += value

                if self.train_logger is not None:
                    if self.iteration%self.log_step!=0:
                        print('                   ', end='\r')
                    #    print()#clear inplace text
                    self.train_logger.add_entry(log)
                    if self.verbosity >= 1:
                        for key, value in log.items():
                            if self.verbosity>=2 or 'avg' in key or 'val' in key:
                                self.logger.info('    {:15s}: {}'.format(str(key), value))
                if (self.monitor_mode == 'min' and self.monitor in log and log[self.monitor] < self.monitor_best)\
                        or (self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best):
                    self.monitor_best = log[self.monitor]
                    self._save_checkpoint(self.iteration, log, save_best=True)

            #SAVE
            if self.iteration % self.save_step == 0:
                self._save_checkpoint(self.iteration, log)
                if self.iteration%self.log_step!=0:
                    print('                   ', end='\r')
                #    print()#clear inplace text
                self.logger.info('Checkpoint saved for iteration '+str(self.iteration))
            elif self.iteration % self.save_step_minor == 0:
                self._save_checkpoint(self.iteration, log, minor=True)
                if self.iteration%self.log_step!=0:
                    print('                   ', end='\r')
                #    print()#clear inplace text
                #self.logger.info('Minor checkpoint saved for iteration '+str(self.iteration))

            

    def _train_iteration(self, iteration):
        """
        Training logic for a single iteration

        :param iteration: Current iteration number
        """
        raise NotImplementedError

    def save(self):
        self._save_checkpoint(self.iteration, None)

    def _save_checkpoint(self, iteration, log, save_best=False, minor=False):
        """
        Saving checkpoints

        :param iteration: current iteration number
        :param log: logging information of the ipoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'iteration': iteration,
            'logger': self.train_logger,
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.monitor_best,
            'config': self.config
        }
        if 'save_mode' not in self.config or self.config['save_mode']=='state_dict':
            state_dict = self.model.state_dict()
            for k,v in state_dict.items():
                state_dict[k]=v.cpu()
            state['state_dict']= state_dict
            if self.swa and self.swa_model is not None:
                swa_state_dict = self.swa_model.state_dict()
                for k,v in swa_state_dict.items():
                    swa_state_dict[k]=v.cpu()
                state['swa_state_dict']= swa_state_dict
        else:
            state['model'] = self.model.cpu()
            if self.swa:
                state['swa_model'] = self.swa_model.cpu()
        if self.useLearningSchedule:
            state['lr_schedule'] = self.lr_schedule.state_dict()
        #if self.swa:
        #    state['swa_n']=self.swa_n
        torch.cuda.empty_cache() #weird gpu memory issue when calling torch.save()
        if not minor:
            filename = os.path.join(self.checkpoint_dir, 'checkpoint-iteration{}.pth'
                                    .format(iteration))
        else:
            filename = os.path.join(self.checkpoint_dir, 'checkpoint-latest.pth')
                            
        #print(self.module.state_dict().keys())
        torch.save(state, filename)
        if not minor:
            #remove minor as this is the latest
            filename_late = os.path.join(self.checkpoint_dir, 'checkpoint-latest.pth')
            try:
                os.remove(filename_late)
            except FileNotFoundError:
                pass
            #os.link(filename,filename_late) #this way checkpoint-latest always does have the latest
            torch.save(state, filename_late) #something is wrong with thel inkgin

        if save_best:
            os.rename(filename, os.path.join(self.checkpoint_dir, 'model_best.pth'))
            self.logger.info("Saved current best: {} ...".format('model_best.pth'))
        else:
            self.logger.info("Saved checkpoint: {} ...".format(filename))


        ######DEBUG
        #checkpoint = torch.load(filename)
        #model_dict=self.model.state_dict()
        #for name in checkpoint['state_dict']:
            #if (checkpoint['state_dict'][name]!=model_dict[name]).any():
                #        print('state not equal at: '+name)
        #        import pdb; pdb.set_trace()

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=lambda storage, location: storage)
        if 'override' not in self.config or not self.config['override']:
            self.config = checkpoint['config']
        if not self.reset_iteration:
            self.start_iteration = checkpoint['iteration'] + 1
            self.iteration=self.start_iteration
        self.monitor_best = checkpoint['monitor_best']
        #print(checkpoint['state_dict'].keys())
        if ('save_mode' not in self.config or self.config['save_mode']=='state_dict') and 'state_dict' in checkpoint:
            #Brain surgery, allow restarting with modified model
            did_brain_surgery=False
            keys=checkpoint['state_dict'].keys()
            init_state_dict = self.model.state_dict()
            for key in keys:
                if len(init_state_dict[key].size())>0 and init_state_dict[key].size(0)>checkpoint['state_dict'][key].size(0):
                    orig_size = checkpoint['state_dict'][key].size(0)
                    init_state_dict[key][:orig_size] = checkpoint['state_dict'][key]
                    checkpoint['state_dict'][key] = init_state_dict[key]
                    self.logger.info('BRAIN SURGERY PERFORMED on {}'.format(key))
                    did_brain_surgery=True
            self.model.load_state_dict(checkpoint['state_dict'])
            if self.swa and 'swa_state_dict' in checkpoint:
                self.swa_model = AveragedModel(self.model)
                keys=checkpoint['swa_state_dict'].keys()
                init_state_dict = self.swa_model.state_dict()
                for key in keys:
                    if torch.is_tensor(init_state_dict[key]) and len(init_state_dict[key].size())>0 and init_state_dict[key].size(0)>checkpoint['swa_state_dict'][key].size(0):
                        orig_size = checkpoint['swa_state_dict'][key].size(0)
                        init_state_dict[key][:orig_size] = checkpoint['swa_state_dict'][key]
                        checkpoint['swa_state_dict'][key] = init_state_dict[key]
                        self.logger.info('BRAIN SURGERY PERFORMED on {}'.format(key))
                self.swa_model.load_state_dict(checkpoint['swa_state_dict'])
        else:
            self.model = checkpoint['model']
            if self.swa:
                self.swa_model = checkpoint['swa_model']
        #if self.swa:
        #    self.swa_n = checkpoint['swa_n']
        dont_load_optimizer = self.config['dont_load_optimizer'] if 'dont_load_optimizer' in self.config else False
        if not did_brain_surgery and not dont_load_optimizer and 'optimizer' in checkpoint:
            try:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                if self.with_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda(self.gpu)
            except ValueError as e:
                print('WARNING did not load optimizer state_dict. {}'.format(e))
        else:
            print('Did not load optimizer')
        if self.useLearningSchedule:
            self.lr_schedule.load_state_dict(checkpoint['lr_schedule'])
        self.train_logger = checkpoint['logger']
        self.logger.info("Checkpoint '{}' (iteration {}) loaded".format(resume_path, self.start_iteration))

    def update_swa_batch_norm(self):
        #update_bn(self.data_loader,self.swa_model)
        tmp=self.model.cpu()
        self.model=self.swa_model.train()
        for instance in self.data_loader:
            self.run(instance)
        self.model=tmp