def main(): #Print args global args, best_prec1 args = parser.parse_args() print('\n\t\t\t\t Aum Sri Sai Ram\nFER Test on AffectWild2 \n\n') print(args) print('\nimg_dir: ', args.root_path) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] imagesize = args.imagesize best_expr_f1 = 0 final_cm = 0 final_mcm = 0 best_prec1 = 0 test_transform = transforms.Compose([ transforms.Resize((args.imagesize,args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean,std) ]) # prepare model basemodel = resnet50(pretrained = False) attention_model = AttentionBranch(inputdim = 512, num_regions = args.num_attentive_regions, num_classes = args.num_classes) region_model = RegionBranch(inputdim = 1024, num_regions = args.num_regions, num_classes = args.num_classes) basemodel = torch.nn.DataParallel(basemodel).to(device) attention_model = torch.nn.DataParallel(attention_model).to(device) region_model = torch.nn.DataParallel(region_model).to(device) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) basemodel.load_state_dict(checkpoint['base_state_dict']) attention_model.load_state_dict(checkpoint['attention_state_dict']) region_model.load_state_dict(checkpoint['region_state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.predict_test: print('\n Test Mode:') test_dataset = ImageList(root=args.root_path,fileList='../data/Affwild2/Annotations/test_set.pkl',train_mode = 'Test', transform = test_transform) test_loader = torch.utils.data.DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=8) print('\n length of AffectWild2 test Database: ' + str(len(test_loader.dataset))) test(test_loader, basemodel, attention_model, region_model) create_test_output() print('Sairam. Exiting. Bye.')
def main(): global args, best_prec1 args = parser.parse_args() print('\n\t\t\t\t Aum Sri Sai Ram\nFER on FERPLUS using OADN \n\n') print(args) print('\nimg_dir: ', args.root_path) print('\nTraining mode: ', args.mode) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] imagesize = args.imagesize train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) valid_transform = transforms.Compose([ transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) #train_data = ImageList(root=args.root_path +'Images/FER2013TrainValid/' , fileList=args.train_list, transform=train_transform, mode = args.mode) train_data = ImageList(root=args.root_path + 'Images/FER2013TrainValid/', landmarksfile=args.train_landmarksfile, fileList=args.train_list, transform=train_transform, mode='majority') train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) #test_data = ImageList(root=args.root_path+'Images/FER2013Test/', fileList = args.test_list, transform = valid_transform, mode = args.mode) test_data = ImageList(root=args.root_path + 'Images/FER2013Test/', landmarksfile=args.test_landmarksfile, fileList=args.test_list, transform=valid_transform, mode='majority') test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_t, shuffle=False, num_workers=args.workers, pin_memory=True) print('length of train+valid Database for training: ' + str(len(train_loader.dataset))) print('length of test Database: ' + str(len(test_loader.dataset))) # prepare model basemodel = resnet50(pretrained=False) attention_model = AttentionBranch(inputdim=2048, num_maps=24, num_classes=args.num_classes) region_model = RegionBranch(inputdim=2048, num_regions=4, num_classes=args.num_classes) basemodel = torch.nn.DataParallel(basemodel).to(device) attention_model = torch.nn.DataParallel(attention_model).to(device) region_model = torch.nn.DataParallel(region_model).to(device) print('\nNumber of parameters:') print( 'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'. format( count_parameters(basemodel), count_parameters(attention_model), count_parameters(region_model), count_parameters(basemodel) + count_parameters(attention_model) + count_parameters(region_model))) criterion = nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD([{ "params": region_model.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }]) optimizer.add_param_group({ "params": attention_model.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }) optimizer.add_param_group({ "params": basemodel.parameters(), "lr": 0.0001, "momentum": args.momentum, "weight_decay": args.weight_decay }) if args.pretrained: util.load_state_dict( basemodel, 'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl') if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] basemodel.load_state_dict(checkpoint['base_state_dict']) attention_model.load_state_dict(checkpoint['attention_state_dict']) region_model.load_state_dict(checkpoint['region_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) print('\nTraining starting:\n') for epoch in range(args.start_epoch, args.epochs): # train for one epoch adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, basemodel, attention_model, region_model, criterion, optimizer, epoch) prec1 = validate(test_loader, basemodel, attention_model, region_model, criterion, epoch) print("Epoch: {} Test Acc: {}".format(epoch, prec1)) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1.to(device).item(), best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'base_state_dict': basemodel.state_dict(), 'attention_state_dict': attention_model.state_dict(), 'region_state_dict': region_model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best.item())
def main(): #Print args global args, best_prec1 args = parser.parse_args() print('\n\t\t\t\t Aum Sri Sai Ram\nFER on FEDRO using Local and global Attention along with region branch (non-overlapping patches)\n\n') print(args) print('\nimg_dir: ', args.root_path) print('\ntrain rule: ',args.train_rule, ' and loss type: ', args.loss_type, '\n') print('\nlr is : ', args.lr) print('img_dir:', args.root_path) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] imagesize = args.imagesize train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.4, contrast = 0.3, saturation = 0.25, hue = 0.05), transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean,std) ]) valid_transform = transforms.Compose([ transforms.Resize((args.imagesize,args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean,std) ]) val_data = TestList(root='../data/FED_RO/FED_RO_crop/', fileList=args.valid_list, transform=valid_transform) val_loader = torch.utils.data.DataLoader(val_data, args.batch_size, shuffle=False, num_workers=8) train_dataset = TrainList(root=args.root_path, fileList = args.train_list, transform=train_transform) if args.train_rule == 'None': train_sampler = None per_cls_weights = None elif args.train_rule == 'Resample': train_sampler = ImbalancedDatasetSampler(train_dataset) per_cls_weights = None elif args.train_rule == 'Reweight': train_sampler = None beta = 0.9999 #0:normal weighting effective_num = 1.0 - np.power(beta, cls_num_list) per_cls_weights = (1.0 - beta) / np.array(effective_num) per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) per_cls_weights = torch.FloatTensor(per_cls_weights).to(device) if args.loss_type == 'CE': criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device) elif args.loss_type == 'Focal': criterion = FocalLoss(weight=per_cls_weights, gamma=2).to(device) else: warnings.warn('Loss type is not listed') return train_loader = torch.utils.data.DataLoader(train_dataset, args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) print('length of AffectNet + RAFDB train Database: ' + str(len(train_dataset))) print('length of FEDRO occlusion Database: ' + str(len(val_loader.dataset))) # prepare model basemodel = resnet50(pretrained = False) attention_model = AttentionBranch(inputdim = 512, num_regions = args.num_attentive_regions, num_classes = args.num_classes) region_model = RegionBranch(inputdim = 1024, num_regions = args.num_regions, num_classes = args.num_classes) basemodel = torch.nn.DataParallel(basemodel).to(device) attention_model = torch.nn.DataParallel(attention_model).to(device) region_model = torch.nn.DataParallel(region_model).to(device) print('\nNumber of parameters:') print('Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.format(count_parameters(basemodel),count_parameters(attention_model), count_parameters(region_model), count_parameters(basemodel)+count_parameters(attention_model)+count_parameters(region_model))) optimizer = torch.optim.SGD([{"params": basemodel.parameters(), "lr": 0.0001, "momentum":args.momentum, "weight_decay":args.weight_decay}]) optimizer.add_param_group({"params": attention_model.parameters(), "lr": args.lr, "momentum":args.momentum, "weight_decay":args.weight_decay}) optimizer.add_param_group({"params": region_model.parameters(), "lr": args.lr, "momentum":args.momentum, "weight_decay":args.weight_decay}) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) basemodel.load_state_dict(checkpoint['base_state_dict']) attention_model.load_state_dict(checkpoint['attention_state_dict']) region_model.load_state_dict(checkpoint['region_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) print('\nTraining starting:\n') for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, basemodel, attention_model, region_model, criterion, optimizer, epoch) adjust_learning_rate(optimizer, epoch) prec1 = validate(val_loader, basemodel, attention_model, region_model, criterion, epoch) print("Epoch: {} Test Acc: {}".format(epoch, prec1)) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1.to(device).item(), best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'base_state_dict': basemodel.state_dict(), 'attention_state_dict': attention_model.state_dict(), 'region_state_dict': region_model.state_dict(), 'best_prec1': best_prec1, 'optimizer' : optimizer.state_dict(), }, is_best.item())
def main(): global args, best_prec1 args = parser.parse_args() print( '\n\t\t Aum Sri Sai Ram\n\t\tRAFDB FER using Attention branch based on gaussian maps with region branch\n\n' ) print(args) print('img_dir:', args.root_path) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] imagesize = args.imagesize train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) valid_transform = transforms.Compose([ transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) train_dataset = ImageList(root=args.root_path, landmarksfile=args.landmarksfile, fileList=args.train_list, transform=train_transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) test_data = ImageList(root=args.root_path, fileList=args.test_list, landmarksfile=args.landmarksfile, transform=valid_transform) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_t, shuffle=False, num_workers=args.workers, pin_memory=True) print('length of RAFDB train Database: ' + str(len(train_dataset))) print('length of RAFDB test Database: ' + str(len(test_loader.dataset))) # prepare model # prepare model basemodel = resnet50(pretrained=False) attention_model = LandmarksAttentionBranch(inputdim=1024, num_maps=24, num_classes=args.num_classes) region_model = RegionBranch(inputdim=1024, num_regions=args.num_regions, num_classes=args.num_classes) basemodel = torch.nn.DataParallel(basemodel).to(device) attention_model = torch.nn.DataParallel(attention_model).to(device) region_model = torch.nn.DataParallel(region_model).to(device) print('\nNumber of parameters:') print( 'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'. format( count_parameters(basemodel), count_parameters(attention_model), count_parameters(region_model), count_parameters(basemodel) + count_parameters(attention_model) + count_parameters(region_model))) criterion = nn.CrossEntropyLoss().to(device) optimizer1 = torch.optim.SGD([{ "params": basemodel.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }]) optimizer1.add_param_group({ "params": attention_model.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }) optimizer1.add_param_group({ "params": region_model.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }) if args.pretrained: pretrained_state_dict = torch.load( 'pretrainedmodels/resnet50-19c8e357.pth') model_state_dict = basemodel.state_dict() #print(model_state_dict.keys()) for key in pretrained_state_dict: if ((key == 'fc.weight') | (key == 'fc.bias') | (key == 'feature.weight') | (key == 'feature.bias') | (key.find('layer4.') > -1)): pass else: #print(key) model_state_dict['module.' + key] = pretrained_state_dict[key] basemodel.load_state_dict(model_state_dict, strict=True) print('\nLoaded resent50 pretrained on imagenet.\n') if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] basemodel.load_state_dict(checkpoint['base_state_dict']) attention_model.load_state_dict(checkpoint['attention_state_dict']) region_model.load_state_dict(checkpoint['region_state_dict']) optimizer1.load_state_dict(checkpoint['optimizer1']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) prec1 = validate(test_loader, basemodel, attention_model, region_model, criterion, 0) print("Epoch: {} Test Acc: {}".format(0, prec1)) assert (False) print('\nTraining starting:\n') for epoch in range(args.start_epoch, args.epochs): # train for one epoch adjust_learning_rate(optimizer1, epoch) # train for one epoch train(train_loader, basemodel, attention_model, region_model, criterion, optimizer1, epoch) prec1 = validate(test_loader, basemodel, attention_model, region_model, criterion, epoch) print("Epoch: {} Test Acc: {}".format(epoch, prec1)) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1.to(device).item(), best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'base_state_dict': basemodel.state_dict(), 'attention_state_dict': attention_model.state_dict(), 'region_state_dict': region_model.state_dict(), 'best_prec1': best_prec1, 'optimizer1': optimizer1.state_dict(), #'optimizer2' : optimizer2.state_dict(), }, is_best.item())
def main(): #Print args global args, best_prec1 args = parser.parse_args() print('\n\t\t\t\t Aum Sri Sai Ram\nFER on SFEW using OADN\n\n') print(args) print('\nimg_dir: ', args.root_path) print('\ntrain rule: ', args.train_rule, ' and loss type: ', args.loss_type, '\n') mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] imagesize = args.imagesize train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) valid_transform = transforms.Compose([ transforms.Resize((args.imagesize, args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) val_data = ImageList(root=args.root_path + 'Val_Aligned_Faces_SAN/', fileList=args.valid_list, landmarksfile=args.test_landmarksfile, transform=valid_transform) val_loader = torch.utils.data.DataLoader(val_data, args.batch_size, shuffle=False, num_workers=8) train_dataset = ImageList(root=args.root_path + 'Train_Aligned_Faces_SAN/', fileList=args.train_list, landmarksfile=args.train_landmarksfile, transform=train_transform) if args.train_rule == 'None': train_sampler = None per_cls_weights = None if args.loss_type == 'CE': criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device) else: warnings.warn('Loss type is not listed') return train_loader = torch.utils.data.DataLoader(train_dataset, args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) print('length of SFEW train Database: ' + str(len(train_dataset))) print('length of SFEW valid Database: ' + str(len(val_loader.dataset))) # prepare model basemodel = resnet50(pretrained=False) attention_model = AttentionBranch(inputdim=2048, num_maps=24, num_classes=args.num_classes) region_model = RegionBranch(inputdim=2048, num_regions=4, num_classes=args.num_classes) basemodel = torch.nn.DataParallel(basemodel).to(device) attention_model = torch.nn.DataParallel(attention_model).to(device) region_model = torch.nn.DataParallel(region_model).to(device) print('\nNumber of parameters:') print( 'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'. format( count_parameters(basemodel), count_parameters(attention_model), count_parameters(region_model), count_parameters(basemodel) + count_parameters(attention_model) + count_parameters(region_model))) criterion = nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD([{ "params": region_model.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }]) optimizer.add_param_group({ "params": attention_model.parameters(), "lr": args.lr, "momentum": args.momentum, "weight_decay": args.weight_decay }) optimizer.add_param_group({ "params": basemodel.parameters(), "lr": 0.0001, "momentum": args.momentum, "weight_decay": args.weight_decay }) if args.pretrained: util.load_state_dict( basemodel, 'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl') if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) basemodel.load_state_dict(checkpoint['base_state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) print('\nTraining starting:\n') for epoch in range(args.start_epoch, args.epochs): train(train_loader, basemodel, attention_model, region_model, criterion, optimizer, epoch) prec1 = validate(val_loader, basemodel, attention_model, region_model, criterion, epoch) print("Epoch: {} Test Acc: {}".format(epoch, prec1)) is_best = prec1 > best_prec1 best_prec1 = max(prec1.to(device).item(), best_prec1) adjust_learning_rate(optimizer, epoch) save_checkpoint( { 'epoch': epoch + 1, 'base_state_dict': basemodel.state_dict(), 'attention_state_dict': attention_model.state_dict(), 'region_state_dict': region_model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, is_best.item())
def main(): #Print args global args, best_prec1 args = parser.parse_args() print('\n\t\t\t\t Aum Sri Sai Ram\nFER on AffectWild2 using Local and global Attention along with region branch (non-overlapping patches)\n\n') print(args) print('\nimg_dir: ', args.root_path) print('\ntrain rule: ',args.train_rule, ' and loss type: ', args.loss_type, '\n') print('\n lr is : ', args.lr) mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] imagesize = args.imagesize best_expr_f1 = 0 final_cm = 0 final_mcm = 0 best_prec1 = 0 train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.4, contrast = 0.3, saturation = 0.25, hue = 0.05), transforms.Resize((args.imagesize,args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean,std) ]) valid_transform = transforms.Compose([ transforms.Resize((args.imagesize,args.imagesize)), transforms.ToTensor(), transforms.Normalize(mean,std) ]) val_data = ImageList(root=args.root_path, fileList = args.metafile, train_mode='Validation', transform=valid_transform) val_loader = torch.utils.data.DataLoader(val_data, args.batch_size, shuffle=False, num_workers=8) train_dataset = ImageList(root=args.root_path, fileList = args.metafile,train_mode='Train', transform=train_transform) cls_num_list = train_dataset.get_cls_num_list() print('\nTrain cls num list:', cls_num_list) if args.train_rule == 'None': train_sampler = None per_cls_weights = None elif args.train_rule == 'Reweight': train_sampler = None beta = 0.9999 #0:normal weighting effective_num = 1.0 - np.power(beta, cls_num_list) per_cls_weights = (1.0 - beta) / np.array(effective_num) per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) per_cls_weights = torch.FloatTensor(per_cls_weights).to(device) if args.loss_type == 'CE': criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device) elif args.loss_type == 'Focal': criterion = FocalLoss(weight=per_cls_weights, gamma=2).to(device) else: warnings.warn('Loss type is not listed') return train_loader = torch.utils.data.DataLoader(train_dataset, args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) print('\nlength of AffectWild2 train Database: ' + str(len(train_dataset))) print('\nlength of AffectWild2 valid Database: ' + str(len(val_loader.dataset))) # prepare model basemodel = resnet50(pretrained = False) attention_model = AttentionBranch(inputdim = 512, num_regions = args.num_attentive_regions, num_classes = args.num_classes) region_model = RegionBranch(inputdim = 1024, num_regions = args.num_regions, num_classes = args.num_classes) basemodel = torch.nn.DataParallel(basemodel).to(device) attention_model = torch.nn.DataParallel(attention_model).to(device) region_model = torch.nn.DataParallel(region_model).to(device) print('\nNumber of parameters:') print('Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.format(count_parameters(basemodel),count_parameters(attention_model), count_parameters(region_model), count_parameters(basemodel)+count_parameters(attention_model)+count_parameters(region_model))) optimizer = torch.optim.SGD([{"params": basemodel.parameters(), "lr": 0.0001, "momentum":args.momentum, "weight_decay":args.weight_decay}]) optimizer.add_param_group({"params": attention_model.parameters(), "lr": args.lr, "momentum":args.momentum, "weight_decay":args.weight_decay}) optimizer.add_param_group({"params": region_model.parameters(), "lr": args.lr, "momentum":args.momentum, "weight_decay":args.weight_decay}) if args.pretrained: util.load_state_dict(basemodel,'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl') if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) basemodel.load_state_dict(checkpoint['base_state_dict']) attention_model.load_state_dict(checkpoint['attention_state_dict']) region_model.load_state_dict(checkpoint['region_state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume))