def test_model(): """Evaluates the model.""" # Build the model (before the loaders to speed up debugging) model = model_builder.build_model() log_model_info(model) # Compute precise time if cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") loss_fun = losses.get_loss_fun() bu.compute_precise_time(model, loss_fun) nu.reset_bn_stats(model) # Load model weights cu.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) # Create data loaders test_loader = loader.construct_test_loader() # Create meters test_meter = TestMeter(len(test_loader)) # Evaluate the model test_epoch(test_loader, model, test_meter, 0)
def train_model(): """Trains the model.""" # Build the model (before the loaders to speed up debugging) model = model_builder.build_model() log_model_info(model) # Define the loss function loss_fun = losses.get_loss_fun() # Construct the optimizer optimizer = optim.construct_optimizer(model) # Load checkpoint or initial weights start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint(): last_checkpoint = cu.get_last_checkpoint() checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model, optimizer) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.TRAIN.WEIGHTS: cu.load_checkpoint(cfg.TRAIN.WEIGHTS, model) logger.info("Loaded initial weights from: {}".format( cfg.TRAIN.WEIGHTS)) # Compute precise time if start_epoch == 0 and cfg.PREC_TIME.ENABLED: logger.info("Computing precise time...") bu.compute_precise_time(model, loss_fun) nu.reset_bn_stats(model) # Create data loaders train_loader = loader.construct_train_loader() test_loader = loader.construct_test_loader() # Create meters train_meter = TrainMeter(len(train_loader)) test_meter = TestMeter(len(test_loader)) # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train for one epoch train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) # Compute precise BN stats if cfg.BN.USE_PRECISE_STATS: nu.compute_precise_bn_stats(model, train_loader) # Save a checkpoint if cu.is_checkpoint_epoch(cur_epoch): checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model if is_eval_epoch(cur_epoch): test_epoch(test_loader, model, test_meter, cur_epoch)
def test_model(): """Evaluates the model.""" # Build the model (before the loaders to speed up debugging) model = model_builder.build_model() log_model_info(model) # Load model weights cu.load_checkpoint(cfg.TEST.WEIGHTS, model) logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) # Create data loaders test_loader = loader.construct_test_loader() # Create meters test_meter = TestMeter(len(test_loader)) # Evaluate the model test_epoch(test_loader, model, test_meter, 0)
def main(): # load pretrained model checkpoint = torch.load(args.checkpoint_path) try: model_arch = checkpoint['model_name'] patch_size = checkpoint['patch_size'] prime_size = checkpoint['patch_size'] flops = checkpoint['flops'] model_flops = checkpoint['model_flops'] policy_flops = checkpoint['policy_flops'] fc_flops = checkpoint['fc_flops'] anytime_classification = checkpoint['anytime_classification'] budgeted_batch_classification = checkpoint[ 'budgeted_batch_classification'] dynamic_threshold = checkpoint['dynamic_threshold'] maximum_length = len(checkpoint['flops']) except: print( 'Error: \n' 'Please provide essential information' 'for customized models (as we have done ' 'in pre-trained models)!\n' 'At least the following information should be Given: \n' '--model_name: name of the backbone CNNs (e.g., resnet50, densenet121)\n' '--patch_size: size of image patches (i.e., H\' or W\' in the paper)\n' '--flops: a list containing the Multiply-Adds corresponding to each ' 'length of the input sequence during inference') model_configuration = model_configurations[model_arch] if args.eval_mode > 0: # create model if 'resnet' in model_arch: model = resnet.resnet50(pretrained=False) model_prime = resnet.resnet50(pretrained=False) elif 'densenet' in model_arch: model = eval('densenet.' + model_arch)(pretrained=False) model_prime = eval('densenet.' + model_arch)(pretrained=False) elif 'efficientnet' in model_arch: model = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) model_prime = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) elif 'mobilenetv3' in model_arch: model = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) model_prime = create_model(model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) elif 'regnet' in model_arch: import pycls.core.model_builder as model_builder from pycls.core.config import cfg cfg.merge_from_file(model_configuration['cfg_file']) cfg.freeze() model = model_builder.build_model() model_prime = model_builder.build_model() traindir = args.data_url + 'train/' valdir = args.data_url + 'val/' normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop( model_configuration['image_size'], interpolation=model_configuration['dataset_interpolation'] ), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) train_set_index = torch.randperm(len(train_set)) train_loader = torch.utils.data.DataLoader( train_set, batch_size=256, num_workers=32, pin_memory=False, sampler=torch.utils.data.sampler.SubsetRandomSampler( train_set_index[-200000:])) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize( int(model_configuration['image_size'] / model_configuration['crop_pct']), interpolation=model_configuration['dataset_interpolation'] ), transforms.CenterCrop(model_configuration['image_size']), transforms.ToTensor(), normalize ])), batch_size=256, shuffle=False, num_workers=16, pin_memory=False) state_dim = model_configuration['feature_map_channels'] * math.ceil( patch_size / 32) * math.ceil(patch_size / 32) memory = Memory() policy = ActorCritic(model_configuration['feature_map_channels'], state_dim, model_configuration['policy_hidden_dim'], model_configuration['policy_conv']) fc = Full_layer(model_configuration['feature_num'], model_configuration['fc_hidden_dim'], model_configuration['fc_rnn']) model = nn.DataParallel(model.cuda()) model_prime = nn.DataParallel(model_prime.cuda()) policy = policy.cuda() fc = fc.cuda() model.load_state_dict(checkpoint['model_state_dict']) model_prime.load_state_dict(checkpoint['model_prime_state_dict']) fc.load_state_dict(checkpoint['fc']) policy.load_state_dict(checkpoint['policy']) budgeted_batch_flops_list = [] budgeted_batch_acc_list = [] print('generate logits on test samples...') test_logits, test_targets, anytime_classification = generate_logits( model_prime, model, fc, memory, policy, val_loader, maximum_length, prime_size, patch_size, model_arch) if args.eval_mode == 2: print('generate logits on training samples...') dynamic_threshold = torch.zeros([39, maximum_length]) train_logits, train_targets, _ = generate_logits( model_prime, model, fc, memory, policy, train_loader, maximum_length, prime_size, patch_size, model_arch) for p in range(1, 40): print('inference: {}/40'.format(p)) _p = torch.FloatTensor(1).fill_(p * 1.0 / 20) probs = torch.exp(torch.log(_p) * torch.range(1, maximum_length)) probs /= probs.sum() if args.eval_mode == 2: dynamic_threshold[p - 1] = dynamic_find_threshold( train_logits, train_targets, probs) acc_step, flops_step = dynamic_evaluate(test_logits, test_targets, flops, dynamic_threshold[p - 1]) budgeted_batch_acc_list.append(acc_step) budgeted_batch_flops_list.append(flops_step) budgeted_batch_classification = [ budgeted_batch_flops_list, budgeted_batch_acc_list ] print('model_arch :', model_arch) print('patch_size :', patch_size) print('flops :', flops) print('model_flops :', model_flops) print('policy_flops :', policy_flops) print('fc_flops :', fc_flops) print('anytime_classification :', anytime_classification) print('budgeted_batch_classification :', budgeted_batch_classification)
def main(): if not os.path.isdir(args.work_dirs): mkdir_p(args.work_dirs) record_path = args.work_dirs + '/GF-' + str(args.model_arch) \ + '_patch-size-' + str(args.patch_size) \ + '_T' + str(args.T) \ + '_train-stage' + str(args.train_stage) if not os.path.isdir(record_path): mkdir_p(record_path) record_file = record_path + '/record.txt' # *create model* # model_configuration = model_configurations[args.model_arch] if 'resnet' in args.model_arch: model_arch = 'resnet' model = resnet.resnet50(pretrained=False) model_prime = resnet.resnet50(pretrained=False) elif 'densenet' in args.model_arch: model_arch = 'densenet' model = eval('densenet.' + args.model_arch)(pretrained=False) model_prime = eval('densenet.' + args.model_arch)(pretrained=False) elif 'efficientnet' in args.model_arch: model_arch = 'efficientnet' model = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.3, drop_connect_rate=0.2) elif 'mobilenetv3' in args.model_arch: model_arch = 'mobilenetv3' model = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) model_prime = create_model(args.model_arch, pretrained=False, num_classes=1000, drop_rate=0.2, drop_connect_rate=0.2) elif 'regnet' in args.model_arch: model_arch = 'regnet' import pycls.core.model_builder as model_builder from pycls.core.config import cfg cfg.merge_from_file(model_configuration['cfg_file']) cfg.freeze() model = model_builder.build_model() model_prime = model_builder.build_model() fc = Full_layer(model_configuration['feature_num'], model_configuration['fc_hidden_dim'], model_configuration['fc_rnn']) if args.train_stage == 1: model.load_state_dict(torch.load(args.model_path)) model_prime.load_state_dict(torch.load(args.model_prime_path)) else: checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) model_prime.load_state_dict(checkpoint['model_prime_state_dict']) fc.load_state_dict(checkpoint['fc']) train_configuration = train_configurations[model_arch] if args.train_stage != 2: if train_configuration['train_model_prime']: optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': model_prime.parameters()}, {'params': fc.parameters()}], lr=0, # specify in adjust_learning_rate() momentum=train_configuration['momentum'], nesterov=train_configuration['Nesterov'], weight_decay=train_configuration['weight_decay']) else: optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': fc.parameters()}], lr=0, # specify in adjust_learning_rate() momentum=train_configuration['momentum'], nesterov=train_configuration['Nesterov'], weight_decay=train_configuration['weight_decay']) training_epoch_num = train_configuration['epoch_num'] else: optimizer = None training_epoch_num = 15 criterion = nn.CrossEntropyLoss().cuda() model = nn.DataParallel(model.cuda()) model_prime = nn.DataParallel(model_prime.cuda()) fc = fc.cuda() traindir = args.data_url + 'train/' valdir = args.data_url + 'val/' normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = datasets.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_set_index = torch.randperm(len(train_set)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False, sampler=torch.utils.data.sampler.SubsetRandomSampler( train_set_index[:])) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=train_configuration['batch_size'], shuffle=False, num_workers=32, pin_memory=False) if args.train_stage != 1: state_dim = model_configuration['feature_map_channels'] * math.ceil(args.patch_size / 32) * math.ceil(args.patch_size / 32) ppo = PPO(model_configuration['feature_map_channels'], state_dim, model_configuration['policy_hidden_dim'], model_configuration['policy_conv']) if args.train_stage == 3: ppo.policy.load_state_dict(checkpoint['policy']) ppo.policy_old.load_state_dict(checkpoint['policy']) else: ppo = None memory = Memory() if args.resume: resume_ckp = torch.load(args.resume) start_epoch = resume_ckp['epoch'] print('resume from epoch: {}'.format(start_epoch)) model.module.load_state_dict(resume_ckp['model_state_dict']) model_prime.module.load_state_dict(resume_ckp['model_prime_state_dict']) fc.load_state_dict(resume_ckp['fc']) if optimizer: optimizer.load_state_dict(resume_ckp['optimizer']) if ppo: ppo.policy.load_state_dict(resume_ckp['policy']) ppo.policy_old.load_state_dict(resume_ckp['policy']) ppo.optimizer.load_state_dict(resume_ckp['ppo_optimizer']) best_acc = resume_ckp['best_acc'] else: start_epoch = 0 best_acc = 0 for epoch in range(start_epoch, training_epoch_num): if args.train_stage != 2: print('Training Stage: {}, lr:'.format(args.train_stage)) adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args) else: print('Training Stage: {}, train ppo only'.format(args.train_stage)) train(model_prime, model, fc, memory, ppo, optimizer, train_loader, criterion, args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args) acc = validate(model_prime, model, fc, memory, ppo, optimizer, val_loader, criterion, args.print_freq, epoch, train_configuration['batch_size'], record_file, train_configuration, args) if acc > best_acc: best_acc = acc is_best = True else: is_best = False save_checkpoint({ 'epoch': epoch + 1, 'model_state_dict': model.module.state_dict(), 'model_prime_state_dict': model_prime.module.state_dict(), 'fc': fc.state_dict(), 'acc': acc, 'best_acc': best_acc, 'optimizer': optimizer.state_dict() if optimizer else None, 'ppo_optimizer': ppo.optimizer.state_dict() if ppo else None, 'policy': ppo.policy.state_dict() if ppo else None, }, is_best, checkpoint=record_path)