Esempio n. 1
0
class FACERecognition(object):
    def __init__(self, cfg, inference=False, threshold=0.5):
        self.device = torch.device(
            "cuda") if cfg.MODEL.DEVICE == 'cuda' else torch.device("cpu")

        if not inference:
            print('load training data')
            self.dataloader, class_num = get_train_loader(cfg)

            print('load testing data')
            if cfg.TEST.MODE == 'face':
                self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(
                    self.dataloader.dataset.root.parent)
            else:
                pairs = read_pairs(
                    os.path.join(cfg.DATASETS.FOLDER, 'pairs.txt'))

                self.data, self.data_issame = get_paths(
                    os.path.join(cfg.DATASETS.FOLDER, 'test'), pairs)

            print('load model')
            self.model = Baseline(cfg)
            self.model = self.model.to(self.device)
            self.load_state(cfg)
            if cfg.SOLVER.OPT == 'SGD':
                self.optimizer = optim.SGD(
                    [{
                        'params': self.model.parameters()
                    }],
                    lr=cfg.SOLVER.BASE_LR,
                    momentum=cfg.SOLVER.MOMENTUM,
                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
            else:
                self.optimizer = optim.Adam(
                    [{
                        'params': self.model.parameters()
                    }],
                    lr=cfg.SOLVER.BASE_LR,
                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)

            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=cfg.SOLVER.MAX_EPOCH,
                eta_min=cfg.SOLVER.ETA_MIN_LR)
            checkpoints = cfg.CHECKPOINT.SAVE_DIR
            os.makedirs(checkpoints, exist_ok=True)

            self.best_score = 0.
            self.best_threshold = 0.
        else:
            self.device = torch.device(
                "cuda") if cfg.TEST.DEVICE == 'cuda' else torch.device("cpu")
            print('load model')
            self.model = Baseline(cfg)
            self.model = self.model.to(self.device)
            self.load_state(cfg)
            self.threshold = threshold
            self.test_transform = trans.Compose([
                trans.ToTensor(),
                trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])

    def load_state(self, cfg):
        if cfg.CHECKPOINT.RESTORE:
            os.makedirs(cfg.CHECKPOINT.SAVE_DIR, exist_ok=True)
            weights_path = osp.join(cfg.CHECKPOINT.SAVE_DIR,
                                    cfg.CHECKPOINT.RESTORE_MODEL)
            self.model.load_state_dict(torch.load(weights_path,
                                                  map_location=self.device),
                                       strict=False)
            print('loaded model {}'.format(weights_path))

    def save_state(self, cfg, save_name):
        save_path = Path(cfg.CHECKPOINT.SAVE_DIR)
        torch.save(self.model.state_dict(), save_path / save_name)

    def evaluate(self, cfg, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), cfg.MODEL.HEADS.EMBEDDING_DIM])
        batch_size = cfg.SOLVER.IMS_PER_BATCH
        with torch.no_grad():
            while idx + batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(self.device)) + self.model(
                        fliped.to(self.device))
                    embeddings[idx:idx + batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + batch_size] = self.model(
                        batch.to(self.device)).cpu()
                idx += batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(self.device)) + self.model(
                        fliped.to(self.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(self.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = scores(embeddings, issame,
                                                     nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def train(self, cfg):
        self.model.train()
        step = 0
        for e in range(cfg.SOLVER.MAX_EPOCH):
            for data, labels in tqdm(self.dataloader,
                                     desc=f"Epoch {e}/{cfg.SOLVER.MAX_EPOCH}",
                                     ascii=True,
                                     total=len(self.dataloader)):
                data = data.to(self.device)
                labels = labels.to(self.device)

                self.optimizer.zero_grad()
                loss_dict = self.model(data, labels)
                losses = sum(loss_dict.values())
                losses.backward()
                self.optimizer.step()

                accuracy = 0.0
                if step % cfg.TEST.SHOW_PERIOD == 0:
                    print(
                        f"Epoch {e}/{cfg.SOLVER.MAX_EPOCH}, Step {step}, CE Loss: {loss_dict.get('loss_cls')}, Triplet Loss: {loss_dict.get('loss_triplet')}, Circle Loss: {loss_dict.get('loss_circle')}, Cos Loss: {loss_dict.get('loss_cosface')}"
                    )
                if step % cfg.TEST.EVAL_PERIOD == 0:
                    if cfg.TEST.MODE == 'face':
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                            cfg, self.agedb_30, self.agedb_30_issame)
                        print("dataset {}, acc {}, best_threshold {}".format(
                            'agedb_30',
                            accuracy,
                            best_threshold,
                        ))
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                            cfg, self.lfw, self.lfw_issame)
                        print("dataset {}, acc {}, best_threshold {}".format(
                            'lfw', accuracy, best_threshold))
                        if accuracy > self.best_score:
                            self.best_score = accuracy
                            self.best_threshold = best_threshold
                            self.save_state(
                                cfg,
                                'model_{}_best_accuracy:{:.3f}_step:{}.pth'.
                                format(get_time(), accuracy, step))
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                            cfg, self.cfp_fp, self.cfp_fp_issame)
                        print("dataset {}, acc {}, best_threshold {}".format(
                            'cfp_fp', accuracy, best_threshold))
                    else:
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                            cfg, self.data, self.data_issame)
                        print("dataset {}, acc {}, best_threshold {}".format(
                            'test', accuracy, best_threshold))
                        if accuracy > self.best_score:
                            self.best_score = accuracy
                            self.best_threshold = best_threshold
                            self.save_state(
                                cfg,
                                'model_{}_best_accuracy:{:.3f}_step:{}.pth'.
                                format(get_time(), accuracy, step))
                    self.model.train()
                if step % cfg.TEST.SAVE_PERIOD == 0:
                    self.save_state(
                        cfg, 'model_{}_accuracy:{:.3f}_step:{}.pth'.format(
                            get_time(), accuracy, step))

                step += 1

            self.save_state(cfg,
                            'model_{}_step:{}.pth'.format(get_time(), step))

            self.scheduler.step()

    def infer(self, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    self.test_transform(img).to(self.device).unsqueeze(0))
                emb_mirror = self.model(
                    self.test_transform(mirror).to(self.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        self.test_transform(img).to(self.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        if isinstance(target_embs, list):
            tmp = []
            for img in target_embs:
                if tta:
                    mirror = trans.functional.hflip(img)
                    tmp = self.model(
                        self.test_transform(img).to(self.device).unsqueeze(0))
                    tmp_mirror = self.model(
                        self.test_transform(mirror).to(
                            self.device).unsqueeze(0))
                    tmp.append(l2_norm(tmp + tmp_mirror))
                else:
                    tmp.append(
                        self.model(
                            self.test_transform(img).to(
                                self.device).unsqueeze(0)))
            target_embs = torch.cat(tmp)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum
Esempio n. 2
0
def test(test_iter,
         test_loader,
         weigths_path,
         num_epoch,
         model_type=0,
         threshold=0.7):
    if model_type == 0:
        model = Baseline(in_channels=7,
                         out_channels_1=7,
                         out_channels_2=7,
                         KT_1=4,
                         KT_2=3,
                         num_nodes=39,
                         batch_size=32,
                         frames=33,
                         frames_0=12,
                         num_generator=10)
    elif model_type == 1:
        model = GAT()
    elif model_type == 2:
        model = GAT_edge()
    else:
        raise
    model.load_state_dict(torch.load(weigths_path))
    # model = nn.DataParallel(model)
    model = model.cuda()
    model.eval()
    accu = 0
    true_labels = np.array([])
    pred_labels = np.array([])
    label_float = np.array([])
    for epoch in range(num_epoch):
        try:
            Y, infos, labels = next(test_iter)
            Y, infos, labels = Y.float().cuda(), infos.float().cuda(
            ), labels.type(torch.int32)
        except StopIteration:
            batch_iterator = iter(test_loader)
            Y, infos, labels = next(batch_iterator)
            Y, infos, labels = Y.float().cuda(), infos.float().cuda(
            ), labels.type(torch.int32)
        label_predicted = model(Y, infos)
        label_float = np.concatenate(
            (label_float, label_predicted.cpu().reshape((1, -1))[0]))
        labels_threshold = label_predicted > threshold
        true_labels = np.concatenate((true_labels, labels.reshape((1, -1))[0]))
        pred_labels = np.concatenate(
            (pred_labels, labels_threshold.cpu().reshape((1, -1))[0]))
        all_right = 1 - torch.mean(
            (labels ^ labels_threshold.cpu()).type(torch.float32))
        print('epoch:{}, accu:{}'.format(epoch, all_right))
        accu += all_right
    accu /= num_epoch
    plot(confusion_matrix(true_labels, pred_labels))
    plt.figure(figsize=(20, 8), dpi=100)
    distance = 0.1
    group_num = int((max(label_float) - min(label_float)) / distance)
    plt.hist(label_float, bins=group_num)
    # plt.xticks(range(min(label_float), max(label_float))[::2])
    plt.grid(linestyle="--", alpha=0.5)
    plt.xlabel("label output")
    plt.ylabel("frequency")
    plt.savefig('./data/frequency.png')
    return accu
Esempio n. 3
0
def train():
    if args.model is 'baseline':
        net = Baseline(in_channels=7,
                       out_channels_1=7,
                       out_channels_2=7,
                       KT_1=4,
                       KT_2=3,
                       num_nodes=39,
                       batch_size=args.batch_size,
                       frames=33,
                       frames_0=12,
                       num_generator=10)
    elif args.model is 'GAT':
        net = GAT()
    elif args.model is 'GAT_edge':
        net = GAT_edge()
    else:
        print('must choose a model in the choices')
        raise

    if args.init_type is not None:
        try:
            init_weights(net, init_type=args.init_type)
        except:
            sys.exit('Load Network  <==> Init_weights error!')

    # net = nn.DataParallel(net)
    net = net.cuda()

    accuracy = 0
    train_file = 4
    train_amount = 6400  # 8144
    eval_amount = 3200
    num_epoch = train_amount // args.batch_size * train_file
    train_data = trainSet(39, train_amount, [0, 1, 2, 3])
    trainloader = DataLoader(train_data,
                             batch_size=args.batch_size,
                             shuffle=True)
    batch_loader = iter(trainloader)
    eval_data = trainSet(39, eval_amount, 4)
    evalloader = DataLoader(eval_data,
                            batch_size=args.batch_size,
                            shuffle=True)
    eval_iter = iter(evalloader)
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    net.train()
    #  train ------------------------------------------------
    print('---- epoch start ----')
    start_time = time.time()
    for epoch in range(num_epoch):
        # load train data
        try:
            Y, infos, labels = next(batch_loader)
            Y, infos, labels = Y.float().cuda(), infos.float().cuda(
            ), labels.float().cuda()
        except StopIteration:
            batch_iterator = iter(trainloader)
            Y, infos, labels = next(batch_iterator)
            Y, infos, labels = Y.float().cuda(), infos.float().cuda(
            ), labels.float().cuda()
        label_predicted = net(Y, infos)
        # loss = MSE_loss(label_predicted, labels.long())
        # criteria = nn.BCELoss()
        loss = MSE_loss(label_predicted, labels.long())
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), max_norm=20, norm_type=2)
        optimizer.step()
        print('epoch:{}/{} | loss:{:.4f}'.format(epoch + 1, num_epoch,
                                                 loss.item()))
        with open(args.log_folder + 'loss.log', mode='a') as f:
            f.writelines('\n epoch:{}/{} | loss:{:.4f}'.format(
                epoch + 1, num_epoch, loss.item()))

        #  eval ------------------------------------------------
        if epoch % 20 == 0:
            net.eval()
            accu, _ = evaluate(model=net,
                               data_iter=eval_iter,
                               data_loader=evalloader,
                               num_epoch=10)
            print('accuracy:{}'.format(accu))
            with open(args.log_folder + 'accu.log', mode='a') as f:
                f.writelines('\n eval epoch:{} | loss:{:.4f}'.format(
                    epoch // 20 + 1, loss.item()))
            if accu > accuracy:
                torch.save(
                    net.state_dict(),
                    args.save_folder + '{}_{}.pth'.format(args.model, accu))
                accuracy = accu

    stop_time = time.time()
    print("program run for {} s".format(stop_time - start_time))