def main(): global args, best_err1, best_err5 args = parser.parse_args() traindir = os.path.join(args.data_path, 'train') valdir = os.path.join(args.data_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_trm = [transforms.RandomResizedCrop(args.input_res), transforms.RandomHorizontalFlip()] if args.autoaug: train_trm.append(AutoAugment()) if args.cutout: train_trm.append(Cutout(length=args.input_res // 4)) train_trm.append(transforms.ToTensor()) train_trm.append(normalize) train_trm = transforms.Compose(train_trm) train_dataset = datasets.ImageFolder(traindir, transform=train_trm) val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(int(args.input_res / 0.875)), transforms.CenterCrop(args.input_res), transforms.ToTensor(), normalize]) ) num_classes = len(train_dataset.classes) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.net_type == 'self_sup': t_net = ResNet.resnet50(pretrained=False, num_classes=num_classes) s_net = ResNet.resnet50(pretrained=False, num_classes=num_classes) # load from pre-trained, before DistributedDataParallel constructor if args.pretrained: if os.path.isfile(args.pretrained): print("=> loading checkpoint '{}'".format(args.pretrained)) checkpoint = torch.load(args.pretrained, map_location="cpu") # rename moco pre-trained keys state_dict = checkpoint['state_dict'] for k in list(state_dict.keys()): # retain only encoder_q up to before the embedding layer if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): # remove prefix state_dict[k[len("module.encoder_q."):]] = state_dict[k] # delete renamed or unused k del state_dict[k] args.start_epoch = 0 msg = t_net.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} msg = s_net.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} print("=> loaded pre-trained model '{}'".format(args.pretrained)) else: print("=> no checkpoint found at '{}'".format(args.pretrained)) elif args.net_type == 'infomin_self': from collections import OrderedDict t_net = ResNet.resnet50(pretrained=False, num_classes=num_classes) s_net = ResNet.resnet50(pretrained=False, num_classes=num_classes) if args.pretrained: ckpt = torch.load(args.pretrained, map_location='cpu') state_dict = ckpt['model'] encoder_state_dict = OrderedDict() for k, v in state_dict.items(): k = k.replace('module.', '') if 'encoder' in k: k = k.replace('encoder.', '') encoder_state_dict[k] = v msg = t_net.load_state_dict(encoder_state_dict, strict=False) print(set(msg.missing_keys)) msg = s_net.load_state_dict(encoder_state_dict, strict=False) print(set(msg.missing_keys)) elif args.net_type == 'r50': s_net = ResNet.resnet50(pretrained=False, num_classes=num_classes) else: print('undefined network type !!!') raise if not args.no_tea: d_net = distiller.Distiller(t_net, s_net) if not args.no_tea: print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()]))) print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()]))) if not args.no_tea: t_net = torch.nn.DataParallel(t_net).cuda() s_net = torch.nn.DataParallel(s_net).cuda() d_net = torch.nn.DataParallel(d_net).cuda() else: s_net = torch.nn.DataParallel(s_net).cuda() # define loss function (criterion) and optimizer if not args.label_smooth: criterion_CE = nn.CrossEntropyLoss().cuda() else: criterion_CE = CrossEntropyLabelSmooth(num_classes=num_classes) if not args.no_tea: optimizer = torch.optim.SGD(list(s_net.parameters()) + list(d_net.module.Connectors.parameters()), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) else: optimizer = torch.optim.SGD(list(s_net.parameters()), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) cudnn.benchmark = True for epoch in range(1, args.epochs+1): adjust_learning_rate(optimizer, epoch) # train for one epoch if not args.no_tea: if args.mixup: train_with_distill_mixup(train_loader, d_net, optimizer, criterion_CE, epoch, args.no_distill_epoch) else: train_with_distill(train_loader, d_net, optimizer, criterion_CE, epoch, args.no_distill_epoch) else: if args.mixup: train_mixup(train_loader, s_net, optimizer, criterion_CE, epoch) else: train(train_loader, s_net, optimizer, criterion_CE, epoch) # evaluate on validation set err1, err5 = validate(val_loader, s_net, criterion_CE, epoch) # remember best prec@1 and save checkpoint is_best = err1 <= best_err1 best_err1 = min(err1, best_err1) if is_best: best_err5 = err5 print ('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) save_checkpoint({ 'epoch': epoch, 'arch': args.net_type, 'state_dict': s_net.state_dict(), 'best_err1': best_err1, 'best_err5': best_err5, 'optimizer' : optimizer.state_dict(), }, is_best) gc.collect() print ('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
def load_model(config, num_classes): if config['model']['type'] == 'resnet': if config['model']['arch'] == 'resnet50': net = ResNet.resnet50(pretrained=False, progress=False, num_classes=num_classes) elif config['model']['arch'] == 'resnext50': net = ResNet.resnext50_32x4d(pretrained=False, progress=False, num_classes=num_classes) elif config['model']['arch'] == 'resnet50d': net = ResNetD.resnet50d(pretrained=False, progress=False, num_classes=num_classes) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'tresnet': if config['model']['arch'] == 'tresnetm': net = TResNet.TResnetM(num_classes=num_classes) elif config['model']['arch'] == 'tresnetl': net = TResNet.TResnetL(num_classes=num_classes) elif config['model']['arch'] == 'tresnetxl': net = TResNet.TResnetXL(num_classes=num_classes) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'regnet': regnet_config = dict() if config['model']['arch'] == 'regnetx-200mf': regnet_config['depth'] = 13 regnet_config['w0'] = 24 regnet_config['wa'] = 36.44 regnet_config['wm'] = 2.49 regnet_config['group_w'] = 8 regnet_config['se_on'] = False regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnetx-600mf': regnet_config['depth'] = 16 regnet_config['w0'] = 48 regnet_config['wa'] = 36.97 regnet_config['wm'] = 2.24 regnet_config['group_w'] = 24 regnet_config['se_on'] = False regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnetx-4.0gf': regnet_config['depth'] = 23 regnet_config['w0'] = 96 regnet_config['wa'] = 38.65 regnet_config['wm'] = 2.43 regnet_config['group_w'] = 40 regnet_config['se_on'] = False regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnetx-6.4gf': regnet_config['depth'] = 17 regnet_config['w0'] = 184 regnet_config['wa'] = 60.83 regnet_config['wm'] = 2.07 regnet_config['group_w'] = 56 regnet_config['se_on'] = False regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnety-200mf': regnet_config['depth'] = 13 regnet_config['w0'] = 24 regnet_config['wa'] = 36.44 regnet_config['wm'] = 2.49 regnet_config['group_w'] = 8 regnet_config['se_on'] = True regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnety-600mf': regnet_config['depth'] = 15 regnet_config['w0'] = 48 regnet_config['wa'] = 32.54 regnet_config['wm'] = 2.32 regnet_config['group_w'] = 16 regnet_config['se_on'] = True regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnety-4.0gf': regnet_config['depth'] = 22 regnet_config['w0'] = 96 regnet_config['wa'] = 31.41 regnet_config['wm'] = 2.24 regnet_config['group_w'] = 64 regnet_config['se_on'] = True regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) elif config['model']['arch'] == 'regnety-6.4gf': regnet_config['depth'] = 25 regnet_config['w0'] = 112 regnet_config['wa'] = 33.22 regnet_config['wm'] = 2.27 regnet_config['group_w'] = 72 regnet_config['se_on'] = True regnet_config['num_classes'] = num_classes net = RegNet.RegNet(regnet_config) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'resnest': if config['model']['arch'] == 'resnest50': net = ResNest.resnest50(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'resnest101': net = ResNest.resnest101(pretrained=False, num_classes=num_classes) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'efficient': if config['model']['arch'] == 'b0': net = EfficientNet.efficientnet_b0(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'b1': net = EfficientNet.efficientnet_b1(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'b2': net = EfficientNet.efficientnet_b2(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'b3': net = EfficientNet.efficientnet_b3(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'b4': net = EfficientNet.efficientnet_b4(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'b5': net = EfficientNet.efficientnet_b5(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'b6': net = EfficientNet.efficientnet_b6(pretrained=False, num_classes=num_classes) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'assembled': pass elif config['model']['type'] == 'shufflenet': if config['model']['arch'] == 'v2_x0_5': net = ShuffleNetV2.shufflenet_v2_x0_5(pretrained=False, progress=False, num_classes=num_classes) elif config['model']['arch'] == 'v2_x1_0': net = ShuffleNetV2.shufflenet_v2_x1_0(pretrained=False, progress=False, num_classes=num_classes) elif config['model']['arch'] == 'v2_x1_5': net = ShuffleNetV2.shufflenet_v2_x1_5(pretrained=False, progress=False, num_classes=num_classes) elif config['model']['arch'] == 'v2_x2_0': net = ShuffleNetV2.shufflenet_v2_x2_0(pretrained=False, progress=False, num_classes=num_classes) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'mobilenet': if config['model']['arch'] == 'small_075': net = Mobilenetv3.mobilenetv3_small_075(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'small_100': net = Mobilenetv3.mobilenetv3_small_100(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'large_075': net = Mobilenetv3.mobilenetv3_large_075(pretrained=False, num_classes=num_classes) elif config['model']['arch'] == 'large_100': net = Mobilenetv3.mobilenetv3_large_100(pretrained=False, num_classes=num_classes) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) elif config['model']['type'] == 'rexnet': if config['model']['arch'] == 'rexnet1.0x': net = ReXNet.rexnet(num_classes=num_classes, width_multi=1.0) elif config['model']['arch'] == 'rexnet1.5x': net = ReXNet.rexnet(num_classes=num_classes, width_multi=1.5) elif config['model']['arch'] == 'rexnet2.0x': net = ReXNet.rexnet(num_classes=num_classes, width_multi=2.0) else: raise ValueError('Unsupported architecture: ' + str(config['model']['arch'])) else: raise ValueError('Unsupported architecture: ' + str(config['model']['type'])) return net
def main(): global args, best_err1, best_err5 args = parser.parse_args() traindir = os.path.join(args.data_path, 'train') valdir = os.path.join(args.data_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) ) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None) val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) ) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.net_type == 'mobilenet': t_net = ResNet.resnet50(pretrained=True) s_net = Mov.MobileNet() elif args.net_type == 'resnet': t_net = ResNet.resnet152(pretrained=True) s_net = ResNet.resnet50(pretrained=False) else: print('undefined network type !!!') raise d_net = distiller.Distiller(t_net, s_net) print ('Teacher Net: ') print(t_net) print ('Student Net: ') print(s_net) print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()]))) print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()]))) t_net = torch.nn.DataParallel(t_net).cuda() s_net = torch.nn.DataParallel(s_net).cuda() d_net = torch.nn.DataParallel(d_net).cuda() # define loss function (criterion) and optimizer criterion_CE = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(list(s_net.parameters()) + list(d_net.module.Connectors.parameters()), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) cudnn.benchmark = True print('Teacher network performance') validate(val_loader, t_net, criterion_CE, 0) for epoch in range(1, args.epochs+1): adjust_learning_rate(optimizer, epoch) # train for one epoch train_with_distill(train_loader, d_net, optimizer, criterion_CE, epoch) # evaluate on validation set err1, err5 = validate(val_loader, s_net, criterion_CE, epoch) # remember best prec@1 and save checkpoint is_best = err1 <= best_err1 best_err1 = min(err1, best_err1) if is_best: best_err5 = err5 print ('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) save_checkpoint({ 'epoch': epoch, 'arch': args.net_type, 'state_dict': s_net.state_dict(), 'best_err1': best_err1, 'best_err5': best_err5, 'optimizer' : optimizer.state_dict(), }, is_best) gc.collect() print ('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
def main(): global best_err1, best_err5 if config.train_params.use_seed: utils.set_seed(config.train_params.seed) imagenet = imagenet_data.ImageNet12(trainFolder=os.path.join(config.data.data_path, 'train'), testFolder=os.path.join(config.data.data_path, 'val'), num_workers=config.data.num_workers, type_of_data_augmentation=config.data.type_of_data_aug, data_config=config.data) train_loader, val_loader = imagenet.getTrainTestLoader(config.data.batch_size) if config.net_type == 'mobilenet': t_net = ResNet.resnet50(pretrained=True) s_net = Mov.MobileNet() elif config.net_type == 'resnet': t_net = ResNet.resnet34(pretrained=True) s_net = ResNet.resnet18(pretrained=False) else: print('undefined network type !!!') raise RuntimeError('%s does not support' % config.net_type) import knowledge_distiller d_net = knowledge_distiller.WSLDistiller(t_net, s_net) print('Teacher Net: ') print(t_net) print('Student Net: ') print(s_net) print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()]))) print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()]))) t_net = torch.nn.DataParallel(t_net) s_net = torch.nn.DataParallel(s_net) d_net = torch.nn.DataParallel(d_net) if config.optim.if_resume: checkpoint = torch.load(config.optim.resume_path) d_net.module.load_state_dict(checkpoint['train_state_dict']) best_err1 = checkpoint['best_err1'] best_err5 = checkpoint['best_err5'] start_epoch = checkpoint['epoch'] + 1 else: start_epoch = 0 t_net = t_net.cuda() s_net = s_net.cuda() d_net = d_net.cuda() ### choose optimizer parameters optimizer = torch.optim.SGD(list(s_net.parameters()), config.optim.init_lr, momentum=config.optim.momentum, weight_decay=config.optim.weight_decay, nesterov=True) cudnn.benchmark = True cudnn.enabled = True print('Teacher network performance') validate(val_loader, t_net, 0) for epoch in range(start_epoch, config.train_params.epochs + 1): adjust_learning_rate(optimizer, epoch) # train for one epoch train_with_distill(train_loader, d_net, optimizer, epoch) # evaluate on validation set err1, err5 = validate(val_loader, s_net, epoch) # remember best prec@1 and save checkpoint is_best = err1 <= best_err1 best_err1 = min(err1, best_err1) if is_best: best_err5 = err5 print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) save_checkpoint({ 'epoch': epoch, 'state_dict': s_net.module.state_dict(), 'train_state_dict': d_net.module.state_dict(), 'best_err1': best_err1, 'best_err5': best_err5, 'optimizer': optimizer.state_dict(), }, is_best) gc.collect() print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
import models.ResNet as resnet print("running image analyzer") dualnet = resnet.resnet50(pretrained=True) print("end image analyzer")