コード例 #1
0
    def create_optimization(self):
        if self.args.loss_function == 'FocalLoss':
            self.loss = FocalLoss(gamma=self.args.gamma)
        else:
            self.loss = nn.CrossEntropyLoss()

        if self.args.cuda:
            self.loss.cuda()

        if self.args.classify:
            self.metric_fc = ArcMarginModel(self.args)
            self.optimizer = RMSprop(self.model.parameters(),
                                     self.args.learning_rate,
                                     momentum=self.args.momentum,
                                     weight_decay=self.args.weight_decay)

        else:
            self.optimizer = RMSprop([{
                'params': self.model.parameters()
            }, {
                'params': self.metric_fc.parameters()
            }],
                                     self.args.learning_rate,
                                     momentum=self.args.momentum,
                                     weight_decay=self.args.weight_decay)
コード例 #2
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = float('-inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            from mobilenet_v2 import MobileNetV2
            model = MobileNetV2()
        else:
            raise TypeError('network {} is not supported.'.format(
                args.network))

        metric_fc = ArcMarginModel(args)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    model = nn.DataParallel(model)
    metric_fc = nn.DataParallel(metric_fc)

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # Decay learning rate if there is no improvement for 2 consecutive epochs, and terminate training after 10
        if epochs_since_improvement == 10:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            checkpoint = 'BEST_checkpoint.tar'
            checkpoint = torch.load(checkpoint)
            model = checkpoint['model']
            metric_fc = checkpoint['metric_fc']
            optimizer = checkpoint['optimizer']

            adjust_learning_rate(optimizer, 0.5)

        # One epoch's training
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch)
        lr = optimizer.param_groups[0]['lr']
        print('\nCurrent effective learning rate: {}\n'.format(lr))
        # print('Step num: {}\n'.format(optimizer.step_num))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)
        writer.add_scalar('model/learning_rate', lr, epoch)

        if epoch % 5 == 0:
            # One epoch's validation
            megaface_acc = megaface_test(model)
            writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch)

            # Check if there was an improvement
            is_best = megaface_acc > best_acc
            best_acc = max(megaface_acc, best_acc)
            if not is_best:
                epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" %
                      (epochs_since_improvement, ))
            else:
                epochs_since_improvement = 0

            # Save checkpoint
            save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                            optimizer, best_acc, is_best)
コード例 #3
0
def test():
    filename = 'metric_fc.pt'
    metric_fc = ArcMarginModel()
    metric_fc.load_state_dict(torch.load(filename))
    metric_fc = metric_fc.to('cpu')
    metric_fc.eval()

    train_batch_size = 30
    eval_batch_size = 30

    scripted_float_model_file = 'mobilenet_quantization_scripted.pth'
    scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'

    data_loader, data_loader_test = prepare_data_loaders()
    criterion = nn.CrossEntropyLoss()

    filename = 'image_matching_mobile.pt'
    print('loading {}...'.format(filename))
    start = time.time()
    float_model = load_model(filename)
    print('elapsed {} sec'.format(time.time() - start))

    print('\n Inverted Residual Block: Before fusion \n\n', float_model.features[1].conv)
    float_model.eval()

    # Fuses modules
    float_model.fuse_model()

    # Note fusion of Conv+BN+Relu and Conv+Relu
    print('\n Inverted Residual Block: After fusion\n\n', float_model.features[1].conv)

    num_eval_batches = 10

    print("Size of baseline model")
    print_size_of_model(float_model)

    top1 = evaluate(float_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches)
    print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg))
    torch.jit.save(torch.jit.script(float_model), scripted_float_model_file)

    # 4. Post-training static quantization
    print(bcolors.HEADER + '\nPost-training static quantization' + bcolors.ENDC)
    num_calibration_batches = 10

    myModel = load_model(filename).to('cpu')
    myModel.eval()

    # Fuse Conv, bn and relu
    myModel.fuse_model()

    # Specify quantization configuration
    # Start with simple min/max range estimation and per-tensor quantization of weights
    myModel.qconfig = torch.quantization.default_qconfig
    print(myModel.qconfig)
    torch.quantization.prepare(myModel, inplace=True)

    # Calibrate first
    print('Post Training Quantization Prepare: Inserting Observers')
    print('\n Inverted Residual Block:After observer insertion \n\n', myModel.features[1].conv)

    # Calibrate with the training set
    print('Calibrate with the training set')
    evaluate(myModel, metric_fc, criterion, data_loader, neval_batches=num_calibration_batches)
    print('Post Training Quantization: Calibration done')

    # Convert to quantized model
    torch.quantization.convert(myModel, inplace=True)
    print('Post Training Quantization: Convert done')
    print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n',
          myModel.features[1].conv)

    print("Size of model after quantization")
    print_size_of_model(myModel)

    top1 = evaluate(myModel, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches)
    print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg))

    per_channel_quantized_model = load_model(filename)
    per_channel_quantized_model.eval()
    per_channel_quantized_model.fuse_model()
    per_channel_quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    print(per_channel_quantized_model.qconfig)

    torch.quantization.prepare(per_channel_quantized_model, inplace=True)
    print('Calibrate with the training set')
    evaluate(per_channel_quantized_model, metric_fc, criterion, data_loader, num_calibration_batches)
    torch.quantization.convert(per_channel_quantized_model, inplace=True)
    top1 = evaluate(per_channel_quantized_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches)
    print('Evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg))
    torch.jit.save(torch.jit.script(per_channel_quantized_model), scripted_quantized_model_file)

    # Speedup from quantization
    print(bcolors.HEADER + '\nSpeedup from quantization' + bcolors.ENDC)
    run_benchmark(scripted_float_model_file, data_loader_test)
    run_benchmark(scripted_quantized_model_file, data_loader_test)

    # 5. Quantization-aware training
    print(bcolors.HEADER + '\nQuantization-aware training' + bcolors.ENDC)
    qat_model = load_model(filename)
    qat_model.fuse_model()

    optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.0001)
    qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

    torch.quantization.prepare_qat(qat_model, inplace=True)
    print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',
          qat_model.features[1].conv)

    num_train_batches = 20

    # Train and check accuracy after each epoch
    for nepoch in range(8):
        train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
        if nepoch > 3:
            # Freeze quantizer parameters
            qat_model.apply(torch.quantization.disable_observer)
        if nepoch > 2:
            # Freeze batch norm mean and variance estimates
            qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

        # Check the accuracy after each epoch
        quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
        quantized_model.eval()
        top1 = evaluate(quantized_model, metric_fc, criterion, data_loader_test, neval_batches=num_eval_batches)
        print('Epoch %d :Evaluation accuracy on %d images, %2.2f' % (
            nepoch, num_eval_batches * eval_batch_size, top1.avg))
コード例 #4
0
ファイル: train.py プロジェクト: hanson-young/InsightFace-v2
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = float('-inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNetV2()
        else:
            raise TypeError('network {} is not supported.'.format(
                args.network))

        # print(model)
        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)

    scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        train_loss, train_acc = train(train_loader=train_loader,
                                      model=model,
                                      metric_fc=metric_fc,
                                      criterion=criterion,
                                      optimizer=optimizer,
                                      epoch=epoch,
                                      logger=logger)

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_acc', train_acc, epoch)

        # One epoch's validation
        lfw_acc, threshold = lfw_test(model)
        writer.add_scalar('model/valid_acc', lfw_acc, epoch)
        writer.add_scalar('model/valid_thres', threshold, epoch)

        # Check if there was an improvement
        is_best = lfw_acc > best_acc
        best_acc = max(lfw_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)

        scheduler.step(epoch)
コード例 #5
0
ファイル: train.py プロジェクト: visionNoob/InsightFace-v2
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNet(1.0)
        elif args.network == 'mr18':
            print("mr18")
            model = myResnet18()
        else:
            model = resnet_face18(args.use_se)
        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        scheduler.step()

        if args.full_log:
            lfw_acc, threshold = lfw_test(model)
            writer.add_scalar('LFW_Accuracy', lfw_acc, epoch)
            full_log(epoch)

        start = datetime.now()
        # One epoch's training
        train_loss, train_top5_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger,
                                            writer=writer)

        writer.add_scalar('Train_Loss', train_loss, epoch)
        writer.add_scalar('Train_Top5_Accuracy', train_top5_accs, epoch)

        end = datetime.now()
        delta = end - start
        print('{} seconds'.format(delta.seconds))

        # One epoch's validation
        lfw_acc, threshold = lfw_test(model)
        writer.add_scalar('LFW Accuracy', lfw_acc, epoch)

        # Check if there was an improvement
        is_best = lfw_acc > best_acc
        best_acc = max(lfw_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)
コード例 #6
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = float('-inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        else:
            raise TypeError('network {} is not supported.'.format(
                args.network))

        if args.pretrained:
            model.load_state_dict(torch.load('insight-face-v3.pt'))

        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        nesterov=True,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma)
    else:
        criterion = nn.CrossEntropyLoss()

    # Custom dataloaders
    # train_dataset = ArcFaceDataset('train')
    # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
    #                                            num_workers=num_workers)
    train_dataset = ArcFaceDatasetBatched('train', img_batch_size)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size //
                                               img_batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               collate_fn=batched_collate_fn)

    scheduler = MultiStepLR(optimizer, milestones=[8, 16, 24, 32], gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        lr = optimizer.param_groups[0]['lr']
        logger.info('\nCurrent effective learning rate: {}\n'.format(lr))
        # print('Step num: {}\n'.format(optimizer.step_num))
        writer.add_scalar('model/learning_rate', lr, epoch)

        # One epoch's training
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch)

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)

        scheduler.step(epoch)

        if args.eval_ds == "LFW":
            from lfw_eval import lfw_test

            # One epochs's validata
            accuracy, threshold = lfw_test(model)

        elif args.eval_ds == "Megaface":
            from megaface_eval import megaface_test

            accuracy = megaface_test(model)

        else:
            accuracy = -1

        writer.add_scalar('model/evaluation_accuracy', accuracy, epoch)

        # Check if there was an improvement
        is_best = accuracy > best_acc
        best_acc = max(accuracy, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            logger.info("\nEpochs since last improvement: %d\n" %
                        (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best, scheduler)
コード例 #7
0
ファイル: train.py プロジェクト: igor-lakoza/InsightFace-v3
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNet(1.0)
        else:
            model = resnet_face18(args.use_se)
        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
            #                             lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)
            optimizer = InsightFaceOptimizer(
                torch.optim.SGD([{
                    'params': model.parameters()
                }, {
                    'params': metric_fc.parameters()
                }],
                                lr=args.lr,
                                momentum=args.mom,
                                weight_decay=args.weight_decay))
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = ArcFaceDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger)
        print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)
        writer.add_scalar('model/learning_rate', optimizer.lr, epoch)

        # One epoch's validation
        megaface_acc = megaface_test(model)
        writer.add_scalar('model/megaface_accuracy', megaface_acc, epoch)

        # Check if there was an improvement
        is_best = megaface_acc > best_acc
        best_acc = max(megaface_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)
def train_net(args):
    torch.manual_seed(7)  #torch的随机种子,在torch.randn使用
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()  #tensorboard
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        if args.network == 'r18':
            model = resnet18(args)
        elif args.network == 'r34':
            model = resnet34(args)
        elif args.network == 'r50':
            model = resnet50(args)
        elif args.network == 'r101':
            model = resnet101(args)
        elif args.network == 'r152':
            model = resnet152(args)
        elif args.network == 'mobile':
            model = MobileNet(1.0)
        else:
            model = resnet_face18(args.use_se)

        model = nn.DataParallel(model)
        metric_fc = ArcMarginModel(args)
        metric_fc = nn.DataParallel(metric_fc)

        if args.optimizer == 'sgd':
            # optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
            #                             lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)
            optimizer = InsightFaceOptimizer(
                torch.optim.SGD([{
                    'params': model.parameters()
                }, {
                    'params': metric_fc.parameters()
                }],
                                lr=args.lr,
                                momentum=args.mom,
                                weight_decay=args.weight_decay))
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        #这里还需要自己加载进去
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    if args.focal_loss:
        criterion = FocalLoss(gamma=args.gamma).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    train_dataset = Dataset(root=args.train_path,
                            phase='train',
                            input_shape=(3, 112, 112))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        # 这里写一个训练函数十分简练,值得学习
        train_loss, train_top1_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger)
        print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr))
        print('Step num: {}\n'.format(optimizer.step_num))

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top1_accs, epoch)
        writer.add_scalar('model/learning_rate', optimizer.lr, epoch)

        # Save checkpoint
        if epoch % 10 == 0:
            save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                            optimizer, best_acc)
コード例 #9
0
def train_net(args):
    torch.manual_seed(9527)
    np.random.seed(9527)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_acc = 0
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        model = MobileNetMatchModel()
        metric_fc = ArcMarginModel(args)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay,
                                        nesterov=True)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.parameters()
            }, {
                'params': metric_fc.parameters()
            }],
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

        model = nn.DataParallel(model)
        metric_fc = nn.DataParallel(metric_fc)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        metric_fc = checkpoint['metric_fc']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)
    metric_fc = metric_fc.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    cudnn.benchmark = True

    # Custom dataloaders
    train_loader = torch.utils.data.DataLoader(FrameDataset('train'),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)

    # scheduler = MultiStepLR(optimizer, milestones=[5, 10, 15, 20], gamma=0.1)
    scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        # One epoch's training
        train_loss, train_top5_accs = train(train_loader=train_loader,
                                            model=model,
                                            metric_fc=metric_fc,
                                            criterion=criterion,
                                            optimizer=optimizer,
                                            epoch=epoch,
                                            logger=logger)

        writer.add_scalar('model/train_loss', train_loss, epoch)
        writer.add_scalar('model/train_accuracy', train_top5_accs, epoch)

        lr = optimizer.param_groups[0]['lr']
        print('\nLearning rate: {}'.format(lr))
        writer.add_scalar('model/learning_rate', lr, epoch)

        # One epoch's validation
        val_acc, thres = test(model)
        writer.add_scalar('model/valid_accuracy', val_acc, epoch)
        writer.add_scalar('model/valid_threshold', thres, epoch)

        scheduler.step(epoch)

        # Check if there was an improvement
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, metric_fc,
                        optimizer, best_acc, is_best)