예제 #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print('\n\t\t\t\t Aum Sri Sai Ram\nFER on FERPLUS using OADN \n\n')
    print(args)
    print('\nimg_dir: ', args.root_path)
    print('\nTraining mode: ', args.mode)

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    #train_data = ImageList(root=args.root_path +'Images/FER2013TrainValid/' , fileList=args.train_list,                transform=train_transform,  mode = args.mode)

    train_data = ImageList(root=args.root_path + 'Images/FER2013TrainValid/',
                           landmarksfile=args.train_landmarksfile,
                           fileList=args.train_list,
                           transform=train_transform,
                           mode='majority')

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    #test_data = ImageList(root=args.root_path+'Images/FER2013Test/', fileList = args.test_list,                transform = valid_transform, mode = args.mode)

    test_data = ImageList(root=args.root_path + 'Images/FER2013Test/',
                          landmarksfile=args.test_landmarksfile,
                          fileList=args.test_list,
                          transform=valid_transform,
                          mode='majority')

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size_t,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print('length of  train+valid Database for training: ' +
          str(len(train_loader.dataset)))

    print('length of  test Database: ' + str(len(test_loader.dataset)))

    # prepare model
    basemodel = resnet50(pretrained=False)
    attention_model = AttentionBranch(inputdim=2048,
                                      num_maps=24,
                                      num_classes=args.num_classes)
    region_model = RegionBranch(inputdim=2048,
                                num_regions=4,
                                num_classes=args.num_classes)

    basemodel = torch.nn.DataParallel(basemodel).to(device)
    attention_model = torch.nn.DataParallel(attention_model).to(device)
    region_model = torch.nn.DataParallel(region_model).to(device)

    print('\nNumber of parameters:')
    print(
        'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.
        format(
            count_parameters(basemodel), count_parameters(attention_model),
            count_parameters(region_model),
            count_parameters(basemodel) + count_parameters(attention_model) +
            count_parameters(region_model)))

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD([{
        "params": region_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }])

    optimizer.add_param_group({
        "params": attention_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })
    optimizer.add_param_group({
        "params": basemodel.parameters(),
        "lr": 0.0001,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })

    if args.pretrained:

        util.load_state_dict(
            basemodel, 'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl')

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            basemodel.load_state_dict(checkpoint['base_state_dict'])
            attention_model.load_state_dict(checkpoint['attention_state_dict'])
            region_model.load_state_dict(checkpoint['region_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print('\nTraining starting:\n')
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch

        train(train_loader, basemodel, attention_model, region_model,
              criterion, optimizer, epoch)

        prec1 = validate(test_loader, basemodel, attention_model, region_model,
                         criterion, epoch)

        print("Epoch: {}   Test Acc: {}".format(epoch, prec1))
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1

        best_prec1 = max(prec1.to(device).item(), best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'base_state_dict': basemodel.state_dict(),
                'attention_state_dict': attention_model.state_dict(),
                'region_state_dict': region_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best.item())
def main():
    #Print args
    global args, best_prec1
    args = parser.parse_args()
    print('\n\t\t\t\t Aum Sri Sai Ram\nFER on FEDRO using Local and global Attention along with region branch (non-overlapping patches)\n\n')
    print(args)
    print('\nimg_dir: ', args.root_path)
    print('\ntrain rule: ',args.train_rule, ' and loss type: ', args.loss_type, '\n')
    
    print('\nlr is : ', args.lr)

    print('img_dir:', args.root_path)
    

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize
    train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),           
            transforms.ColorJitter(brightness=0.4, contrast = 0.3, saturation = 0.25, hue = 0.05),      
            transforms.Resize((args.imagesize, args.imagesize)),
            transforms.ToTensor(),
            transforms.Normalize(mean,std)
        ])

    
    valid_transform = transforms.Compose([
            transforms.Resize((args.imagesize,args.imagesize)),
            transforms.ToTensor(),
            transforms.Normalize(mean,std)
        ])

    val_data = TestList(root='../data/FED_RO/FED_RO_crop/', fileList=args.valid_list,
                  transform=valid_transform)   
    
    val_loader = torch.utils.data.DataLoader(val_data, args.batch_size, shuffle=False, num_workers=8)

    train_dataset = TrainList(root=args.root_path, fileList = args.train_list,
                  transform=train_transform)

    
    if args.train_rule == 'None':
       train_sampler = None  
       per_cls_weights = None 
    elif args.train_rule == 'Resample':
       train_sampler = ImbalancedDatasetSampler(train_dataset)
       per_cls_weights = None
    elif args.train_rule == 'Reweight':
       train_sampler = None
       beta = 0.9999                 #0:normal weighting
       effective_num = 1.0 - np.power(beta, cls_num_list)
       per_cls_weights = (1.0 - beta) / np.array(effective_num)
       per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
       per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
    
    if args.loss_type == 'CE':
       criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device)
    elif args.loss_type == 'Focal':
       criterion = FocalLoss(weight=per_cls_weights, gamma=2).to(device)
    else:
       warnings.warn('Loss type is not listed')
       return
    
        
    
    train_loader = torch.utils.data.DataLoader(train_dataset, args.batch_size, shuffle=(train_sampler is None),
                                                   num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    print('length of AffectNet + RAFDB train Database: ' + str(len(train_dataset)))
    print('length of FEDRO occlusion Database: ' + str(len(val_loader.dataset)))
    # prepare model
    basemodel = resnet50(pretrained = False)
    attention_model = AttentionBranch(inputdim = 512, num_regions = args.num_attentive_regions, num_classes = args.num_classes)
    region_model = RegionBranch(inputdim = 1024, num_regions = args.num_regions, num_classes = args.num_classes)

    basemodel = torch.nn.DataParallel(basemodel).to(device)
    attention_model = torch.nn.DataParallel(attention_model).to(device)
    region_model = torch.nn.DataParallel(region_model).to(device)

    print('\nNumber of parameters:')
    print('Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.format(count_parameters(basemodel),count_parameters(attention_model),  count_parameters(region_model), count_parameters(basemodel)+count_parameters(attention_model)+count_parameters(region_model)))      

        
    optimizer =  torch.optim.SGD([{"params": basemodel.parameters(), "lr": 0.0001, "momentum":args.momentum,
                                 "weight_decay":args.weight_decay}])
    
    optimizer.add_param_group({"params": attention_model.parameters(), "lr": args.lr, "momentum":args.momentum,
                                 "weight_decay":args.weight_decay})
    
    optimizer.add_param_group({"params": region_model.parameters(), "lr": args.lr, "momentum":args.momentum,
                                 "weight_decay":args.weight_decay})
    
    
    
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            basemodel.load_state_dict(checkpoint['base_state_dict'])
            attention_model.load_state_dict(checkpoint['attention_state_dict'])
            region_model.load_state_dict(checkpoint['region_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print('\nTraining starting:\n')
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch        
        train(train_loader, basemodel, attention_model, region_model, criterion, optimizer, epoch)
        adjust_learning_rate(optimizer, epoch)
        prec1 = validate(val_loader, basemodel, attention_model, region_model, criterion,  epoch)
        print("Epoch: {}   Test Acc: {}".format(epoch, prec1))
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1

        best_prec1 = max(prec1.to(device).item(), best_prec1)
        
        save_checkpoint({
            'epoch': epoch + 1,
            'base_state_dict': basemodel.state_dict(),
            'attention_state_dict': attention_model.state_dict(),
            'region_state_dict': region_model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best.item())
예제 #3
0
def main():
    #Print args
    global args, best_prec1
    args = parser.parse_args()
    print('\n\t\t\t\t Aum Sri Sai Ram\nFER on SFEW using OADN\n\n')
    print(args)
    print('\nimg_dir: ', args.root_path)
    print('\ntrain rule: ', args.train_rule, ' and loss type: ',
          args.loss_type, '\n')

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    val_data = ImageList(root=args.root_path + 'Val_Aligned_Faces_SAN/',
                         fileList=args.valid_list,
                         landmarksfile=args.test_landmarksfile,
                         transform=valid_transform)

    val_loader = torch.utils.data.DataLoader(val_data,
                                             args.batch_size,
                                             shuffle=False,
                                             num_workers=8)

    train_dataset = ImageList(root=args.root_path + 'Train_Aligned_Faces_SAN/',
                              fileList=args.train_list,
                              landmarksfile=args.train_landmarksfile,
                              transform=train_transform)

    if args.train_rule == 'None':
        train_sampler = None
        per_cls_weights = None

    if args.loss_type == 'CE':
        criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device)
    else:
        warnings.warn('Loss type is not listed')
        return

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    print('length of SFEW train Database: ' + str(len(train_dataset)))
    print('length of SFEW valid Database: ' + str(len(val_loader.dataset)))
    # prepare model
    basemodel = resnet50(pretrained=False)
    attention_model = AttentionBranch(inputdim=2048,
                                      num_maps=24,
                                      num_classes=args.num_classes)
    region_model = RegionBranch(inputdim=2048,
                                num_regions=4,
                                num_classes=args.num_classes)

    basemodel = torch.nn.DataParallel(basemodel).to(device)
    attention_model = torch.nn.DataParallel(attention_model).to(device)
    region_model = torch.nn.DataParallel(region_model).to(device)

    print('\nNumber of parameters:')
    print(
        'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.
        format(
            count_parameters(basemodel), count_parameters(attention_model),
            count_parameters(region_model),
            count_parameters(basemodel) + count_parameters(attention_model) +
            count_parameters(region_model)))

    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD([{
        "params": region_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }])

    optimizer.add_param_group({
        "params": attention_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })
    optimizer.add_param_group({
        "params": basemodel.parameters(),
        "lr": 0.0001,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })

    if args.pretrained:

        util.load_state_dict(
            basemodel, 'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl')

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            basemodel.load_state_dict(checkpoint['base_state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print('\nTraining starting:\n')
    for epoch in range(args.start_epoch, args.epochs):

        train(train_loader, basemodel, attention_model, region_model,
              criterion, optimizer, epoch)

        prec1 = validate(val_loader, basemodel, attention_model, region_model,
                         criterion, epoch)

        print("Epoch: {}   Test Acc: {}".format(epoch, prec1))

        is_best = prec1 > best_prec1

        best_prec1 = max(prec1.to(device).item(), best_prec1)

        adjust_learning_rate(optimizer, epoch)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'base_state_dict': basemodel.state_dict(),
                'attention_state_dict': attention_model.state_dict(),
                'region_state_dict': region_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best.item())
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(
        '\n\t\t Aum Sri Sai Ram\n\t\tRAFDB FER using  Attention branch based on gaussian maps with region branch\n\n'
    )
    print(args)

    print('img_dir:', args.root_path)

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_dataset = ImageList(root=args.root_path,
                              landmarksfile=args.landmarksfile,
                              fileList=args.train_list,
                              transform=train_transform)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    test_data = ImageList(root=args.root_path,
                          fileList=args.test_list,
                          landmarksfile=args.landmarksfile,
                          transform=valid_transform)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size_t,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print('length of RAFDB train Database: ' + str(len(train_dataset)))

    print('length of RAFDB test Database: ' + str(len(test_loader.dataset)))

    # prepare model
    # prepare model
    basemodel = resnet50(pretrained=False)
    attention_model = LandmarksAttentionBranch(inputdim=1024,
                                               num_maps=24,
                                               num_classes=args.num_classes)
    region_model = RegionBranch(inputdim=1024,
                                num_regions=args.num_regions,
                                num_classes=args.num_classes)

    basemodel = torch.nn.DataParallel(basemodel).to(device)
    attention_model = torch.nn.DataParallel(attention_model).to(device)
    region_model = torch.nn.DataParallel(region_model).to(device)

    print('\nNumber of parameters:')
    print(
        'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.
        format(
            count_parameters(basemodel), count_parameters(attention_model),
            count_parameters(region_model),
            count_parameters(basemodel) + count_parameters(attention_model) +
            count_parameters(region_model)))

    criterion = nn.CrossEntropyLoss().to(device)

    optimizer1 = torch.optim.SGD([{
        "params": basemodel.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }])

    optimizer1.add_param_group({
        "params": attention_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })

    optimizer1.add_param_group({
        "params": region_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })

    if args.pretrained:

        pretrained_state_dict = torch.load(
            'pretrainedmodels/resnet50-19c8e357.pth')
        model_state_dict = basemodel.state_dict()
        #print(model_state_dict.keys())
        for key in pretrained_state_dict:
            if ((key == 'fc.weight') | (key == 'fc.bias') |
                (key == 'feature.weight') | (key == 'feature.bias') |
                (key.find('layer4.') > -1)):
                pass
            else:
                #print(key)
                model_state_dict['module.' + key] = pretrained_state_dict[key]

        basemodel.load_state_dict(model_state_dict, strict=True)
        print('\nLoaded resent50 pretrained on imagenet.\n')

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            basemodel.load_state_dict(checkpoint['base_state_dict'])
            attention_model.load_state_dict(checkpoint['attention_state_dict'])
            region_model.load_state_dict(checkpoint['region_state_dict'])
            optimizer1.load_state_dict(checkpoint['optimizer1'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    prec1 = validate(test_loader, basemodel, attention_model, region_model,
                     criterion, 0)
    print("Epoch: {}   Test Acc: {}".format(0, prec1))
    assert (False)

    print('\nTraining starting:\n')
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        adjust_learning_rate(optimizer1, epoch)

        # train for one epoch

        train(train_loader, basemodel, attention_model, region_model,
              criterion, optimizer1, epoch)
        prec1 = validate(test_loader, basemodel, attention_model, region_model,
                         criterion, epoch)
        print("Epoch: {}   Test Acc: {}".format(epoch, prec1))
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1

        best_prec1 = max(prec1.to(device).item(), best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'base_state_dict': basemodel.state_dict(),
                'attention_state_dict': attention_model.state_dict(),
                'region_state_dict': region_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer1': optimizer1.state_dict(),
                #'optimizer2' : optimizer2.state_dict(),
            },
            is_best.item())
def main():
    #Print args
    global args, best_prec1
    args = parser.parse_args()
    print('\n\t\t\t\t Aum Sri Sai Ram\nFER on AffectWild2 using Local and global Attention along with region branch (non-overlapping patches)\n\n')
    print(args)
    print('\nimg_dir: ', args.root_path)
    print('\ntrain rule: ',args.train_rule, ' and loss type: ', args.loss_type, '\n')
    print('\n lr is : ', args.lr)

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize

    best_expr_f1 = 0
    final_cm = 0
    final_mcm = 0
    best_prec1 = 0


    train_transform = transforms.Compose([          
            transforms.RandomHorizontalFlip(p=0.5),           
            transforms.ColorJitter(brightness=0.4, contrast = 0.3, saturation = 0.25, hue = 0.05),            
            transforms.Resize((args.imagesize,args.imagesize)),
            transforms.ToTensor(),
            transforms.Normalize(mean,std)
        ])

    
    valid_transform = transforms.Compose([
            transforms.Resize((args.imagesize,args.imagesize)),
            transforms.ToTensor(),
            transforms.Normalize(mean,std)
        ])

    val_data = ImageList(root=args.root_path, fileList = args.metafile, train_mode='Validation',
                  transform=valid_transform)   
    
    val_loader = torch.utils.data.DataLoader(val_data, args.batch_size, shuffle=False, num_workers=8)

    train_dataset = ImageList(root=args.root_path, fileList = args.metafile,train_mode='Train',
                  transform=train_transform)


    cls_num_list = train_dataset.get_cls_num_list()
    print('\nTrain cls num list:', cls_num_list)


    
    if args.train_rule == 'None':
       train_sampler = None  
       per_cls_weights = None    
    elif args.train_rule == 'Reweight':
       train_sampler = None
       beta = 0.9999                 #0:normal weighting
       effective_num = 1.0 - np.power(beta, cls_num_list)
       per_cls_weights = (1.0 - beta) / np.array(effective_num)
       per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
       per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
    
    if args.loss_type == 'CE':
       criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device)
    elif args.loss_type == 'Focal':
       criterion = FocalLoss(weight=per_cls_weights, gamma=2).to(device)
    else:
       warnings.warn('Loss type is not listed')
       return
    

        
    train_loader = torch.utils.data.DataLoader(train_dataset, args.batch_size, shuffle=(train_sampler is None),
                                                   num_workers=args.workers, pin_memory=True, sampler=train_sampler)    
    
    print('\nlength of AffectWild2 train Database: ' + str(len(train_dataset)))
    print('\nlength of AffectWild2 valid Database: ' + str(len(val_loader.dataset)))
    
    # prepare model
    basemodel = resnet50(pretrained = False)
    attention_model = AttentionBranch(inputdim = 512, num_regions = args.num_attentive_regions, num_classes = args.num_classes)
    region_model = RegionBranch(inputdim = 1024, num_regions = args.num_regions, num_classes = args.num_classes)
    
    basemodel = torch.nn.DataParallel(basemodel).to(device)
    attention_model = torch.nn.DataParallel(attention_model).to(device)
    region_model = torch.nn.DataParallel(region_model).to(device)
    
    print('\nNumber of parameters:')
    print('Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.format(count_parameters(basemodel),count_parameters(attention_model),  count_parameters(region_model), count_parameters(basemodel)+count_parameters(attention_model)+count_parameters(region_model)))  
    
    

       
    optimizer =  torch.optim.SGD([{"params": basemodel.parameters(), "lr": 0.0001, "momentum":args.momentum,
                                 "weight_decay":args.weight_decay}])
    
    optimizer.add_param_group({"params": attention_model.parameters(), "lr": args.lr, "momentum":args.momentum,
                                 "weight_decay":args.weight_decay})
    
    optimizer.add_param_group({"params": region_model.parameters(), "lr": args.lr, "momentum":args.momentum,
                                 "weight_decay":args.weight_decay})
  
    if args.pretrained:
        util.load_state_dict(basemodel,'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl')
        
    
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            basemodel.load_state_dict(checkpoint['base_state_dict'])
            attention_model.load_state_dict(checkpoint['attention_state_dict'])
            region_model.load_state_dict(checkpoint['region_state_dict'])            
			optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))