Пример #1
0
def main():
    if not torch.cuda.is_available():
        print('No GPU device available')
        sys.exit(1)
    cudnn.enabled = True
    cudnn.benchmark = True

    # create model
    print('parsing the architecture')
    if args.model_path and os.path.isfile(args.model_path):
        op_weights, depth_weights = get_op_and_depth_weights(args.model_path)
        parsed_arch = parse_architecture(op_weights, depth_weights)
        mc_mask_dddict = torch.load(args.model_path)['mc_mask_dddict']
        mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
        model = Network(args.num_classes, parsed_arch, mc_num_dddict, None,
                        0.0, 0.0)
    elif args.config_path and os.path.isfile(args.config_path):
        model_config = json.load(open(args.config_path, 'r'))
        model = NetworkCfg(args.num_classes, model_config, None, 0.0, 0.0)
    else:
        raise Exception('invalid --model_path and --config_path')
    model = nn.DataParallel(model).cuda()

    # load pretrained weights
    if os.path.exists(args.weights) and os.path.isfile(args.weights):
        print('loading weights from {}'.format(args.weights))
        checkpoint = torch.load(args.weights)
        model.load_state_dict(checkpoint['state_dict'])

    # define transform and initialize dataloader
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    val_queue = torch.utils.data.DataLoader(ImageList(
        root=args.val_root,
        list_path=args.val_list,
        transform=val_transform,
    ),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=args.workers)

    start = time.time()
    val_acc_top1, val_acc_top5 = validate(val_queue, model)
    print('Val_acc_top1: {:.2f}'.format(val_acc_top1))
    print('Val_acc_top5: {:.2f}'.format(val_acc_top5))
    print('Test time: %ds.', time.time() - start)
Пример #2
0
def main():
    # --------------------------------------model----------------------------------------
    model = sphere20()
    model = torch.nn.DataParallel(model).cuda()
    print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + 'CosFace_0_checkpoint.pth')

    # ------------------------------------load image---------------------------------------
    train_loader = torch.utils.data.DataLoader(
        ImageList(root=args.root_path, fileList=args.train_list,
                  transform=transforms.Compose([
                      transforms.RandomHorizontalFlip(),
                      transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
                      transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
                  ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print('length of train Dataset: ' + str(len(train_loader.dataset)))
    print('Number of Classses: ' + str(args.num_class))

    # --------------------------------loss function and optimizer-----------------------------
    MCP = layer.MarginCosineProduct(512, args.num_class).cuda()
    # MCP = layer.AngleLinear(512, args.num_class).cuda()
    # MCP = torch.nn.Linear(512, args.num_class, bias=False).cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': MCP.parameters()}],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # ----------------------------------------train----------------------------------------
    # lfw_eval.eval(args.save_path + 'CosFace_0_checkpoint.pth')
    for epoch in range(1, args.epochs + 1):
        # scheduler.step()
        train(train_loader, model, MCP, criterion, optimizer, epoch)
        model.module.save(args.save_path + 'CosFace_' + str(epoch) + '_checkpoint.pth')
        lfw_eval.eval(args.save_path + 'CosFace_' + str(epoch) + '_checkpoint.pth')
    print('Finished Training')
Пример #3
0
if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

box = (16, 17, 214, 215)
transform=transforms.Compose([transforms.Lambda(lambda x: x.crop(box)),
                             transforms.Resize((230,230)),
                             #transforms.Resize(opt.imageSize),                            
                             transforms.RandomGrayscale(p=0.1),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(),
                             transforms.RandomCrop((opt.imageSize,opt.imageSize)),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ])
tensor_dataset = ImageList(opt.train_list,transform)
                          
dataloader = DataLoader(tensor_dataset,                          
                        batch_size=opt.batchSize,    
                        shuffle=True,    
                        num_workers=opt.workers)  


ngpu = int(opt.ngpu)



def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 :
        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
Пример #4
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    set_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.gpu = 0
    args.world_size = 1
    if args.distributed:
        set_seed(args.local_rank)
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
    if args.local_rank == 0:
        logging.info("args = {}".format(args))
        logging.info("unparsed_args = {}".format(unparsed))
        logging.info("distributed = {}".format(args.distributed))
        logging.info("opt_level = {}".format(args.opt_level))
        logging.info("keep_batchnorm_fp32 = {}".format(
            args.keep_batchnorm_fp32))
        logging.info("loss_scale = {}".format(args.loss_scale))
        logging.info("CUDNN VERSION: {}".format(
            torch.backends.cudnn.version()))

    # create model
    if args.local_rank == 0:
        logging.info('parsing the architecture')
    if args.model_path and os.path.isfile(args.model_path):
        op_weights, depth_weights = get_op_and_depth_weights(args.model_path)
        parsed_arch = parse_architecture(op_weights, depth_weights)
        mc_mask_dddict = torch.load(args.model_path)['mc_mask_dddict']
        mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
        model = Network(args.num_classes, parsed_arch, mc_num_dddict, None,
                        args.dropout_rate, args.drop_connect_rate)
    elif args.config_path and os.path.isfile(args.config_path):
        model_config = json.load(open(args.config_path, 'r'))
        model = NetworkCfg(args.num_classes, model_config, None,
                           args.dropout_rate, args.drop_connect_rate)
    else:
        raise Exception('invalid --model_path and --config_path')
    if args.sync_bn:
        if args.local_rank == 0: logging.info("using apex synced BN")
        model = parallel.convert_syncbn_model(model)
    model = model.cuda().to(memory_format=memory_format
                            ) if memory_format is not None else model.cuda()
    config = model.config
    if args.local_rank == 0:
        with open(os.path.join(args.save, 'model.config'), 'w') as f:
            json.dump(config, f, indent=4)
        # logging.info(config)
        logging.info("param size = %fMB", count_parameters_in_MB(model))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(args.num_classes,
                                               args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # Initialize Amp
    if args.opt_level is not None:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.
    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)
    else:
        model = nn.DataParallel(model)

    # define transform and initialize dataloader
    batch_size = args.batch_size // args.world_size
    workers = args.workers // args.world_size
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,  #),
            hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    train_dataset = ImageList(root=args.train_root,
                              list_path=args.train_list,
                              transform=train_transform)
    val_dataset = ImageList(root=args.val_root,
                            list_path=args.val_list,
                            transform=val_transform)
    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               shuffle=(train_sampler is None))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             num_workers=workers,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             shuffle=False)

    # define learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    best_acc_top1 = 0
    best_acc_top5 = 0
    start_epoch = 0

    # restart from snapshot
    if args.snapshot and os.path.isfile(args.snapshot):
        if args.local_rank == 0:
            logging.info('loading snapshot from {}'.format(args.snapshot))
        checkpoint = torch.load(
            args.snapshot,
            map_location=lambda storage, loc: storage.cuda(args.gpu))
        start_epoch = checkpoint['epoch']
        best_acc_top1 = checkpoint['best_acc_top1']
        best_acc_top5 = checkpoint['best_acc_top5']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.opt_level is not None:
            amp.load_state_dict(checkpoint['amp'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), last_epoch=0)
        for epoch in range(start_epoch):
            current_lr = scheduler.get_lr()[0]
            if args.local_rank == 0:
                logging.info('Epoch: %d lr %e', epoch, current_lr)
            if epoch < 5 and args.batch_size > 256:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr * (epoch + 1) / 5.0
                if args.local_rank == 0:
                    logging.info('Warming-up Epoch: %d, LR: %e', epoch,
                                 current_lr * (epoch + 1) / 5.0)
            if epoch < 5 and args.batch_size > 256:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr
            scheduler.step()

    # the main loop
    for epoch in range(start_epoch, args.epochs):
        current_lr = scheduler.get_lr()[0]
        if args.local_rank == 0:
            logging.info('Epoch: %d lr %e', epoch, current_lr)
        if epoch < 5 and args.batch_size > 256:
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr * (epoch + 1) / 5.0
            if args.local_rank == 0:
                logging.info('Warming-up Epoch: %d, LR: %e', epoch,
                             current_lr * (epoch + 1) / 5.0)

        if args.distributed:
            train_sampler.set_epoch(epoch)

        epoch_start = time.time()
        train_acc, train_obj = train(train_loader, model, criterion_smooth,
                                     optimizer)
        if args.local_rank == 0:
            logging.info('Train_acc: %f', train_acc)

        val_acc_top1, val_acc_top5, val_obj = validate(val_loader, model,
                                                       criterion)
        if args.local_rank == 0:
            logging.info('Val_acc_top1: %f', val_acc_top1)
            logging.info('Val_acc_top5: %f', val_acc_top5)
            logging.info('Epoch time: %ds.', time.time() - epoch_start)

        if args.local_rank == 0:
            is_best = False
            if val_acc_top1 > best_acc_top1:
                best_acc_top1 = val_acc_top1
                best_acc_top5 = val_acc_top5
                is_best = True
            save_checkpoint(
                {
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    model.state_dict(),
                    'best_acc_top1':
                    best_acc_top1,
                    'best_acc_top5':
                    best_acc_top5,
                    'optimizer':
                    optimizer.state_dict(),
                    'amp':
                    amp.state_dict() if args.opt_level is not None else None,
                }, is_best, args.save)

        if epoch < 5 and args.batch_size > 256:
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

        scheduler.step()
Пример #5
0
def main():
    # --------------------------------------model----------------------------------------
    if args.network is 'sphere20':
        model = net.sphere(type=20, is_gray=args.is_gray)
        model_eval = net.sphere(type=20, is_gray=args.is_gray)
    elif args.network is 'sphere64':
        model = net.sphere(type=64, is_gray=args.is_gray)
        model_eval = net.sphere(type=64, is_gray=args.is_gray)
    elif args.network is 'LResNet50E_IR':
        model = net.LResNet50E_IR(is_gray=args.is_gray)
        model_eval = net.LResNet50E_IR(is_gray=args.is_gray)
    else:
        raise ValueError("NOT SUPPORT NETWORK! ")

    model = torch.nn.DataParallel(model).to(device)
    model_eval = model_eval.to(device)
    print(model)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + 'CosFace_0_checkpoint.pth')

    # 512 is dimension of feature
    classifier = {
        'MCP': layer.MarginCosineProduct(512, args.num_class).to(device),
        'AL': layer.AngleLinear(512, args.num_class).to(device),
        'L': torch.nn.Linear(512, args.num_class, bias=False).to(device)
    }[args.classifier_type]

    # ------------------------------------load image---------------------------------------
    if args.is_gray:
        train_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, ), std=(0.5, ))
        ])  # gray
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5,
                                      0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
        ])
    train_loader = torch.utils.data.DataLoader(ImageList(
        root=args.root_path,
        fileList=args.train_list,
        transform=train_transform),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print('length of train Database: ' + str(len(train_loader.dataset)))
    print('Number of Identities: ' + str(args.num_class))

    # --------------------------------loss function and optimizer-----------------------------
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': classifier.parameters()
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # ----------------------------------------train----------------------------------------
    # lfw_eval.eval(args.save_path + 'CosFace_0_checkpoint.pth')
    for epoch in range(1, args.epochs + 1):
        train(train_loader, model, classifier, criterion, optimizer, epoch)
        model.module.save(args.save_path + 'CosFace_' + str(epoch) +
                          '_checkpoint.pth')
        lfw_eval.eval(
            model_eval,
            args.save_path + 'CosFace_' + str(epoch) + '_checkpoint.pth',
            args.is_gray)
    print('Finished Training')
Пример #6
0
def main():
    resume = True
    if torch.cuda.device_count() > 1:
        print('available gpus is ', torch.cuda.device_count(),
              torch.cuda.get_device_name())
    else:
        print("only one GPU found !!!")
    #model = sphere20()
    model = LResnet50.LResNet50E_IR(is_gray=False)
    print(model)
    model = torch.nn.DataParallel(
        model, device_ids=[0],
        output_device=0).cuda()  # enable mutiple-gpu training

    # print(model)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    # model.save(args.save_path + '/CosFace_0_checkpoint.pth')

    print('save checkpoint finished!')

    # upload training dataset
    train_loader = torch.utils.data.DataLoader(
        ImageList(
            root=args.root_path,
            fileList=args.image_list,

            # processing images
            transform=transforms.Compose([
                # hflip PIL 图像 at 0.5 probability
                transforms.RandomHorizontalFlip(),
                # transform a PIL image(H*W*C)in [0, 255] to torch.Tensor(H*W*C) in [0.0, 0.1]
                transforms.ToTensor(),  # range [0, 255] -> [0.0, 1.0]
                # use mean and standard deviation to normalize data
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
                                     )  # range [0.0, 0.1] -> [-1.0, 1.0]
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=False,
        drop_last=True)

    # print the length of train dataset
    print('length of train dataset: {}'.format(str(len(train_loader.dataset))))
    # print the class number of train dataset
    print('Number of Classes: {}'.format(str(args.num_class)))

    # --------------------------------loss function and optimizer-------------------------------
    # core implementation of Cos face, using cuda
    scale = math.sqrt(2) * math.log(args.num_class - 1)
    MCP = MarginCosineProduct(512, args.num_class, s=scale).cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': MCP.parameters()
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if resume:
        print("resume from epoch 10!!!")
        pretrained_cnn = torch.load('./checkpoints/CosFace_10_checkpoint.pth')
        pretrained_mcp = torch.load('./checkpoints/MCP_10_checkpoint.pth')
        model.load_state_dict(pretrained_cnn)
        MCP.load_state_dict(pretrained_mcp)

    for epoch in range(10, args.epochs + 1):
        train(train_loader, model, MCP, criterion, optimizer, epoch)
        if (epoch % 5 == 0):
            torch.save(
                model.state_dict(),
                os.path.join(args.save_path,
                             'CosFace_' + str(epoch) + '_checkpoint.pth'))
            torch.save(
                MCP.state_dict(),
                os.path.join(args.save_path,
                             'MCP_' + str(epoch) + '_checkpoint.pth'))

    print('Finished Training')
Пример #7
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    set_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    # create model
    logging.info('parsing the architecture')
    if args.model_path and os.path.isfile(args.model_path):
        op_weights, depth_weights = get_op_and_depth_weights(args.model_path)
        parsed_arch = parse_architecture(op_weights, depth_weights)
        mc_mask_dddict = torch.load(args.model_path)['mc_mask_dddict']
        mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
        model = Network(args.num_classes, parsed_arch, mc_num_dddict, None,
                        args.dropout_rate, args.drop_connect_rate)
    elif args.config_path and os.path.isfile(args.config_path):
        model_config = json.load(open(args.config_path, 'r'))
        model = NetworkCfg(args.num_classes, model_config, None,
                           args.dropout_rate, args.drop_connect_rate)
    else:
        raise Exception('invalid --model_path and --config_path')
    model = nn.DataParallel(model).cuda()
    config = model.module.config
    with open(os.path.join(args.save, 'model.config'), 'w') as f:
        json.dump(config, f, indent=4)
    # logging.info(config)
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(args.num_classes,
                                               args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # define transform and initialize dataloader
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,  #),
            hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    train_queue = torch.utils.data.DataLoader(ImageList(
        root=args.train_root,
        list_path=args.train_list,
        transform=train_transform,
    ),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)
    val_queue = torch.utils.data.DataLoader(ImageList(
        root=args.val_root,
        list_path=args.val_list,
        transform=val_transform,
    ),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # define learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))
    best_acc_top1 = 0
    best_acc_top5 = 0
    start_epoch = 0

    # restart from snapshot
    if args.snapshot:
        logging.info('loading snapshot from {}'.format(args.snapshot))
        checkpoint = torch.load(args.snapshot)
        start_epoch = checkpoint['epoch']
        best_acc_top1 = checkpoint['best_acc_top1']
        best_acc_top5 = checkpoint['best_acc_top5']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), last_epoch=0)
        for epoch in range(start_epoch):
            current_lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr %e', epoch, current_lr)
            if epoch < 5 and args.batch_size > 256:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr * (epoch + 1) / 5.0
                logging.info('Warming-up Epoch: %d, LR: %e', epoch,
                             current_lr * (epoch + 1) / 5.0)
            if epoch < 5 and args.batch_size > 256:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr
            scheduler.step()

    # the main loop
    for epoch in range(start_epoch, args.epochs):
        current_lr = scheduler.get_lr()[0]
        logging.info('Epoch: %d lr %e', epoch, current_lr)
        if epoch < 5 and args.batch_size > 256:
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr * (epoch + 1) / 5.0
            logging.info('Warming-up Epoch: %d, LR: %e', epoch,
                         current_lr * (epoch + 1) / 5.0)

        epoch_start = time.time()
        train_acc, train_obj = train(train_queue, model, criterion_smooth,
                                     optimizer)
        logging.info('Train_acc: %f', train_acc)

        val_acc_top1, val_acc_top5, val_obj = validate(val_queue, model,
                                                       criterion)
        logging.info('Val_acc_top1: %f', val_acc_top1)
        logging.info('Val_acc_top5: %f', val_acc_top5)
        logging.info('Epoch time: %ds.', time.time() - epoch_start)

        is_best = False
        if val_acc_top1 > best_acc_top1:
            best_acc_top1 = val_acc_top1
            best_acc_top5 = val_acc_top5
            is_best = True
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc_top1': best_acc_top1,
                'best_acc_top5': best_acc_top5,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)

        if epoch < 5 and args.batch_size > 256:
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

        scheduler.step()
Пример #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    logging.info("args = %s", args)

    with open(args.lookup_path, 'rb') as f:
        lat_lookup = pickle.load(f)

    mc_maxnum_dddict = get_mc_num_dddict(mc_mask_dddict, is_max=True)
    model = Network(args.num_classes, mc_maxnum_dddict, lat_lookup)
    model = torch.nn.DataParallel(model).cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    # save initial model
    model_path = os.path.join(args.save, 'searched_model_00.pth.tar')
    torch.save(
        {
            'state_dict': model.state_dict(),
            'mc_mask_dddict': mc_mask_dddict,
        }, model_path)

    # get lr list
    lr_list = []
    optimizer_w = torch.optim.SGD(model.module.weight_parameters(),
                                  lr=args.w_lr,
                                  momentum=args.w_mom,
                                  weight_decay=args.w_wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_w, float(args.epochs))
    for _ in range(args.epochs):
        lr = scheduler.get_lr()[0]
        lr_list.append(lr)
        scheduler.step()
    del model
    del optimizer_w
    del scheduler

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_queue = torch.utils.data.DataLoader(ImageList(
        root=args.img_root,
        list_path=args.train_list,
        transform=train_transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    val_queue = torch.utils.data.DataLoader(ImageList(root=args.img_root,
                                                      list_path=args.val_list,
                                                      transform=val_transform),
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=args.workers)

    for epoch in range(args.epochs):
        mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
        model = Network(args.num_classes, mc_num_dddict, lat_lookup)
        model = torch.nn.DataParallel(model).cuda()
        model.module.set_temperature(args.T)

        # load model
        model_path = os.path.join(args.save,
                                  'searched_model_{:02}.pth.tar'.format(epoch))
        state_dict = torch.load(model_path)['state_dict']
        for key in state_dict:
            if 'm_ops' not in key:
                exec('model.{}.data = state_dict[key].data'.format(key))
        for stage in mc_mask_dddict:
            for block in mc_mask_dddict[stage]:
                for op_idx in mc_mask_dddict[stage][block]:
                    index = torch.nonzero(
                        mc_mask_dddict[stage][block][op_idx]).view(-1)
                    index = index.cuda()
                    iw = 'model.module.{}.{}.m_ops[{}].inverted_bottleneck.conv.weight.data'.format(
                        stage, block, op_idx)
                    iw_key = 'module.{}.{}.m_ops.{}.inverted_bottleneck.conv.weight'.format(
                        stage, block, op_idx)
                    exec(
                        iw +
                        ' = torch.index_select(state_dict[iw_key], 0, index).data'
                    )
                    dw = 'model.module.{}.{}.m_ops[{}].depth_conv.conv.weight.data'.format(
                        stage, block, op_idx)
                    dw_key = 'module.{}.{}.m_ops.{}.depth_conv.conv.weight'.format(
                        stage, block, op_idx)
                    exec(
                        dw +
                        ' = torch.index_select(state_dict[dw_key], 0, index).data'
                    )
                    pw = 'model.module.{}.{}.m_ops[{}].point_linear.conv.weight.data'.format(
                        stage, block, op_idx)
                    pw_key = 'module.{}.{}.m_ops.{}.point_linear.conv.weight'.format(
                        stage, block, op_idx)
                    exec(
                        pw +
                        ' = torch.index_select(state_dict[pw_key], 1, index).data'
                    )
                    if op_idx >= 4:
                        se_cr_w = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_reduce.weight.data'.format(
                            stage, block, op_idx)
                        se_cr_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.weight'.format(
                            stage, block, op_idx)
                        exec(
                            se_cr_w +
                            ' = torch.index_select(state_dict[se_cr_w_key], 1, index).data'
                        )
                        se_cr_b = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_reduce.bias.data'.format(
                            stage, block, op_idx)
                        se_cr_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.bias'.format(
                            stage, block, op_idx)
                        exec(se_cr_b + ' = state_dict[se_cr_b_key].data')
                        se_ce_w = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_expand.weight.data'.format(
                            stage, block, op_idx)
                        se_ce_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.weight'.format(
                            stage, block, op_idx)
                        exec(
                            se_ce_w +
                            ' = torch.index_select(state_dict[se_ce_w_key], 0, index).data'
                        )
                        se_ce_b = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_expand.bias.data'.format(
                            stage, block, op_idx)
                        se_ce_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.bias'.format(
                            stage, block, op_idx)
                        exec(
                            se_ce_b +
                            ' = torch.index_select(state_dict[se_ce_b_key], 0, index).data'
                        )
        del index

        lr = lr_list[epoch]
        optimizer_w = torch.optim.SGD(model.module.weight_parameters(),
                                      lr=lr,
                                      momentum=args.w_mom,
                                      weight_decay=args.w_wd)
        optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
                                       lr=args.a_lr,
                                       betas=(args.a_beta1, args.a_beta2),
                                       weight_decay=args.a_wd)
        logging.info('Epoch: %d lr: %e T: %e', epoch, lr, args.T)

        # training
        epoch_start = time.time()
        if epoch < 10:
            train_acc = train_wo_arch(train_queue, model, criterion,
                                      optimizer_w)
        else:
            train_acc = train_w_arch(train_queue, val_queue, model, criterion,
                                     optimizer_w, optimizer_a)
            args.T *= args.T_decay
        # logging arch parameters
        logging.info('The current arch parameters are:')
        for param in model.module.log_alphas_parameters():
            param = np.exp(param.detach().cpu().numpy())
            logging.info(' '.join(['{:.6f}'.format(p) for p in param]))
        for param in model.module.betas_parameters():
            param = F.softmax(param.detach().cpu(), dim=-1)
            param = param.numpy()
            logging.info(' '.join(['{:.6f}'.format(p) for p in param]))
        logging.info('Train_acc %f', train_acc)
        epoch_duration = time.time() - epoch_start
        logging.info('Epoch time: %ds', epoch_duration)

        # validation for last 5 epochs
        if args.epochs - epoch < 5:
            val_acc = validate(val_queue, model, criterion)
            logging.info('Val_acc %f', val_acc)

        # update state_dict
        state_dict_from_model = model.state_dict()
        for key in state_dict:
            if 'm_ops' not in key:
                state_dict[key].data = state_dict_from_model[key].data
        for stage in mc_mask_dddict:
            for block in mc_mask_dddict[stage]:
                for op_idx in mc_mask_dddict[stage][block]:
                    index = torch.nonzero(
                        mc_mask_dddict[stage][block][op_idx]).view(-1)
                    index = index.cuda()
                    iw_key = 'module.{}.{}.m_ops.{}.inverted_bottleneck.conv.weight'.format(
                        stage, block, op_idx)
                    state_dict[iw_key].data[
                        index, :, :, :] = state_dict_from_model[iw_key]
                    dw_key = 'module.{}.{}.m_ops.{}.depth_conv.conv.weight'.format(
                        stage, block, op_idx)
                    state_dict[dw_key].data[
                        index, :, :, :] = state_dict_from_model[dw_key]
                    pw_key = 'module.{}.{}.m_ops.{}.point_linear.conv.weight'.format(
                        stage, block, op_idx)
                    state_dict[
                        pw_key].data[:, index, :, :] = state_dict_from_model[
                            pw_key]
                    if op_idx >= 4:
                        se_cr_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.weight'.format(
                            stage, block, op_idx)
                        state_dict[
                            se_cr_w_key].data[:,
                                              index, :, :] = state_dict_from_model[
                                                  se_cr_w_key]
                        se_cr_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.bias'.format(
                            stage, block, op_idx)
                        state_dict[
                            se_cr_b_key].data[:] = state_dict_from_model[
                                se_cr_b_key]
                        se_ce_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.weight'.format(
                            stage, block, op_idx)
                        state_dict[se_ce_w_key].data[
                            index, :, :, :] = state_dict_from_model[
                                se_ce_w_key]
                        se_ce_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.bias'.format(
                            stage, block, op_idx)
                        state_dict[se_ce_b_key].data[
                            index] = state_dict_from_model[se_ce_b_key]
        del state_dict_from_model, index

        # shrink and expand
        if epoch >= 10:
            logging.info('Now shrinking or expanding the arch')
            op_weights, depth_weights = get_op_and_depth_weights(model)
            parsed_arch = parse_architecture(op_weights, depth_weights)
            mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
            before_lat = get_lookup_latency(parsed_arch, mc_num_dddict,
                                            lat_lookup_key_dddict, lat_lookup)
            logging.info(
                'Before, the current lat: {:.4f}, the target lat: {:.4f}'.
                format(before_lat, args.target_lat))

            if before_lat > args.target_lat:
                logging.info('Shrinking......')
                stages = ['stage{}'.format(x) for x in range(1, 7)]
                mc_num_dddict, after_lat = fit_mc_num_by_latency(
                    parsed_arch,
                    mc_num_dddict,
                    mc_maxnum_dddict,
                    lat_lookup_key_dddict,
                    lat_lookup,
                    args.target_lat,
                    stages,
                    sign=-1)
                for start in range(2, 7):
                    stages = ['stage{}'.format(x) for x in range(start, 7)]
                    mc_num_dddict, after_lat = fit_mc_num_by_latency(
                        parsed_arch,
                        mc_num_dddict,
                        mc_maxnum_dddict,
                        lat_lookup_key_dddict,
                        lat_lookup,
                        args.target_lat,
                        stages,
                        sign=1)
            elif before_lat < args.target_lat:
                logging.info('Expanding......')
                stages = ['stage{}'.format(x) for x in range(1, 7)]
                mc_num_dddict, after_lat = fit_mc_num_by_latency(
                    parsed_arch,
                    mc_num_dddict,
                    mc_maxnum_dddict,
                    lat_lookup_key_dddict,
                    lat_lookup,
                    args.target_lat,
                    stages,
                    sign=1)
                for start in range(2, 7):
                    stages = ['stage{}'.format(x) for x in range(start, 7)]
                    mc_num_dddict, after_lat = fit_mc_num_by_latency(
                        parsed_arch,
                        mc_num_dddict,
                        mc_maxnum_dddict,
                        lat_lookup_key_dddict,
                        lat_lookup,
                        args.target_lat,
                        stages,
                        sign=1)
            else:
                logging.info('No opeartion')
                after_lat = before_lat

            # change mc_mask_dddict based on mc_num_dddict
            for stage in parsed_arch:
                for block in parsed_arch[stage]:
                    op_idx = parsed_arch[stage][block]
                    if mc_num_dddict[stage][block][op_idx] != int(
                            sum(mc_mask_dddict[stage][block][op_idx]).item()):
                        mc_num = mc_num_dddict[stage][block][op_idx]
                        max_mc_num = mc_mask_dddict[stage][block][op_idx].size(
                            0)
                        mc_mask_dddict[stage][block][op_idx].data[
                            [True] * max_mc_num] = 0.0
                        key = 'module.{}.{}.m_ops.{}.depth_conv.conv.weight'.format(
                            stage, block, op_idx)
                        weight_copy = state_dict[key].clone().abs().cpu(
                        ).numpy()
                        weight_l1_norm = np.sum(weight_copy, axis=(1, 2, 3))
                        weight_l1_order = np.argsort(weight_l1_norm)
                        weight_l1_order_rev = weight_l1_order[::-1][:mc_num]
                        mc_mask_dddict[stage][block][op_idx].data[
                            weight_l1_order_rev.tolist()] = 1.0

            logging.info(
                'After, the current lat: {:.4f}, the target lat: {:.4f}'.
                format(after_lat, args.target_lat))

        # save model
        model_path = os.path.join(
            args.save, 'searched_model_{:02}.pth.tar'.format(epoch + 1))
        torch.save(
            {
                'state_dict': state_dict,
                'mc_mask_dddict': mc_mask_dddict,
            }, model_path)
Пример #9
0
def main():
    model = sphere20()
    model = torch.nn.DataParallel(model).cuda()

    # print(model)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    # model.save(args.save_path + '/CosFace_0_checkpoint.pth')

    print('save checkpoint finished!')

    # upload training dataset
    train_loader = torch.utils.data.DataLoader(
        ImageList(
            root=args.root_path,
            fileList=args.image_list,

            # processing images
            transform=transforms.Compose([
                # hflip PIL 图像 at 0.5 probability
                transforms.RandomHorizontalFlip(),
                # transform a PIL image(H*W*C)in [0, 255] to torch.Tensor(H*W*C) in [0.0, 0.1]
                transforms.ToTensor(),  # range [0, 255] -> [0.0, 1.0]
                # use mean and standard deviation to normalize data
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
                                     )  # range [0.0, 0.1] -> [-1.0, 1.0]
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    # print the length of train dataset
    print('length of train dataset: {}'.format(str(len(train_loader.dataset))))
    # print the class number of train dataset
    print('Number of Classes: {}'.format(str(args.num_class)))

    # --------------------------------loss function and optimizer-------------------------------
    # core implementation of Cos face, using cuda
    scale = math.sqrt(2) * math.log(args.num_class - 1)
    MCP = MarginCosineProduct(512, args.num_class, s=scale).cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': MCP.parameters()
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    for epoch in range(1, args.epochs + 1):
        train(train_loader, model, MCP, criterion, optimizer, epoch)
        torch.save(
            model.state_dict(),
            os.path.join(args.save_path,
                         'CosFace_' + str(epoch) + '_checkpoint.pth'))

    print('Finished Training')
def main():
    # ----------------------------------------load images----------------------------------------

    train_loader = torch.utils.data.DataLoader(
        ImageList(
            root=root_path,
            fileList=train_list,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
                transforms.Normalize(
                    mean=(0.5, 0.5, 0.5),
                    std=(0.5, 0.5, 0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
            ])),
        batch_size=BatchSize,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
        drop_last=True)

    print('length of train Dataset: ' + str(len(train_loader.dataset)))
    print('Number of Classses: ' + str(num_class))

    # ------------------------------------model--------------------------------------------
    model_ft = net.sphere64a()

    # # --------------load model---------------
    # model_path = './checkpoints/mnface_30_checkpoints.pth'
    # state_dict = torch.load(model_path)
    # model_ft.load_state_dict(state_dict)

    #------------------------------use gpu--------------------
    if use_gpu:
        model_ft = nn.DataParallel(model_ft).cuda()

    # ------------------------------cosface loss and optimizer-------------------------
    MCP = layer.MarginCosineProduct(512, num_class).cuda()
    # MCP = layer.AngleLinear(512, args.num_class).cuda()
    # MCP = torch.nn.Linear(512, args.num_class, bias=False).cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD([{
        'params': model_ft.parameters()
    }, {
        'params': MCP.parameters()
    }],
                                lr=lr_ori,
                                momentum=0.9,
                                weight_decay=0.0005)

    for epoch in range(1, 38 + 1):
        # # -------------------my loss----------------------------
        # train(train_loader, model_ft, mining_loss, ce_loss, optimizer, epoch)
        # model_ft.module.save(save_path + 'mnface_' + str(epoch) + '_checkpoints.pth')
        # acc, pred = lfw_eval.eval(save_path + 'mnface_' + str(epoch) + '_checkpoints.pth')

        #-------------------cos face--------------------------
        train(train_loader, model_ft, MCP, criterion, optimizer, epoch)
        model_ft.module.save(save_path + 'cosface_' + str(epoch) +
                             '_checkpoints.pth')
        acc, pred = lfw_eval.eval(save_path + 'cosface_' + str(epoch) +
                                  '_checkpoints.pth')

        writer.add_scalar('Test/LFWAcc', acc, epoch)
    print('finished training')
Пример #11
0
def main():
    # ----------------------------------------load images----------------------------------------

    train_loader = torch.utils.data.DataLoader(
        ImageList(
            root=root_path,
            fileList=train_list,
            transform=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
                transforms.Normalize(
                    mean=(0.5, 0.5, 0.5),
                    std=(0.5, 0.5, 0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
            ])),
        batch_size=BatchSize,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
        drop_last=True)

    print('length of train Dataset: ' + str(len(train_loader.dataset)))
    f.write('length of train Dataset: ' + str(len(train_loader.dataset)) +
            '\n')
    print('Number of Classses: ' + str(num_class))
    f.write('Number of Classses: ' + str(num_class) + '\n')

    # ------------------------------------model--------------------------------------------
    model_ft = net.sphere20a()

    # # --------------load model---------------
    # model_path = './checkpoints/mnface_30_checkpoints.pth'
    # state_dict = torch.load(model_path)
    # model_ft.load_state_dict(state_dict)

    #------------------------------use gpu--------------------
    if use_gpu:
        # speed up training
        model_ft = nn.DataParallel(model_ft).cuda()
        # model_ft = model_ft.cuda()

    # -----------------------------------loss function and optimizer--------------------------

    if multi_sphere:
        mining_loss = layer.MultiMini(512, num_class)
    else:
        mining_loss = layer.miniloss(512, num_class)
    ce_loss = nn.CrossEntropyLoss()
    if use_gpu:
        mining_loss = mining_loss.cuda()
        ce_loss = ce_loss.cuda()

    optimizer = optim.SGD([{
        'params': model_ft.parameters()
    }, {
        'params': mining_loss.parameters()
    }],
                          lr=lr_ori,
                          momentum=0.9,
                          weight_decay=0.0005)

    # # ------------------------------cosface loss and optimizer-------------------------
    # MCP = layer.MarginCosineProduct(512, num_class).cuda()
    # # MCP = layer.AngleLinear(512, args.num_class).cuda()
    # # MCP = torch.nn.Linear(512, args.num_class, bias=False).cuda()
    # criterion = torch.nn.CrossEntropyLoss().cuda()
    # optimizer = torch.optim.SGD([{'params': model_ft.parameters()}, {'params': MCP.parameters()}],
    #                             lr=lr_ori, momentum=0.9, weight_decay=0.0005)

    for epoch in range(1, 30 + 1):
        # -------------------my loss----------------------------
        # x, y = train(train_loader, model_ft, mining_loss, ce_loss, optimizer, epoch)
        # if multi-sphere
        train(train_loader, model_ft, mining_loss, ce_loss, optimizer, epoch)

        model_ft.module.save(save_path + 'mnface_' + str(epoch) +
                             '_checkpoints.pth')
        acc = lfw_eval.eval(model_path=save_path + 'mnface_' + str(epoch) +
                            '_checkpoints.pth')

        # if epoch in [1,2,3,10,15,20,30]:
        # pickle.dump(x, open("/home/baixy/Codes/class-invariant-loss/xarc"+str(epoch)+".pkl", 'wb'))
        # pickle.dump(y, open("/home/baixy/Codes/class-invariant-loss//yarc" + str(epoch) + ".pkl", 'wb'))
        # del x

        # #-------------------cos face--------------------------
        # train(train_loader, model_ft, MCP, criterion, optimizer, epoch)
        # model_ft.module.save(save_path + 'cosface_' + str(epoch) + '_checkpoints.pth')
        # acc, pred = lfw_eval.eval(save_path + 'cosface_' + str(epoch) + '_checkpoints.pth')

        writer.add_scalar('Test/LFWAcc', acc, epoch)

    # fig, (ax0, ax1, ax2) = plt.subplots(nrows=3)
    # ax0.hist(x1, 100, normed=1, histtype='bar', facecolor='yellowgreen', alpha=0.75)
    # ax1.hist(x10, 100, normed=1, histtype='bar', facecolor='pink', alpha=0.75)
    # ax2.hist(x20, 100, normed=1, histtype='bar', facecolor='yellowgreen', alpha=0.75)
    # plt.show()
    print('finished training')
    f.write("finished training" + '\n')
    f.close()
Пример #12
0
def main():
    # --------------------------------------model----------------------------------------
    # 调用 net.py 文件中的 sphere20() 网络
    model = sphere20()
    # DataParallel 的作用是让数据在多个 gpu 上运行
    # 改为 cpu 版本(因 mac 不支持gpu运行)
    # model = torch.nn.DataParallel(model).cuda()
    print(model)
    # 检测保存路径是否已存在,保存 checkpoint
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    model.module.save(args.save_path + 'CosFace_0_checkpoint.pth')

    # ------------------------------------load image---------------------------------------
    # 加载训练数据集
    train_loader = torch.utils.data.DataLoader(
        ImageList(
            root=args.root_path,
            fileList=args.train_list,
            # 进行图像预处理
            transform=transforms.Compose([
                # 以 0.5 的概率水平翻转给定的 PIL 图像
                transforms.RandomHorizontalFlip(),
                # 将一个 PIL 图像(H*W*C)在 【0,255】范围内转化为 torch.Tensor(C*H*W) 在 【0.0,1.0】范围内
                transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
                # 使用 均值 mean 和标准差 standard deviation 来标准化数据
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
                                     )  # range [0.0, 1.0] -> [-1.0,1.0]
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)
    # 打印 train dataset 的 length
    print('length of train Dataset: ' + str(len(train_loader.dataset)))
    # 打印 train dataset 的类别数目
    print('Number of Classses: ' + str(args.num_class))

    # --------------------------------loss function and optimizer-----------------------------
    # 实现 cos face 的核心部分,但是使用了 cuda
    # MCP = layer.MarginCosineProduct(512, args.num_class).cuda()
    MCP = layer.MarginCosineProduct(512, args.num_class)
    # MCP = layer.AngleLinear(512, args.num_class).cuda()
    # MCP = torch.nn.Linear(512, args.num_class, bias=False).cuda()

    # 修改为不用 cuda
    # criterion = torch.nn.CrossEntropyLoss().cuda()
    criterion = torch.nn.CrossEntropyLoss()
    # 使用(被优化的) SGD 优化器
    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': MCP.parameters()
    }],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # ----------------------------------------train----------------------------------------
    # lfw_eval.eval(args.save_path + 'CosFace_0_checkpoint.pth')
    # 开始训练
    # 训练 epoch 次,每完整训练一次,存储一次 checkpoint
    for epoch in range(1, args.epochs + 1):
        # scheduler.step()
        train(train_loader, model, MCP, criterion, optimizer, epoch)
        model.module.save(args.save_path + 'CosFace_' + str(epoch) +
                          '_checkpoint.pth')
        lfw_eval.eval(args.save_path + 'CosFace_' + str(epoch) +
                      '_checkpoint.pth')
    print('Finished Training')