log.info('Epoch length: %d' % len(train_loader))
    for epoch_num in range(args.epoch_total_num):
        log.info('Epoch: %d' % epoch_num)

        scheduler.step()
        if epoch_num > 5 or args.compr_config:
            set_dropout_fn(0.)

        for i, data in enumerate(train_loader, 0):
            iteration = epoch_num * len(train_loader) + i

            if iteration % args.val_step == 0:
                snapshot_name = osp.join(args.snap_folder,
                                         args.snap_prefix + '_{0}.pt'.format(iteration))
                log.info('Saving Snapshot: ' + snapshot_name)
                save_model_cpu(model, optimizer, snapshot_name, epoch_num)

                model.eval()
                log.info('Evaluating Snapshot: ' + snapshot_name)
                avg_err, per_point_avg_err, failures_rate = evaluate(train_loader, model)
                weights = per_point_avg_err / np.sum(per_point_avg_err)
                criterion.set_weights(weights)
                log.info(str(weights))
                log.info('Avg train error: {}'.format(avg_err))
                log.info('Train failure rate: {}'.format(failures_rate))
                writer.add_scalar('Quality/Avg_error', avg_err, iteration)
                writer.add_scalar('Quality/Failure_rate', failures_rate, iteration)
                writer.add_scalar('Epoch', epoch_num, iteration)
                model.train()

            data, gt_landmarks = data['img'].cuda(), data['landmarks'].cuda()
def train(args):
    """Launches training of landmark regression model"""
    if args.dataset == 'vgg':
        drops_schedule = [1, 6, 9, 13]
        dataset = VGGFace2(args.train,
                           args.t_list,
                           args.t_land,
                           landmarks_training=True)
    elif args.dataset == 'celeba':
        drops_schedule = [10, 20]
        dataset = CelebA(args.train, args.t_land)
    else:
        drops_schedule = [90, 140, 200]
        dataset = NDG(args.train, args.t_land)

    if dataset.have_landmarks:
        log.info('Use alignment for the train data')
        dataset.transform = transforms.Compose([
            landmarks_augmentation.Rescale((56, 56)),
            landmarks_augmentation.Blur(k=3, p=.2),
            landmarks_augmentation.HorizontalFlip(p=.5),
            landmarks_augmentation.RandomRotate(50),
            landmarks_augmentation.RandomScale(.8, .9, p=.4),
            landmarks_augmentation.RandomCrop(48),
            landmarks_augmentation.ToTensor(switch_rb=True)
        ])
    else:
        log.info('Error: training dataset has no landmarks data')
        exit()

    train_loader = DataLoader(dataset,
                              batch_size=args.train_batch_size,
                              num_workers=4,
                              shuffle=True)
    writer = SummaryWriter(
        './logs_landm/{:%Y_%m_%d_%H_%M}_'.format(datetime.datetime.now()) +
        args.snap_prefix)
    model = LandmarksNet()

    set_dropout_fn = model.set_dropout_ratio

    if args.snap_to_resume is not None:
        log.info('Resuming snapshot ' + args.snap_to_resume + ' ...')
        model = load_model_state(model,
                                 args.snap_to_resume,
                                 args.device,
                                 eval_state=False)
        model = torch.nn.DataParallel(model, device_ids=[args.device])
    else:
        model = torch.nn.DataParallel(model, device_ids=[args.device])
        model.cuda()
        model.train()
        cudnn.enabled = True
        cudnn.benchmark = True

    log.info('Face landmarks model:')
    log.info(model)

    criterion = AlignmentLoss('wing')
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, drops_schedule)

    log.info('Epoch length: %d' % len(train_loader))
    for epoch_num in range(args.epoch_total_num):
        log.info('Epoch: %d' % epoch_num)

        if epoch_num > 5:
            set_dropout_fn(0.)

        for i, data in enumerate(train_loader, 0):
            iteration = epoch_num * len(train_loader) + i

            if iteration % args.val_step == 0 and iteration > 0:
                snapshot_name = osp.join(
                    args.snap_folder,
                    args.snap_prefix + '_{0}.pt'.format(iteration))
                log.info('Saving Snapshot: ' + snapshot_name)
                save_model_cpu(model, optimizer, snapshot_name, epoch_num)

                model.eval()
                log.info('Evaluating Snapshot: ' + snapshot_name)
                avg_err, per_point_avg_err, failures_rate = evaluate(
                    train_loader, model)
                weights = per_point_avg_err / np.sum(per_point_avg_err)
                criterion.set_weights(weights)
                log.info(str(weights))
                log.info('Avg train error: {}'.format(avg_err))
                log.info('Train failure rate: {}'.format(failures_rate))
                writer.add_scalar('Quality/Avg_error', avg_err, iteration)
                writer.add_scalar('Quality/Failure_rate', failures_rate,
                                  iteration)
                writer.add_scalar('Epoch', epoch_num, iteration)
                model.train()

            data, gt_landmarks = data['img'].cuda(), data['landmarks'].cuda()
            predicted_landmarks = model(data)

            optimizer.zero_grad()
            loss = criterion(predicted_landmarks, gt_landmarks)
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                log.info('Iteration %d, Loss: %.4f' % (iteration, loss))
                log.info('Learning rate: %f' % scheduler.get_lr()[0])
                writer.add_scalar('Loss/train_loss', loss.item(), iteration)
                writer.add_scalar('Learning_rate',
                                  scheduler.get_lr()[0], iteration)
        scheduler.step()
Exemple #3
0
def train(args):
    """Performs training of a face recognition network"""
    input_size = models_backbones[args.model]().get_input_res()
    if args.train_dataset == 'vgg':
        assert args.t_list
        dataset = VGGFace2(args.train, args.t_list, args.t_land)
    elif args.train_dataset == 'imdbface':
        dataset = IMDBFace(args.train, args.t_list)
    elif args.train_dataset == 'trp':
        dataset = TrillionPairs(args.train, args.t_list)
    else:
        dataset = MSCeleb1M(args.train, args.t_list)

    if dataset.have_landmarks:
        log.info('Use alignment for the train data')
        dataset.transform = t.Compose([
            augm.HorizontalFlipNumpy(p=.5),
            augm.CutOutWithPrior(p=0.05, max_area=0.1),
            augm.RandomRotationNumpy(10, p=.95),
            augm.ResizeNumpy(input_size),
            augm.BlurNumpy(k=5, p=.2),
            augm.NumpyToTensor(switch_rb=True)
        ])
    else:
        dataset.transform = t.Compose([
            augm.ResizeNumpy(input_size),
            augm.HorizontalFlipNumpy(),
            augm.RandomRotationNumpy(10),
            augm.NumpyToTensor(switch_rb=True)
        ])

    if args.weighted:
        train_weights = dataset.get_weights()
        train_weights = torch.DoubleTensor(train_weights)
        sampler = torch.utils.data.sampler.WeightedRandomSampler(
            train_weights, len(train_weights))
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.train_batch_size,
            sampler=sampler,
            num_workers=3,
            pin_memory=False)
    else:
        train_loader = DataLoader(dataset,
                                  batch_size=args.train_batch_size,
                                  num_workers=4,
                                  shuffle=True)

    lfw = LFW(args.val, args.v_list, args.v_land)
    if lfw.use_landmarks:
        log.info('Use alignment for the test data')
        lfw.transform = t.Compose(
            [augm.ResizeNumpy(input_size),
             augm.NumpyToTensor(switch_rb=True)])
    else:
        lfw.transform = t.Compose([
            augm.ResizeNumpy((160, 160)),
            augm.CenterCropNumpy(input_size),
            augm.NumpyToTensor(switch_rb=True)
        ])

    log_path = './logs/{:%Y_%m_%d_%H_%M}_{}'.format(datetime.datetime.now(),
                                                    args.snap_prefix)
    writer = SummaryWriter(log_path)

    if not osp.exists(args.snap_folder):
        os.mkdir(args.snap_folder)

    model = models_backbones[args.model](embedding_size=args.embed_size,
                                         num_classes=dataset.get_num_classes(),
                                         feature=False)
    if args.snap_to_resume is not None:
        log.info('Resuming snapshot ' + args.snap_to_resume + ' ...')
        model = load_model_state(model,
                                 args.snap_to_resume,
                                 args.devices[0],
                                 eval_state=False)
        model = torch.nn.DataParallel(model, device_ids=args.devices)
    else:
        model = torch.nn.DataParallel(model,
                                      device_ids=args.devices,
                                      output_device=args.devices[0])
        model.cuda()
        model.train()
        cudnn.benchmark = True

    log.info('Face Recognition model:')
    log.info(model)

    if args.mining_type == 'focal':
        softmax_criterion = AMSoftmaxLoss(gamma=args.gamma,
                                          m=args.m,
                                          margin_type=args.margin_type,
                                          s=args.s)
    else:
        softmax_criterion = AMSoftmaxLoss(t=args.t,
                                          m=0.35,
                                          margin_type=args.margin_type,
                                          s=args.s)
    aux_losses = MetricLosses(dataset.get_num_classes(), args.embed_size,
                              writer)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [3, 6, 9, 13])
    for epoch_num in range(args.epoch_total_num):
        scheduler.step()
        if epoch_num > 6:
            model.module.set_dropout_ratio(0.)
        classification_correct = 0
        classification_total = 0

        for i, data in enumerate(train_loader, 0):
            iteration = epoch_num * len(train_loader) + i

            if iteration % args.val_step == 0:
                snapshot_name = osp.join(
                    args.snap_folder,
                    args.snap_prefix + '_{0}.pt'.format(iteration))
                if iteration > 0:
                    log.info('Saving Snapshot: ' + snapshot_name)
                    save_model_cpu(model, optimizer, snapshot_name, epoch_num)

                log.info('Evaluating Snapshot: ' + snapshot_name)
                model.eval()
                same_acc, diff_acc, all_acc, auc = evaluate(
                    args,
                    lfw,
                    model,
                    compute_embeddings_lfw,
                    args.val_batch_size,
                    verbose=False)

                model.train()

                log.info('Validation accuracy: {0:.4f}, {1:.4f}'.format(
                    same_acc, diff_acc))
                log.info('Validation accuracy mean: {0:.4f}'.format(all_acc))
                log.info('Validation AUC: {0:.4f}'.format(auc))
                writer.add_scalar('Accuracy/Val_same_accuracy', same_acc,
                                  iteration)
                writer.add_scalar('Accuracy/Val_diff_accuracy', diff_acc,
                                  iteration)
                writer.add_scalar('Accuracy/Val_accuracy', all_acc, iteration)
                writer.add_scalar('Accuracy/AUC', auc, iteration)

            data, label = data['img'], data['label'].cuda()
            features, sm_outputs = model(data)

            optimizer.zero_grad()
            aux_losses.init_iteration()
            aux_loss, aux_log = aux_losses(features, label, epoch_num,
                                           iteration)
            loss_sm = softmax_criterion(sm_outputs, label)
            loss = loss_sm + aux_loss
            loss.backward()
            aux_losses.end_iteration()
            optimizer.step()

            _, predicted = torch.max(sm_outputs.data, 1)
            classification_total += int(label.size(0))
            classification_correct += int(torch.sum(predicted.eq(label)))
            train_acc = float(classification_correct) / classification_total

            if i % 10 == 0:
                log.info('Iteration %d, Softmax loss: %.4f, Total loss: %.4f' %
                         (iteration, loss_sm, loss) + aux_log)
                log.info('Learning rate: %f' % scheduler.get_lr()[0])
                writer.add_scalar('Loss/train_loss', loss, iteration)
                writer.add_scalar('Loss/softmax_loss', loss_sm, iteration)
                writer.add_scalar('Learning_rate',
                                  scheduler.get_lr()[0], iteration)
                writer.add_scalar('Accuracy/classification', train_acc,
                                  iteration)