config.net_config, net_type = utils.load_net_config( os.path.join(args.load_path, 'net_config')) derivedNetwork = getattr(model_derived, '%s_Net' % net_type.upper()) model = derivedNetwork(config.net_config, config=config, num_classes=1000) logging.info("Network Structure: \n" + '\n'.join(map(str, model.net_config))) logging.info("Params = %.2fMB" % utils.count_parameters_in_MB(model)) logging.info("Mult-Adds = %.2fMB" % comp_multadds(model, input_size=config.data.input_size)) model = model.cuda() model = nn.DataParallel(model) checkpoint = torch.load(os.path.join(args.load_path, 'weight.pt'), map_location="cpu") # weight checkpoint model.load_state_dict(checkpoint['state_dict'], strict=False) imagenet = imagenet_data.ImageNet12( trainFolder=os.path.join(args.data_path, 'train'), testFolder=os.path.join(args.data_path, 'val'), num_workers=config.data.num_workers, data_config=config.data) valid_queue = imagenet.getTestLoader(config.data.batch_size) trainer = Trainer(None, valid_queue, None, None, None, config, args.report_freq) with torch.no_grad(): val_acc_top1, val_acc_top5, valid_obj, batch_time = trainer.infer( model)
model = model.cuda() if config.optim.label_smooth: criterion = utils.cross_entropy_with_label_smoothing else: criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() optimizer = torch.optim.SGD(model.parameters(), config.optim.init_lr, momentum=config.optim.momentum, weight_decay=config.optim.weight_decay) imagenet = imagenet_data.ImageNet12( trainFolder=os.path.join(args.data_path, 'train'), testFolder=os.path.join(args.data_path, 'val'), num_workers=config.data.num_workers, type_of_data_augmentation=config.data.type_of_data_aug, data_config=config.data) if config.optim.use_multi_stage: (train_queue, week_train_queue), valid_queue = imagenet.getSetTrainTestLoader( config.data.batch_size) else: train_queue, valid_queue = imagenet.getTrainTestLoader( config.data.batch_size) scheduler = get_lr_scheduler(config, optimizer, train_queue.dataset.__len__()) scheduler.last_step = start_epoch * ( train_queue.dataset.__len__() // config.data.batch_size + 1) - 1
def main(): global best_err1, best_err5 if config.train_params.use_seed: utils.set_seed(config.train_params.seed) imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(config.data.data_path, 'train'), testFolder=os.path.join(config.data.data_path, 'val'), num_workers=config.data.num_workers, type_of_data_augmentation=config.data.type_of_data_aug, data_config=config.data) train_loader, val_loader = imagenet.getTrainTestLoader(config.data.batch_size) if config.net_type == 'mobilenet': t_net = ResNet.resnet50(pretrained=True) s_net = Mov.MobileNet() elif config.net_type == 'resnet': t_net = ResNet.resnet34(pretrained=True) s_net = ResNet.resnet18(pretrained=False) else: print('undefined network type !!!') raise RuntimeError('%s does not support' % config.net_type) import knowledge_distiller d_net = knowledge_distiller.WSLDistiller(t_net, s_net) print('Teacher Net: ') print(t_net) print('Student Net: ') print(s_net) print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()]))) print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()]))) t_net = torch.nn.DataParallel(t_net) s_net = torch.nn.DataParallel(s_net) d_net = torch.nn.DataParallel(d_net) if config.optim.if_resume: checkpoint = torch.load(config.optim.resume_path) d_net.module.load_state_dict(checkpoint['train_state_dict']) best_err1 = checkpoint['best_err1'] best_err5 = checkpoint['best_err5'] start_epoch = checkpoint['epoch'] + 1 else: start_epoch = 0 t_net = t_net.cuda() s_net = s_net.cuda() d_net = d_net.cuda() ### choose optimizer parameters optimizer = torch.optim.SGD(list(s_net.parameters()), config.optim.init_lr, momentum=config.optim.momentum, weight_decay=config.optim.weight_decay, nesterov=True) cudnn.benchmark = True cudnn.enabled = True print('Teacher network performance') validate(val_loader, t_net, 0) for epoch in range(start_epoch, config.train_params.epochs + 1): adjust_learning_rate(optimizer, epoch) # train for one epoch train_with_distill(train_loader, d_net, optimizer, epoch) # evaluate on validation set err1, err5 = validate(val_loader, s_net, epoch) # remember best prec@1 and save checkpoint is_best = err1 <= best_err1 best_err1 = min(err1, best_err1) if is_best: best_err5 = err5 print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) save_checkpoint({ 'epoch': epoch, 'state_dict': s_net.module.state_dict(), 'train_state_dict': d_net.module.state_dict(), 'best_err1': best_err1, 'best_err5': best_err5, 'optimizer': optimizer.state_dict(), }, is_best) gc.collect() print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)