def load_model(file, model_configs, trans_configs=None, gpu_device=None): if torch.cuda.is_available(): if gpu_device and gpu_device.startswith("cuda:"): gpu_device = gpu_device else: gpu_device = "cuda:0" torch.cuda.set_device(torch.device(gpu_device)) device = torch.device(gpu_device) else: device = torch.device("cpu") if trans_configs is None: # load the undefended model by default trans_configs = { "type": "clean", "subtype": "", "id": 0, "description": "clean" } network_configs = model_configs.get("network_configs") model = get_model( model_type=network_configs.get("model_type"), num_class=num_class(model_configs.get("dataset")), data_parallel=torch.cuda.is_available(), device=device ) loss_func = nn.CrossEntropyLoss() optimizer = optim.SGD( model.parameters(), lr=network_configs.get("lr"), momentum=network_configs.get("optimizer_momentum"), weight_decay=network_configs.get("optimizer_decay"), nesterov=network_configs.get("optimizer_nesterov") ) if os.path.isfile(file): print(">>> Loading model from [{}]...".format(file)) data = torch.load(file, map_location=lambda storage, loc: storage) data_key = list(data.keys())[0] print(f">>> DATA_KEY: {data_key}") if 'model' in data or 'state_dict' in data: key = 'model' if 'model' in data else 'state_dict' #if data['epoch']: # print('checkpoint epoch@%d' % data['epoch']) if not isinstance(model, DataParallel): if data_key.startswith('classifier.module.'): model.load_state_dict({k.replace('classifier.module.', ''): v for k, v in data[key].items()}) elif data_key.startswith('module.'): model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()}) elif data_key.startswith('classifier.module.'): model.load_state_dict({k.replace('classifier.module', ''): v for k, v in data[key].items()}) #model.load_state_dict({k if 'classifier.module.' in k else 'classifier.module.'+k: v for k, v in data[key].items()}) else: model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()}) optimizer.load_state_dict(data['optimizer']) elif data_key.startswith('module.'): model.load_state_dict({k.replace('module.', ''): v for k, v in data.items()}) optimizer.load_state_dict(data['optimizer']) else: model.load_state_dict({k: v for k, v in data.items()}) del data else: raise ValueError(file, 'is not found.') model.eval() if model_configs.get("wrap", True): print("Wrap model") model = PyTorchWD( model=model, loss=loss_func, optimizer=optimizer, input_shape=IMAGE_SHAPE, nb_classes=num_class(model_configs.get("dataset")), trans_configs=trans_configs, channel_index=1, clip_values=(0., 1.) ) return model, loss_func, optimizer
def train_and_eval(tag, dataroot, trans_type=TRANSFORMATION.clean, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False): print('----------------------------') print('Augments for model training') print('>>> tag:', tag) print('>>> dataroot:', dataroot) print('>>> save_path:', save_path) print('>>> eval:', only_eval) print('----------------------------') if not reporter: reporter = lambda **kwargs: 0 max_epoch = C.get()['epoch'] start = time.monotonic() trainsampler, trainloader, validloader, testloader_ = get_dataloaders( C.get()['dataset'], C.get()['batch'], dataroot, trans_type=trans_type, split=test_ratio, split_idx=cv_fold) trans_cost = time.monotonic() - start print('Cost for transformation:', round(trans_cost / 60., 6)) # create a model & an optimizer model = get_model(C.get()['model'], num_class(C.get()['dataset']), data_parallel=True) criterion = nn.CrossEntropyLoss() if C.get()['optimizer']['type'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=C.get()['lr'], momentum=C.get()['optimizer'].get( 'momentum', 0.9), weight_decay=C.get()['optimizer']['decay'], nesterov=C.get()['optimizer']['nesterov']) else: raise ValueError( 'Optimizer type [{}] is not yet supported, SGD is the only optimizer supported.' .format(C.get()['optimizer']['type'])) is_master = True logger.debug('is_master={}'.format(is_master)) lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') if lr_scheduler_type == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=C.get()['epoch'], eta_min=0.) elif lr_scheduler_type == 'resnet': scheduler = adjust_learning_rate_resnet(optimizer) else: raise ValueError('invalid lr_schduler={}'.format(lr_scheduler_type)) if C.get()['lr_schedule'].get('warmup', None): scheduler = GradualWarmupScheduler( optimizer, multiplier=C.get()['lr_schedule']['warmup']['multiplier'], total_epoch=C.get()['lr_schedule']['warmup']['epoch'], after_scheduler=scheduler) if not tag or not is_master: from models.utils.estimator import SummaryWriterDummy as SummaryWriter logger.warning('tag not provided, no tensorboard log.') else: from tensorboardX import SummaryWriter writers = [ SummaryWriter(log_dir='./logs/{}/{}'.format(tag, x)) for x in ['train', 'valid', 'test'] ] result = OrderedDict() epoch_start = 1 if save_path and os.path.exists(save_path): logger.info('Found file [{}]. Loading...'.format(save_path)) data = torch.load(save_path) if 'model' in data or 'state_dict' in data: key = 'model' if 'model' in data else 'state_dict' logger.info('checkpoint epoch@{}'.format(data['epoch'])) if not isinstance(model, DataParallel): model.load_state_dict({ k.replace('module.', ''): v for k, v in data[key].items() }) else: model.load_state_dict({ k if 'module.' in k else 'module.' + k: v for k, v in data[key].items() }) optimizer.load_state_dict(data['optimizer']) if data['epoch'] < C.get()['epoch']: epoch_start = data['epoch'] else: only_eval = True else: model.load_state_dict({k: v for k, v in data.items()}) del data else: logger.info('[{}] file not found. Skip to pretrain weights...'.format( save_path)) if only_eval: logger.warning( 'model checkpoint not found. only-evaluation mode is off.') only_eval = False if only_eval: logger.info('evaluation only+') model.eval() rs = dict() rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0]) rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1]) rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2]) for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): if setname not in rs: continue result['{}_{}'.format(key, setname)] = rs[setname][key] result['epoch'] = 0 return result # train loop best_top1 = 0 for epoch in range(epoch_start, max_epoch + 1): model.train() rs = dict() rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=is_master, scheduler=scheduler) model.eval() if math.isnan(rs['train']['loss']): raise Exception('train loss is NaN.') if epoch % 5 == 0 or epoch == max_epoch: rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=is_master) rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=is_master) if metric == 'last' or rs[metric]['top1'] > best_top1: if metric != 'last': best_top1 = rs[metric]['top1'] for key, setname in itertools.product( ['loss', 'top1', 'top5'], ['train', 'valid', 'test']): result['{}_{}'.format(key, setname)] = rs[setname][key] result['epoch'] = epoch writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) reporter(loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], loss_test=rs['test']['loss'], top1_test=rs['test']['top1']) # save checkpoint if is_master and save_path: logger.info('save model@%d to %s' % (epoch, save_path)) torch.save( { 'epoch': epoch, 'log': { 'train': rs['train'].get_dict(), 'valid': rs['valid'].get_dict(), 'test': rs['test'].get_dict(), }, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() }, save_path) torch.save( { 'epoch': epoch, 'log': { 'train': rs['train'].get_dict(), 'valid': rs['valid'].get_dict(), 'test': rs['test'].get_dict(), }, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() }, save_path.replace( '.pth', '_e%d_top1_%.3f_%.3f' % (epoch, rs['train']['top1'], rs['test']['top1']) + '.pth')) del model torch.cuda.empty_cache() result['top1_test'] = best_top1 result['trans_cost'] = trans_cost return result