def __init__(self, backbone, lr_channels=17, hr_channels=5): super(isCNN, self).__init__() self.model_name = backbone self.lr_backbone = EfficientNet.from_name(backbone) self.hr_backbone = EfficientNet.from_name(backbone) self.lr_final = nn.Sequential(double_conv(self.n_channels, 64), nn.Conv2d(64, lr_channels, 1)) self.up_conv1 = up_conv(2*self.n_channels+lr_channels, 512) self.double_conv1 = double_conv(self.size[0], 512) self.up_conv2 = up_conv(512, 256) self.double_conv2 = double_conv(self.size[1], 256) self.up_conv3 = up_conv(256, 128) self.double_conv3 = double_conv(self.size[2], 128) self.up_conv4 = up_conv(128, 64) self.double_conv4 = double_conv(self.size[3], 64) self.up_conv_input = up_conv(64, 32) self.double_conv_input = double_conv(self.size[4], 32) self.hr_final = nn.Conv2d(self.size[5], hr_channels, kernel_size=1)
def EfficientNet_B8(pretrained=True, num_class=5, onehot=1, onehot2=0): if pretrained: model = EfficientNet.from_pretrained('efficientnet-b8', num_classes=num_class, onehot=onehot, onehot2=onehot2) for name, param in model.named_parameters(): if 'fc' not in name: param.requires_grad = False else: model = EfficientNet.from_name('efficientnet-b8', onehot=onehot, onehot2=onehot2) model.name = "EfficientNet_B8" print("EfficientNet B7 Loaded!") return model
def EfficientNet_B6(pretrained=True, num_class=5, advprop=False, onehot=1, onehot2=0): if pretrained: model = EfficientNet.from_pretrained('efficientnet-b6', num_classes=num_class, onehot=onehot, onehot2=onehot2) for name, param in model.named_parameters(): if 'fc' not in name :# and 'blocks.24' not in name and 'blocks.25' not in name param.requires_grad = False else: model = EfficientNet.from_name('efficientnet-b6', onehot=onehot, onehot2=onehot2) model.name = "EfficientNet_B6" print("EfficientNet B6 Loaded!") return model
def main(args): # Step 1: parse args config logging.basicConfig( format= '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s', level=logging.INFO, handlers=[ logging.FileHandler(args.log_file, mode='w'), logging.StreamHandler() ]) print_args(args) # Step 2: model, criterion, optimizer, scheduler # model = MobileNetV3(mode='large').to(args.device) model = EfficientNet.from_name(args.arch).to(args.device) # auxiliarynet = AuxiliaryNet().to(args.device) auxiliarynet = None checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['model']) # step 3: data # argumetion transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # mpiidataset = MPIIDatasets(args.dataroot, train=True, transforms=transform) # train_dataset = GazeCaptureDatasets(args.dataroot, train=True, transforms=transform) # mpii_val_dataset = MPIIDatasets(args.val_dataroot, train=False, transforms=transform) val_dataset = GazeCaptureDatasets(args.val_dataroot, train=True, transforms=transform) val_dataloader = DataLoader(val_dataset, batch_size=args.val_batchsize, shuffle=False, num_workers=args.workers) # step 4: run val_loss, val_error = validate(args, val_dataloader, model, auxiliarynet, 1) print("val_loss: '{}' val_error: '{}'".format(val_loss, val_error))
def __init__(self, num_classes, network='efficientdet-d0', D_bifpn=3, W_bifpn=88, D_class=3, is_training=True, threshold=0.5, iou_threshold=0.5): super(EfficientDet, self).__init__() # self.efficientnet = EfficientNet.from_pretrained(MODEL_MAP[network]) self.efficientnet = EfficientNet.from_name( MODEL_MAP[network], override_params={'num_classes': num_classes}) self.is_training = is_training self.BIFPN = BIFPN( in_channels=self.efficientnet.get_list_features()[-5:], out_channels=W_bifpn, stack=D_bifpn, num_outs=5) self.regressionModel = RegressionModel(W_bifpn) self.classificationModel = ClassificationModel(W_bifpn, num_classes=num_classes) self.anchors = Anchors() self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() self.threshold = threshold self.iou_threshold = iou_threshold for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() prior = 0.01 self.classificationModel.output.weight.data.fill_(0) self.classificationModel.output.bias.data.fill_(-math.log( (1.0 - prior) / prior)) self.regressionModel.output.weight.data.fill_(0) self.regressionModel.output.bias.data.fill_(0) self.freeze_bn()
def __init__(self, backbone, out_channels=2, concat_input=True): super().__init__() self.model_name = backbone self.backbone = EfficientNet.from_name(backbone) self.concat_input = concat_input self.up_conv1 = up_conv(self.n_channels, 512) self.double_conv1 = double_conv(self.size[0], 512) self.up_conv2 = up_conv(512, 256) self.double_conv2 = double_conv(self.size[1], 256) self.up_conv3 = up_conv(256, 128) self.double_conv3 = double_conv(self.size[2], 128) self.up_conv4 = up_conv(128, 64) self.double_conv4 = double_conv(self.size[3], 64) if self.concat_input: self.up_conv_input = up_conv(64, 32) self.double_conv_input = double_conv(self.size[4], 32) self.final_conv = nn.Conv2d(self.size[5], out_channels, kernel_size=1)
def __init__(self, num_classes, network='efficientdet-d0', D_bifpn=3, W_bifpn=88, D_class=3, is_training=True, threshold=0.01, iou_threshold=0.5): super(EfficientDet, self).__init__() try: self.backbone = EfficientNet.from_pretrained(MODEL_MAP[network]) except: print("pretrained model is not available ", MODEL_MAP[network]) print("make backobne from name") self.backbone = EfficientNet.from_name(MODEL_MAP[network]) self.is_training = is_training self.neck = BIFPN(in_channels=self.backbone.get_list_features()[-5:], out_channels=W_bifpn, stack=D_bifpn, num_outs=5) self.bbox_head = RetinaHead(num_classes=num_classes, in_channels=W_bifpn) self.anchors = Anchors() self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() self.threshold = threshold self.iou_threshold = iou_threshold for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.freeze_bn() self.criterion = FocalLoss()
def __init__(self): super(EfficientBase, self).__init__() self.model_name = 'efficientnet-b0' self.encoder = EfficientNet.from_name('efficientnet-b0') self.size = [1280, 80, 40, 24, 16] # sum self.decoder5 = self._make_deconv_layer(DecoderBlock, self.size[0], self.size[1], stride=2) self.decoder4 = self._make_deconv_layer(DecoderBlock, self.size[1], self.size[2], stride=2) self.decoder3 = self._make_deconv_layer(DecoderBlock, self.size[2], self.size[3], stride=2) self.decoder2 = self._make_deconv_layer(DecoderBlock, self.size[3], self.size[4], stride=2) self.decoder1 = self._make_deconv_layer(DecoderBlock, self.size[4], self.size[4]) self.final = nn.Sequential( nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(8, 8, kernel_size=3, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(8, 2, kernel_size=1))
def test_model(args): # create model num_classes = 2 if args.arch == 'efficientnet_b0': if args.pretrained: model = EfficientNet.from_pretrained("efficientnet-b0", quantize=args.quantize, num_classes=num_classes) else: model = EfficientNet.from_name( "efficientnet-b0", quantize=args.quantize, override_params={'num_classes': num_classes}) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'mobilenet_v1': model = mobilenet_v1(quantize=args.quantize, num_classes=num_classes) model = torch.nn.DataParallel(model).cuda() if args.pretrained: checkpoint = torch.load(args.resume) state_dict = checkpoint['state_dict'] if num_classes != 1000: new_dict = { k: v for k, v in state_dict.items() if 'fc' not in k } state_dict = new_dict res = model.load_state_dict(state_dict, strict=False) for missing_key in res.missing_keys: assert 'quantize' in missing_key or 'fc' in missing_key elif args.arch == 'mobilenet_v2': model = mobilenet_v2(pretrained=args.pretrained, num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'resnet18': model = resnet18(pretrained=args.pretrained, num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'resnet50': model = resnet50(pretrained=args.pretrained, num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'resnet152': model = resnet152(pretrained=args.pretrained, num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'resnet164': model = resnet_164(num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'vgg11': model = vgg11(pretrained=args.pretrained, num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() elif args.arch == 'vgg19': model = vgg19(pretrained=args.pretrained, num_classes=num_classes, quantize=args.quantize) model = torch.nn.DataParallel(model).cuda() else: logging.info('No such model.') sys.exit() if args.resume and not args.pretrained: if os.path.isfile(args.resume): logging.info('=> loading checkpoint `{}`'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( args.resume, checkpoint['epoch'])) else: logging.info('=> no checkpoint found at `{}`'.format(args.resume)) cudnn.benchmark = False test_loader = prepare_test_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) criterion = nn.CrossEntropyLoss().cuda() with torch.no_grad(): prec1 = validate(args, test_loader, model, criterion, 0)
return total_flops if __name__ == '__main__': if 'cifar10' in args.dataset: num_classes = 10 input_res = 32 elif 'cifar100' in args.dataset: num_classes = 100 input_res = 32 elif 'imagenet' in args.dataset: num_classes = 1000 input_res = 224 if args.arch == 'efficientnet_b0': model = EfficientNet.from_name( "efficientnet-b0", override_params={'num_classes': num_classes}) elif args.arch == 'mobilenet_v1': model = mobilenet_v1(num_classes=num_classes) elif args.arch == 'mobilenet_v2': model = mobilenet_v2(num_classes=num_classes) elif args.arch == 'resnet18': model = resnet18(num_classes=num_classes) elif args.arch == 'resnet50': model = resnet50(num_classes=num_classes) elif args.arch == 'resnet152': model = resnet152(num_classes=num_classes)
import cv2 import numpy as np import torchvision.transforms as tfs import matplotlib.pyplot as plt from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader from models.efficientnet import EfficientNet from models.efficientdet import EfficientDet from models.loss import FocalLoss mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) #model = EfficientNet.from_pretrained('efficientnet-b0',False,True) parameters = torch.load('weights/efficientnet-b0.pth') model = EfficientNet.from_name('efficientnet-b0') parameters = [v for _, v in parameters.items()] model_state_dict = model.state_dict() for i, (k, v) in enumerate(model_state_dict.items()): model_state_dict[k] = parameters[i] torch.save(model_state_dict, 'weights/efficientnet-b0.pth') ''' data_dir = 'data' imgs_path = [os.path.join(data_dir, f) for f in os.listdir(data_dir)] imgs = [cv2.resize((cv2.imread(i)[...,::-1]/255 - mean)/std,(608,608)) for i in imgs_path] imgs = torch.stack([torch.from_numpy(i.astype(np.float32)) for i in imgs], 0).permute(0, 3, 1, 2) imgs = imgs[:4].cuda() features = model.extract_features(imgs) print(features.size())
def main(args): # Step 1: parse args config logging.basicConfig( format= '[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s', level=logging.INFO, handlers=[ logging.FileHandler(args.log_file, mode='w'), logging.StreamHandler() ]) print_args(args) # Step 2: model, criterion, optimizer, scheduler # model = MobileNetV3(mode='large').to(args.device) if args.pretrained: model = EfficientNet.from_pretrained(args.arch).to(args.device) print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) model = EfficientNet.from_name(args.arch).to(args.device) auxiliarynet = AuxiliaryNet().to(args.device) # auxiliarynet = AuxiliaryNet() criterion = GazeLoss() optimizer = torch.optim.Adam([{ 'params': model.parameters() }, { 'params': auxiliarynet.parameters() }], lr=args.base_lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=args.lr_patience, verbose=True, min_lr=args.min_lr) # optionally resume from a checkpoint min_error = 1e6 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] min_error = checkpoint['error'] model.load_state_dict(checkpoint['model']) auxiliarynet.load_state_dict(checkpoint['aux']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {}) {:.3f}".format( args.resume, checkpoint['epoch'], min_error)) else: print("=> no checkpoint found at '{}'".format(args.resume)) # step 3: data # argumetion # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) transform = transforms.Compose([ transforms.ToTensor(), ]) # train_dataset = MPIIDatasets(args.dataroot, train=True, transforms=transform) train_dataset = GazeCaptureDatasets(args.dataroot, train=True, transforms=transform) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batchsize, shuffle=True, num_workers=args.workers, drop_last=True) # val_dataset = MPIIDatasets(args.val_dataroot, train=False, transforms=transform) val_dataset = GazeCaptureDatasets(args.val_dataroot, train=False, transforms=transform) val_dataloader = DataLoader(val_dataset, batch_size=args.val_batchsize, shuffle=False, num_workers=args.workers) # step 4: run writer = SummaryWriter(args.tensorboard) for epoch in range(args.start_epoch, args.end_epoch + 1): train_loss, train_error = train(args, train_dataloader, model, auxiliarynet, criterion, optimizer, epoch) val_loss, val_error = validate(args, val_dataloader, model, auxiliarynet, criterion, epoch) filename = os.path.join(str(args.snapshot), "checkpoint_epoch_" + str(epoch) + '.pth.tar') is_best = min_error > val_error min_error = min(min_error, val_error) save_checkpoint( { 'epoch': epoch + 1, 'model': model.state_dict(), 'aux': auxiliarynet.state_dict(), 'optimizer': optimizer.state_dict(), 'error': min_error, }, is_best, filename) scheduler.step(val_loss) writer.add_scalars('data/error', { 'val error': val_error, 'train error ': train_error }, epoch) writer.add_scalars('data/loss', { 'val loss': val_loss, 'train loss': train_loss }, epoch) writer.close()
def main(): args = parser.parse_args() log = logger(args) log.write('V' * 50 + " configs " + 'V' * 50 + '\n') log.write(args) log.write('') log.write('Λ' * 50 + " configs " + 'Λ' * 50 + '\n') # load data input_size = (224, 224) dataset = DataLoader(args, input_size) train_data, val_data = dataset.load_data() num_classes = dataset.num_classes classes = dataset.classes log.write('\n\n') log.write('V' * 50 + " data " + 'V' * 50 + '\n') log.info('success load data.') log.info('num classes: %s' % num_classes) log.info('classes: ' + str(classes) + '\n') log.write('Λ' * 50 + " data " + 'Λ' * 50 + '\n') # Random seed if args.manual_seed is None: args.manual_seed = random.randint(1, 10000) random.seed(args.manual_seed) torch.manual_seed(args.manual_seed) np.random.seed(args.manual_seed) log.write('random seed is %s' % args.manual_seed) # pretrained or not log.write('\n\n') log.write('V' * 50 + " model " + 'V' * 50 + '\n') if args.pretrained: log.info("using pre-trained model") else: log.info("creating model from initial") # model log.info('using model: %s' % args.arch) log.write('') log.write('Λ' * 50 + " model " + 'Λ' * 50 + '\n') # resume model if args.resume: log.info('using resume model: %s' % args.resume) states = torch.load(args.resume) model = states['model'] model.load_state_dict(states['state_dict']) else: log.info('not using resume model') if args.arch.startswith('dla'): model = eval(args.arch)(args.pretrained, num_classes) elif args.arch.startswith('efficientnet'): if args.pretrained: model = EfficientNet.from_pretrained(args.arch, num_classes=num_classes) else: model = EfficientNet.from_name(args.arch, num_classes=num_classes) else: model = make_model(model_name=args.arch, num_classes=num_classes, pretrained=args.pretrained, pool=nn.AdaptiveAvgPool2d(output_size=1), classifier_factory=None, input_size=input_size, original_model_state_dict=None, catch_output_size_exception=True) # cuda have_cuda = torch.cuda.is_available() use_cuda = args.use_gpu and have_cuda log.info('using cuda: %s' % use_cuda) if have_cuda and not use_cuda: log.info( '\nWARNING: found gpu but not use, you can switch it on by: -ug or --use-gpu\n' ) multi_gpus = False if use_cuda: torch.backends.cudnn.benchmark = True if args.multi_gpus: gpus = torch.cuda.device_count() multi_gpus = gpus > 1 if multi_gpus: log.info('using multi gpus, found %d gpus.' % gpus) model = torch.nn.DataParallel(model).cuda() elif use_cuda: model = model.cuda() # criterian log.write('\n\n') log.write('V' * 50 + " criterion " + 'V' * 50 + '\n') if args.label_smoothing > 0 and args.mixup == 1: criterion = CrossEntropyWithLabelSmoothing() log.info('using label smoothing criterion') elif args.label_smoothing > 0 and args.mixup < 1: criterion = CrossEntropyWithMixup() log.info('using label smoothing and mixup criterion') elif args.mixup < 1 and not args.label_smoothing == 0: criterion = CrossEntropyWithMixup() log.info('using mixup criterion') else: criterion = nn.CrossEntropyLoss() log.info('using normal cross entropy criterion') if use_cuda: criterion = criterion.cuda() log.write('using criterion: %s' % str(criterion)) log.write('') log.write('Λ' * 50 + " criterion " + 'Λ' * 50 + '\n') # optimizer log.write('\n\n') log.write('V' * 50 + " optimizer " + 'V' * 50 + '\n') if args.linear_scaling: args.lr = 0.1 * args.train_batch / 256 log.write('initial lr: %4f\n' % args.lr) if args.no_bias_decay: log.info('using no bias weight decay') param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = optim.SGD(optimizer_grouped_parameters, lr=args.lr, momentum=args.momentum) else: log.info('using bias weight decay') optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: optimizer.load_state_dict(states['optimizer']) log.write('using optimizer: %s' % str(optimizer)) log.write('') log.write('Λ' * 50 + " optimizer " + 'Λ' * 50 + '\n') # low precision use_low_precision_training = args.low_precision_training if use_low_precision_training: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # lr scheduler iters_per_epoch = int(np.ceil(len(train_data) / args.train_batch)) total_iters = iters_per_epoch * args.epochs log.write('\n\n') log.write('V' * 50 + " lr_scheduler " + 'V' * 50 + '\n') if args.warmup: log.info('using warmup scheduler, warmup epochs: %d' % args.warmup_epochs) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, iters_per_epoch * args.warmup_epochs, eta_min=1e-6) elif args.cosine: log.info('using cosine lr scheduler') scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_iters) else: log.info('using normal lr decay scheduler') scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10, min_lr=1e-6, mode='min') log.write('using lr scheduler: %s' % str(scheduler)) log.write('') log.write('Λ' * 50 + " lr_scheduler " + 'Λ' * 50 + '\n') log.write('\n\n') log.write('V' * 50 + " training start " + 'V' * 50 + '\n') best_acc = 0 start = time.time() log.info('\nstart training ...') for epoch in range(1, args.epochs + 1): lr = optimizer.param_groups[-1]['lr'] train_loss, train_acc = train_one_epoch( log, scheduler, train_data, model, criterion, optimizer, use_cuda, use_low_precision_training, args.label_smoothing, args.mixup) test_loss, test_acc = val_one_epoch(log, val_data, model, criterion, use_cuda) end = time.time() log.info( 'epoch: [%d / %d], time spent(s): %.2f, mean time: %.2f, lr: %.4f, train loss: %.4f, train acc: %.4f, ' 'test loss: %.4f, test acc: %.4f' % (epoch, args.epochs, end - start, (end - start) / epoch, lr, train_loss, train_acc, test_loss, test_acc)) states = dict() states['arch'] = args.arch if multi_gpus: states['model'] = model.module states['state_dict'] = model.module.state_dict() else: states['model'] = model states['state_dict'] = model.state_dict() states['optimizer'] = optimizer.state_dict() states['test_acc'] = test_acc states['train_acc'] = train_acc states['epoch'] = epoch states['classes'] = classes is_best = False if test_acc > best_acc: is_best = True log.save_checkpoint(states, is_best) else: log.save_checkpoint(states, is_best) log.write('\ntraining finished.') log.write('Λ' * 50 + " training finished " + 'Λ' * 50 + '\n') log.log_file.close() log.writer.close()
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu # suppress printing if not master if args.multiprocessing_distributed and args.gpu != 0: def print_pass(*args): pass builtins.print = print_pass if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # load teacher print("==> loading teacher '{}'".format(args.teacher)) teacher = model_dict[args.t_arch]() for name, param in teacher.named_parameters(): param.requires_grad = False checkpoint = torch.load(args.teacher, map_location="cpu") # rename moco pre-trained keys state_dict = checkpoint['state_dict'] for k in list(state_dict.keys()): # retain only encoder_q up to before the embedding layer if k.startswith('module.encoder_q' ) and not k.startswith('module.encoder_q.fc'): # remove prefix state_dict[k[len("module.encoder_q."):]] = state_dict[k] # delete renamed or unused k del state_dict[k] args.start_epoch = 0 msg = teacher.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} print("==> done") # create model print("==> creating model '{}'".format(args.arch)) if not args.arch in [ 'efficientnet-b0', 'efficientnet-b1', 'mobilenetv3-large' ]: model = Distiller(base_encoder=model_dict[args.arch], teacher=teacher, h_dim=args.pred_h_dim, args=args) elif 'efficientnet' in args.arch: model = Distiller(base_encoder=EfficientNet.from_name(args.arch), teacher=teacher, h_dim=args.pred_h_dim, args=args) else: # mobile net v3 - large model = Distiller(base_encoder=mobilenetv3_large(), teacher=teacher, h_dim=args.pred_h_dim, args=args) print("==> done") if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) # comment out the following line for debugging raise NotImplementedError("Only DistributedDataParallel is supported.") else: # AllGather implementation (batch shuffle, queue update, etc.) in # this code only supports DistributedDataParallel. raise NotImplementedError("Only DistributedDataParallel is supported.") optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 augmentation = [ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomApply( [ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ] train_dataset = datasets.ImageFolder( traindir, moco.loader.TwoCropsTransform(transforms.Compose(augmentation))) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) logger = SummaryWriter(logdir=args.tb_folder) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, model, optimizer, epoch + 1, args, logger) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if (epoch + 1) % args.save_freq == 0: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename=os.path.join( args.save_folder, 'checkpoint_{:04d}epoch.pth.tar'.format(epoch + 1))) # for resume save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename=os.path.join(args.save_folder, 'latest.pth.tar'))