class face_learner(object):
    def __init__(self, conf, inference=False):
        print(conf)
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.kernel],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            print(self.optimizer)
            #             self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 10
            self.save_every = len(self.loader) // 5
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(
                self.loader.dataset.root.parent)
        else:
            self.threshold = conf.threshold

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(save_path / 'model_{}'.format(fixed_str)))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
#         self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
#         self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
#         self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader)):
                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.agedb_30, self.agedb_30_issame)
                    self.board_val('agedb_30', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.cfp_fp, self.cfp_fp_issame)
                    self.board_val('cfp_fp', accuracy, best_threshold,
                                   roc_curve_tensor)
                    self.model.train()
                if self.step % self.save_every == 0 and self.step != 0:
                    self.save_state(conf, accuracy)

                self.step += 1

        self.save_state(conf, accuracy, to_save_folder=True, extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum
示例#2
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        print(conf)
        # self.loader, self.class_num = construct_msr_dataset(conf)
        self.loader, self.class_num = get_train_loader(conf)
        self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode)
        print('{}_{} model generated'.format(conf.net_mode, conf.net_depth))

        if not inference:
            self.milestones = conf.milestones

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head = QAMFace(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)
            self.focalLoss = FocalLoss()

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            self.optimizer = optim.SGD(
                [{
                    'params': paras_wo_bn + [self.head.kernel],
                    'weight_decay': 5e-4
                }, {
                    'params': paras_only_bn
                }],
                lr=conf.lr,
                momentum=conf.momentum)
            print(self.optimizer)
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.board_loss_every = len(self.loader) // 1000
            self.evaluate_every = len(self.loader) // 10
            self.save_every = len(self.loader) // 2
        else:
            self.threshold = conf.threshold

        # 多GPU训练
        self.model = torch.nn.DataParallel(self.model)
        self.model.to(conf.device)
        self.head = torch.nn.DataParallel(self.head)
        self.head = self.head.to(conf.device)

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        print('resume model from ' + fixed_str)
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(save_path / 'model_{}'.format(fixed_str)))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))

    def board_val(self,
                  db_name,
                  accuracy,
                  best_threshold=0,
                  roc_curve_tensor=0):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            # manually decay lr
            if e in self.milestones:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader)):
                imgs = (imgs[:, (2, 1, 0)].to(conf.device) * 255)  # RGB
                labels = labels.to(conf.device)
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)

                loss = self.focalLoss(thetas, labels)
                loss.backward()
                running_loss += loss.item() / conf.batch_size
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    self.model.eval()
                    for bmk in [
                            'agedb_30', 'lfw', 'calfw', 'cfp_ff', 'cfp_fp',
                            'cplfw', 'vgg2_fp'
                    ]:
                        acc = eval_emore_bmk(conf, self.model, bmk)
                        self.board_val(bmk, acc)

                    self.model.train()
                if self.step % self.save_every == 0 and self.step != 0:
                    self.save_state(conf, acc)

                self.step += 1

        self.save_state(conf, acc, to_save_folder=True, extra='final')

    def myValidation(self, conf):
        self.model.eval()

        for bmk in [
                'agedb_30', 'lfw', 'calfw', 'cfp_ff', 'cfp_fp', 'cplfw',
                'vgg2_fp'
        ]:
            eval_emore_bmk(conf, self.model, bmk)

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)
示例#3
0
import torch
from mtcnn import MTCNN
import cv2
import numpy as np

import PIL.Image as Image
from model import Backbone, Arcface, MobileFaceNet, Am_softmax, l2_norm
from torchvision import transforms as trans

device = torch.device('cuda:0')
mtcnn = MTCNN()

model = Backbone(50, 0.6, 'ir_se').to(device)
model.eval()
model.load_state_dict(torch.load('./saved_models/model_ir_se50.pth'))

# threshold = 1.54
test_transform = trans.Compose(
    [trans.ToTensor(),
     trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

img = cv2.imread(
    '/home/taotao/Downloads/celeba-512/000014.jpg.jpg')[:, :, ::-1]

bboxes, faces = mtcnn.align_multi(Image.fromarray(img),
                                  limit=10,
                                  min_face_size=30)
input = test_transform(faces[0]).unsqueeze(0)
embbed = model(input.cuda())
print(embbed.shape)
print(bboxes)
示例#4
0
class face_learner(object):
    def __init__(self, conf, inference=False, transfer=0, ext='final'):
        pprint.pprint(conf)
        self.conf = conf
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            tmp_idx = ext.rfind('_')  # find the last '_' to replace it by '/'
            self.ext = '/' + ext[:tmp_idx] + '/' + ext[tmp_idx + 1:]
            self.writer = SummaryWriter(str(conf.log_path) + self.ext)
            self.step = 0
            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            self.optimizer = optim.Adam(
                list(self.model.parameters()) + list(self.head.parameters()),
                conf.lr)
            print(self.optimizer)
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.save_freq = len(self.loader) // 5  #//5 # originally, 100
            self.evaluate_every = len(self.loader)  #//5 # originally, 10
            self.save_every = len(self.loader)  #//2 # originally, 5
            # self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(self.loader.dataset.root.parent)
            # self.val_112, self.val_112_issame = get_val_pair(self.loader.dataset.root.parent, 'val_112')
        else:
            self.threshold = conf.threshold

        self.train_losses = []
        self.train_counter = []
        self.test_losses = []
        self.test_accuracy = []
        self.test_counter = []

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(), save_path /
                ('optimizer_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(save_path / 'model_{}'.format(fixed_str),
                       map_location=conf.device))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
        # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
        # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
        # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in enumerate(
                self.loader):  #tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in iter(self.loader):  #tqdm(iter(self.loader)):
                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.save_freq == 0 and self.step != 0:
                    self.train_losses.append(loss.item())
                    self.train_counter.append(self.step)

                self.step += 1

            self.save_loss()

        # self.save_state(conf, accuracy, to_save_folder=True, extra=self.ext, model_only=True)

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum

    def binfer(self, conf, faces, target_embs, tta=False):
        '''
        return raw scores for every class 
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        self.model.eval()
        self.plot_result()
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        # print(dist)
        return dist.detach().cpu().numpy()
        # minimum, min_idx = torch.min(dist, dim=1)
        # min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1
        # return min_idx, minimum

    def save_loss(self):
        if not os.path.exists(self.conf.stored_result_dir):
            os.mkdir(self.conf.stored_result_dir)

        result = dict()
        result["train_losses"] = np.asarray(self.train_losses)
        result["train_counter"] = np.asarray(self.train_counter)
        result['test_accuracy'] = np.asarray(self.test_accuracy)
        result['test_losses'] = np.asarray(self.test_losses)
        result["test_counter"] = np.asarray(self.test_counter)

        with open(os.path.join(self.conf.stored_result_dir, "result_log.p"),
                  'wb') as fp:
            pickle.dump(result, fp)

    def plot_result(self):
        result_log_path = os.path.join(self.conf.stored_result_dir,
                                       "result_log.p")
        with open(result_log_path, 'rb') as f:
            result_dict = pickle.load(f)

        train_losses = result_dict['train_losses']
        train_counter = result_dict['train_counter']
        test_losses = result_dict['test_losses']
        test_counter = result_dict['test_counter']
        test_accuracy = result_dict['test_accuracy']

        fig1 = plt.figure(figsize=(12, 8))
        ax1 = fig1.add_subplot(111)
        ax1.plot(train_counter, train_losses, 'b', label='Train_loss')
        ax1.legend('Train_losses')
        plt.savefig(os.path.join(self.conf.stored_result_dir,
                                 "train_loss.png"))
        """
示例#5
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        print(conf)
        self.lr=conf.lr
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
        ###############################  ir_se50  ########################################
            if conf.struct =='ir_se_50':
                self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device)
            
                print('{}_{} model generated'.format(conf.net_mode, conf.net_depth))
        ###############################  resnet101  ######################################
            if conf.struct =='ir_se_101':
                self.model = resnet101().to(conf.device)
                print('resnet101 model generated')
            
        
        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)        

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            
        ###############################  ir_se50  ########################################
            if conf.struct =='ir_se_50':
                self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device)
                self.head_race = Arcface(embedding_size=conf.embedding_size, classnum=4).to(conf.device)
        
        ###############################  resnet101  ######################################
            if conf.struct =='ir_se_101':
                self.head = ArcMarginModel(embedding_size=conf.embedding_size,classnum=self.class_num).to(conf.device)
                self.head_race = ArcMarginModel(embedding_size=conf.embedding_size,classnum=self.class_num).to(conf.device)
            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)
            
            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD([
                                    {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
                                    {'params': [paras_wo_bn[-1]] + [self.head.kernel] + [self.head_race.kernel], 'weight_decay': 4e-4},
                                    {'params': paras_only_bn}
                                ], lr = conf.lr, momentum = conf.momentum)
            else:
                self.optimizer = optim.SGD([
                                    {'params': paras_wo_bn + [self.head.kernel] + [self.head_race.kernel], 'weight_decay': 5e-4},
                                    {'params': paras_only_bn}
                                ], lr = conf.lr, momentum = conf.momentum)
            print(self.optimizer)
#             self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

            print('optimizers generated')    
            print('len of loader:',len(self.loader)) 
            self.board_loss_every = len(self.loader)//min(len(self.loader),100)
            self.evaluate_every = len(self.loader)//1
            self.save_every = len(self.loader)//1
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(conf.val_folder)
        else:
            #self.threshold = conf.threshold
            pass
    



    def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(), save_path /
            ('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(), save_path /
                ('head_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
            torch.save(
                self.head_race.state_dict(), save_path /
                ('head__race{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(), save_path /
                ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra)))
    
    def load_state(self, model, head=None,head_race=None,optimizer=None):
           
        self.model.load_state_dict(torch.load(model),strict=False)
        if head is not None:
            self.head.load_state_dict(torch.load(head))
        if head_race is not None:
            self.head_race.load_state_dict(torch.load(head_race))
        if optimizer is not None:
            self.optimizer.load_state_dict(torch.load(optimizer))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor,tpr_val):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step)
        
        self.writer.add_scalar('{}[email protected]'.format(db_name), tpr_val, self.step)
#         self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
#         self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
#         self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)
        
    def evaluate(self, conf, carray, issame, nrof_folds = 5, tta = False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])            
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
        try:
            tpr_val = tpr[np.less(fpr,0.0012)&np.greater(fpr,0.0008)][0]
            
        except:
            tpr_val = 0
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor,tpr_val
    
    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1          

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)          
          
            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss,batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses    

    def train(self, conf, epochs):
        self.model = self.model.to(conf.device)
        self.head = self.head.to(conf.device)
        self.head_race = self.head_race.to(conf.device)
        self.model.train()
        self.head.train()
        self.head_race.train()
        running_loss = 0.     
        for e in range(epochs):
                
        
            print('epoch {} started'.format(e))
            
            if e == 8:#5 #train hear_race
                #self.init_lr()
                conf.loss0 = False
                conf.loss1 = True
                conf.loss2 = True
                conf.model = False
                conf.head = False
                conf.head_race = True
                print(conf)
            if e == 16:#10:
                #self.init_lr()
                self.schedule_lr()
                conf.loss0 = True
                conf.loss1 = True
                conf.loss2 = True
                conf.model = True
                conf.head = True
                conf.head_race = True
                print(conf)
            if e == 28:#22
                self.schedule_lr()
            if e == 32:
                self.schedule_lr()      
            if e == 35:
                self.schedule_lr()      
            
            requires_grad(self.head,conf.head)
            requires_grad(self.head_race,conf.head_race)
            requires_grad(self.model,conf.model)                            
            for imgs, labels  in tqdm(iter(self.loader)):
                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                labels_race = torch.zeros_like(labels)
                
                race0_index = labels.lt(sum(conf.race_num[:1]))
                race1_index = labels.lt(sum(conf.race_num[:2])) & labels.ge(sum(conf.race_num[:1]))
                race2_index = labels.lt(sum(conf.race_num[:3])) & labels.ge(sum(conf.race_num[:2]))
                race3_index = labels.ge(sum(conf.race_num[:3]))
                labels_race[race0_index]=0
                labels_race[race1_index] = 1
                labels_race[race2_index] = 2
                labels_race[race3_index] = 3

                
                
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas ,w = self.head(embeddings, labels)
                thetas_race ,w_race = self.head_race(embeddings, labels_race)
                loss = 0
                loss0 = conf.ce_loss(thetas, labels) 
                loss1 = conf.ce_loss(thetas_race, labels_race)
                loss2 = torch.mm(w_race.t(),w).to(conf.device)
                
                target =  torch.zeros_like(loss2).to(conf.device)
                
                target[0][:sum(conf.race_num[:1])] = 1
                target[1][sum(conf.race_num[:1]):sum(conf.race_num[:2])] = 1
                target[2][sum(conf.race_num[:2]):sum(conf.race_num[:3])] = 1
                target[3][sum(conf.race_num[:3]):] = 1
                
                weight = torch.zeros_like(loss2).to(conf.device)
                for i in range(4):
                    weight[i,:] = sum(conf.race_num)/conf.race_num[i] 
                #loss2 = torch.nn.functional.mse_loss(loss2 , target)
                
                loss2 = F.binary_cross_entropy(torch.sigmoid(loss2),target,weight)
                if conf.loss0 ==True:
                    loss += 2*loss0
                if conf.loss1 ==True:
                    loss += loss1
                if conf.loss2 ==True:
                    loss += loss2
                #loss = loss0 + loss1 + loss2
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()
                
                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.
                
                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy=None
                    accuracy, best_threshold, roc_curve_tensor ,tpr_val= self.evaluate(conf, self.agedb_30, self.agedb_30_issame)
                    self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor,tpr_val)
                    accuracy, best_threshold, roc_curve_tensor,tpr_val = self.evaluate(conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor,tpr_val)
                    accuracy, best_threshold, roc_curve_tensor,tpr_val = self.evaluate(conf, self.cfp_fp, self.cfp_fp_issame)
                    self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor,tpr_val)
                    self.model.train()
                    
                if self.step % self.save_every == 0 and self.step != 0:
                    
                    self.save_state(conf, accuracy)
                self.step += 1
                
        self.save_state(conf, accuracy, to_save_folder=True, extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:                 
            params['lr'] /= 10
        print(self.optimizer)
        
    def init_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] = self.lr
        print(self.optimizer)
        
        
    def schedule_lr_add(self):
        for params in self.optimizer.param_groups:                 
            params['lr'] *= 10
        print(self.optimizer)
        
        
    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:                        
                embs.append(self.model(conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)
        
        diff = source_embs.unsqueeze(-1) - target_embs.transpose(1,0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1
        return min_idx, minimum               
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
# import libnvjpeg
# import pickle

# img_root_dir = '/media/taotao/958c7d2d-c4ce-4117-a93b-c8a7aa4b88e3/taotao/part1/'
# save_path = '/media/taotao/958c7d2d-c4ce-4117-a93b-c8a7aa4b88e3/taotao/stars_256_0.85/'
img_root_dir = './images/'
save_path = './aligned/'
# embed_path = '/home/taotao/Downloads/celeb-aligned-256/embed.pkl'

device = torch.device('cpu')
mtcnn = MTCNN()

model = Backbone(50, 0.6, 'ir_se').to(device)
model.eval()
model.load_state_dict(torch.load('./model_ir_se50.pth', map_location=device))

# threshold = 1.54
test_transform = trans.Compose(
    [trans.ToTensor(),
     trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# decoder = libnvjpeg.py_NVJpegDecoder()

ind = 0
embed_map = {}

for root, dirs, files in os.walk(img_root_dir):
    for name in files:
        if name.endswith('jpg') or name.endswith('png'):
            try:
示例#7
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            self.growup = GrowUP().to(conf.device)
            self.discriminator = Discriminator().to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:

            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)
            if conf.discriminator:
                self.child_loader, self.adult_loader = get_train_loader_d(conf)

            os.makedirs(conf.log_path, exist_ok=True)
            self.writer = SummaryWriter(conf.log_path)
            self.step = 0

            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            # Will not use anymore
            if conf.use_dp:
                self.model = nn.DataParallel(self.model)
                self.head = nn.DataParallel(self.head)

            print(self.class_num)
            print(conf)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.kernel],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            if conf.discriminator:
                self.optimizer_g = optim.Adam(self.growup.parameters(),
                                              lr=1e-4,
                                              betas=(0.5, 0.999))
                self.optimizer_g2 = optim.Adam(self.growup.parameters(),
                                               lr=1e-4,
                                               betas=(0.5, 0.999))
                self.optimizer_d = optim.Adam(self.discriminator.parameters(),
                                              lr=1e-4,
                                              betas=(0.5, 0.999))
                self.optimizer2 = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)

            if conf.finetune_model_path is not None:
                self.optimizer = optim.SGD([{
                    'params': paras_wo_bn,
                    'weight_decay': 5e-4
                }, {
                    'params': paras_only_bn
                }],
                                           lr=conf.lr,
                                           momentum=conf.momentum)
            print('optimizers generated')

            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 2
            self.save_every = len(self.loader)

            dataset_root = "/home/nas1_userD/yonggyu/Face_dataset/face_emore"
            self.lfw = np.load(
                os.path.join(dataset_root,
                             "lfw_align_112_list.npy")).astype(np.float32)
            self.lfw_issame = np.load(
                os.path.join(dataset_root, "lfw_align_112_label.npy"))
            self.fgnetc = np.load(
                os.path.join(dataset_root,
                             "FGNET_new_align_list.npy")).astype(np.float32)
            self.fgnetc_issame = np.load(
                os.path.join(dataset_root, "FGNET_new_align_label.npy"))
        else:
            # Will not use anymore
            # self.model = nn.DataParallel(self.model)
            self.threshold = conf.threshold

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor,
                  negative_wrong, positive_wrong):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_scalar('{}_negative_wrong'.format(db_name),
                               negative_wrong, self.step)
        self.writer.add_scalar('{}_positive_wrong'.format(db_name),
                               positive_wrong, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=10, tta=True):
        self.model.eval()
        self.growup.eval()
        self.discriminator.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(
                        batch.to(conf.device)).cpu() + self.model(
                            fliped.to(conf.device)).cpu()
                    embeddings[idx:idx +
                               conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(
                        batch.to(conf.device)).cpu() + self.model(
                            fliped.to(conf.device)).cpu()
                    embeddings[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds, dist = evaluate_dist(
            embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = transforms.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, dist

    def evaluate_child(self, conf, carray, issame, nrof_folds=10, tta=True):
        self.model.eval()
        self.growup.eval()
        self.discriminator.eval()
        idx = 0
        embeddings1 = np.zeros([len(carray) // 2, conf.embedding_size])
        embeddings2 = np.zeros([len(carray) // 2, conf.embedding_size])

        carray1 = carray[::2, ]
        carray2 = carray[1::2, ]

        with torch.no_grad():
            while idx + conf.batch_size <= len(carray1):
                batch = torch.tensor(carray1[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \
                                self.growup(self.model(fliped.to(conf.device))).cpu()
                    embeddings1[idx:idx +
                                conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings1[idx:idx + conf.batch_size] = self.growup(
                        self.model(batch.to(conf.device))).cpu()
                idx += conf.batch_size
            if idx < len(carray1):
                batch = torch.tensor(carray1[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \
                                self.growup(self.model(fliped.to(conf.device))).cpu()
                    embeddings1[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings1[idx:] = self.growup(
                        self.model(batch.to(conf.device))).cpu()

            while idx + conf.batch_size <= len(carray2):
                batch = torch.tensor(carray2[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)).cpu() + \
                                self.model(fliped.to(conf.device)).cpu()
                    embeddings2[idx:idx +
                                conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings2[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray2):
                batch = torch.tensor(carray2[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)).cpu() + \
                                self.model(fliped.to(conf.device)).cpu()
                    embeddings2[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings2[idx:] = self.model(batch.to(conf.device)).cpu()

        tpr, fpr, accuracy, best_thresholds = evaluate_child(
            embeddings1, embeddings2, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = transforms.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def zero_grad(self):
        self.optimizer.zero_grad()
        self.optimizer_g.zero_grad()
        self.optimizer_d.zero_grad()

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()

            for imgs, labels, ages in tqdm(iter(self.loader)):

                self.optimizer.zero_grad()

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)

                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()

                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                # added wrong on evaluations
                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    # LFW evaluation
                    accuracy, best_threshold, roc_curve_tensor, dist = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    # NEGATIVE WRONG
                    wrong_list = np.where((self.lfw_issame == False)
                                          & (dist < best_threshold))[0]
                    negative_wrong = len(wrong_list)
                    # POSITIVE WRONG
                    wrong_list = np.where((self.lfw_issame == True)
                                          & (dist > best_threshold))[0]
                    positive_wrong = len(wrong_list)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor, negative_wrong,
                                   positive_wrong)

                    # FGNETC evaluation
                    accuracy2, best_threshold2, roc_curve_tensor2, dist2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    # NEGATIVE WRONG
                    wrong_list = np.where((self.fgnetc_issame == False)
                                          & (dist2 < best_threshold2))[0]
                    negative_wrong2 = len(wrong_list)
                    # POSITIVE WRONG
                    wrong_list = np.where((self.fgnetc_issame == True)
                                          & (dist2 > best_threshold2))[0]
                    positive_wrong2 = len(wrong_list)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2, negative_wrong2,
                                   positive_wrong2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    if conf.finetune_model_path is not None:
                        self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                            + '_' + str(conf.batch_size) + conf.model_name)
                    else:
                        self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                            + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        print('Horray!')

    def train_with_growup(self, conf, epochs):
        '''
        Our method
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()
                self.optimizer_g2.zero_grad()
                self.growup.train()

                c = (ages == 0)  # select children for enhancement

                embeddings = self.model(imgs)

                if sum(c) > 1:  # there might be no childern in loader's batch
                    embeddings_c = embeddings[c]
                    embeddings_a_hat = self.growup(embeddings_c)
                    embeddings[c] = embeddings_a_hat
                elif sum(c) == 1:
                    self.growup.eval()
                    embeddings_c = embeddings[c]
                    embeddings_a_hat = self.growup(embeddings_c)
                    embeddings[c] = embeddings_a_hat

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()
                self.optimizer_g2.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                self.growup.train()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                embeddings_a_hat = self.growup(embeddings_c)
                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_a_hat))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer_g.zero_grad()
                embeddings_c = self.model(imgs_c)
                embeddings_a_hat = self.growup(embeddings_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_a_hat))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                l1_loss = conf.l1_loss(embeddings_a_hat, embeddings_c)
                g_total_loss = g_loss + 10 * l1_loss
                g_total_loss.backward()

                # g_loss.backward()
                self.optimizer_g.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate_child(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def train_age_invariant(self, conf, epochs):
        '''
        Our method, without growup
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()
                self.schedule_lr2()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer2.zero_grad()
                embeddings_c = self.model(imgs_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                g_loss.backward()
                self.optimizer2.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def train_age_invariant2(self, conf, epochs):
        '''
        Our method, without growup, using paired dataset TODO
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()
                self.schedule_lr2()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer2.zero_grad()
                embeddings_c = self.model(imgs_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                g_loss.backward()
                self.optimizer2.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def analyze_angle(self, conf, name):
        '''
        Only works on age labeled vgg dataset, agedb dataset
        '''

        angle_table = [{
            0: set(),
            1: set(),
            2: set(),
            3: set(),
            4: set(),
            5: set(),
            6: set(),
            7: set()
        } for i in range(self.class_num)]
        # batch = 0
        # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device)
        if conf.resume_analysis:
            self.loader = []
        for imgs, labels, ages in tqdm(iter(self.loader)):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            ages = ages.to(conf.device)

            embeddings = self.model(imgs)
            if conf.use_dp:
                kernel_norm = l2_norm(self.head.module.kernel, axis=0)
                cos_theta = torch.mm(embeddings, kernel_norm)
                cos_theta = cos_theta.clamp(-1, 1)
            else:
                cos_theta = self.head.get_angle(embeddings)

            thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta)))

            for i in range(len(thetas)):
                age_bin = 7
                if ages[i] < 26:
                    age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2
                elif ages[i] < 66:
                    age_bin = int(((ages[i] + 4) // 10).item())
                angle_table[labels[i]][age_bin].add(
                    thetas[i][labels[i]].item())

        if conf.resume_analysis:
            with open('analysis/angle_table.pkl', 'rb') as f:
                angle_table = pickle.load(f)
        else:
            with open('analysis/angle_table.pkl', 'wb') as f:
                pickle.dump(angle_table, f)

        count, avg_angle = [], []
        for i in range(self.class_num):
            count.append(
                [len(single_set) for single_set in angle_table[i].values()])
            avg_angle.append([
                sum(list(single_set)) / len(single_set)
                if len(single_set) else 0  # if set() size is zero, avg is zero
                for single_set in angle_table[i].values()
            ])

        count_df = pd.DataFrame(count)
        avg_angle_df = pd.DataFrame(avg_angle)

        with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format(
                conf.data_mode, name)) as writer:
            count_df.to_excel(writer, sheet_name='count')
            avg_angle_df.to_excel(writer, sheet_name='avg_angle')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def schedule_lr2(self):
        for params in self.optimizer2.param_groups:
            params['lr'] /= 10
        print(self.optimizer2)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = transforms.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum

    def save_best_state(self,
                        conf,
                        accuracy,
                        to_save_folder=False,
                        extra=None,
                        model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        os.makedirs('work_space/models', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            str(save_path) +
            ('lfw_best_model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                str(save_path) +
                ('lfw_best_head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                str(save_path) +
                ('lfw_best_optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        os.makedirs('work_space/models', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            str(save_path) +
            ('/model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                str(save_path) +
                ('/head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                str(save_path) +
                ('/optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            if conf.discriminator:
                torch.save(
                    self.growup.state_dict(),
                    str(save_path) +
                    ('/growup_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                        get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False,
                   analyze=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(os.path.join(save_path, 'model_{}'.format(fixed_str))))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            if not analyze:
                self.optimizer.load_state_dict(
                    torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
示例#8
0
def train(args):
    # gpu init
    multi_gpu = False
    if len(args.gpus.split(',')) > 1:
        multi_gpu = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    D = MultiscaleDiscriminator(
        input_nc=3,
        ndf=64,
        n_layers=3,
        use_sigmoid=False,
        norm_layer=torch.nn.InstanceNorm2d)  # pix2pix use MSEloss
    G = AAD_Gen()
    F = Backbone(50, drop_ratio=0.6, mode='ir_se')
    F.load_state_dict(torch.load(args.arc_model_path))
    E = Att_Encoder()

    optimizer_D = torch.optim.Adam(D.parameters(),
                                   lr=0.0004,
                                   betas=(0.0, 0.999))
    optimizer_GE = torch.optim.Adam([{
        'params': G.parameters()
    }, {
        'params': E.parameters()
    }],
                                    lr=0.0004,
                                    betas=(0.0, 0.999))

    if multi_gpu:
        D = DataParallel(D).to(device)
        G = DataParallel(G).to(device)
        F = DataParallel(F).to(device)
        E = DataParallel(E).to(device)
    else:
        D = D.to(device)
        G = G.to(device)
        F = F.to(device)
        E = E.to(device)

    if args.resume:
        if os.path.isfile(args.resume_model_path):
            print("Loading checkpoint from {}".format(args.resume_model_path))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            D.load_state_dict(checkpoint["state_dict_D"])
            G.load_state_dict(checkpoint["state_dict_G"])
            E.load_state_dict(checkpoint["state_dict_E"])
            #            optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            optimizer_D.load_state_dict(checkpoint['optimizer_D'])
            optimizer_GE.load_state_dict(checkpoint['optimizer_GE'])
        else:
            print('Cannot found checkpoint {}'.format(args.resume_model_path))
    else:
        args.start_epoch = 1

    def print_with_time(string):
        print(time.strftime("%Y-%m-%d %H:%M:%S ", time.localtime()) + string)

    def weights_init(m):
        classname = m.__class__.__name__
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    def set_requires_grad(nets, requires_grad=False):

        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def trans_batch(batch):
        t = trans.Compose(
            [trans.ToPILImage(),
             trans.Resize((112, 112)),
             trans.ToTensor()])
        bs = batch.shape[0]
        res = torch.ones(bs, 3, 112, 112).type_as(batch)
        for i in range(bs):
            res[i] = t(batch[i].cpu())
        return res

    set_requires_grad(F, requires_grad=False)
    data_transform = trans.Compose([
        trans.ToTensor(),
        trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    #dataset = ImageFolder(args.data_path, transform=data_transform)
    dataset = FaceEmbed(args.data_path)
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    D.apply(weights_init)
    G.apply(weights_init)
    E.apply(weights_init)

    for epoch in range(args.start_epoch, args.total_epoch + 1):
        D.train()
        G.train()
        F.eval(
        )  #   Only extract features!  # input dim=3,256,256   out dim=256 !
        E.train()

        for batch_idx, data in enumerate(data_loader):
            time_curr = time.time()
            iteration = (epoch - 1) * len(data_loader) + batch_idx
            try:
                source, target, label = data

                source = source.to(device)
                target = target.to(device)
                label = torch.LongTensor(label).to(device)

                #Zid =F(trans_batch(source))  # bs, 512
                Zid = F(
                    downsample(source[:, :, 50:-10, 30:-30], size=(112, 112)))
                Zatt = E(target)  # list:8  each:bs,,,
                Yst0 = G(Zid, Zatt)  # bs,3,256,256

                # train discriminators
                pred_gen = D(Yst0.detach())
                #pred_gen = list(map(lambda x: x[0].detach(), pred_gen))
                pred_real = D(target)
                optimizer_D.zero_grad()
                loss_real, loss_fake = loss_hinge_dis()(pred_gen, pred_real)
                L_dis = loss_real + loss_fake
                #    if batch_idx%3==0:
                L_dis.backward()
                optimizer_D.step()

                # train generators
                pred_gen = D(Yst0)
                L_gen = loss_hinge_gen()(pred_gen)
                #L_id = IdLoss()(F(trans_batch(Yst0)), Zid)
                L_id = IdLoss()(F(
                    downsample(Yst0[:, :, 50:-10, 30:-30], size=(112, 112))),
                                Zid)
                #Zatt = list(map(lambda x: x.detach(), Zatt))
                L_att = AttrLoss()(E(Yst0), Zatt)
                L_Rec = RecLoss()(Yst0, target, label)

                Loss = (L_gen + 10 * L_att + 5 * L_id + 10 * L_Rec).to(device)
                optimizer_GE.zero_grad()
                Loss.backward()
                optimizer_GE.step()

            except Exception as e:
                print(e)
                continue

            if batch_idx % args.log_interval == 0 or batch_idx == 20:
                time_used = time.time() - time_curr
                print_with_time(
                    'Train Epoch: {} [{}/{} ({:.0f}%)], L_dis:{:.4f}, loss_real:{:.4f}, loss_fake:{:.4f}, Loss:{:.4f}, L_gen:{:.4f}, L_id:{:.4f}, L_att:{:.4f}, L_Rec:{:.4f}'
                    .format(
                        epoch, batch_idx * len(data), len(data_loader.dataset),
                        100. * batch_idx *
                        len(data) / len(data_loader.dataset), L_dis.item(),
                        loss_real.item(), loss_fake.item(), Loss.item(),
                        L_gen.item(), 5 * L_id.item(), 10 * L_att.item(),
                        10 * L_Rec))
                time_curr = time.time()

        if epoch % args.save_interval == 0:  #or batch_idx*len(data) % 350004==0:
            state = {
                "epoch": epoch,
                "state_dict_D": D.state_dict(),
                "state_dict_G": G.state_dict(),
                "state_dict_E": E.state_dict(),
                "optimizer_D": optimizer_D.state_dict(),
                "optimizer_GE": optimizer_GE.state_dict(),
                #                        "optimizer_E": optimizer_E.state_dict(),
            }
            filename = "../model/train1_{:03d}_{:03d}.pth.tar".format(
                epoch, batch_idx * len(data))
            torch.save(state, filename)
示例#9
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        print(conf)

        self.num_splits = int(conf.meta_file.split('_labels.txt')[0][-1])

        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if conf.device > 1:
            gpu_ids = list(
                range(0, min(torch.cuda.device_count(), conf.device)))
            self.model = nn.DataParallel(self.model, device_ids=gpu_ids).cuda()
        else:
            self.model = self.model.cuda()

        if not inference:
            self.milestones = conf.milestones

            if conf.remove_single is True:
                conf.meta_file = conf.meta_file.replace('.txt', '_clean.txt')
            meta_file = open(conf.meta_file, 'r')
            meta = meta_file.readlines()
            pseudo_all = [int(item.split('\n')[0]) for item in meta]
            pseudo_classnum = set(pseudo_all)
            if -1 in pseudo_classnum:
                pseudo_classnum = len(pseudo_classnum) - 1
            else:
                pseudo_classnum = len(pseudo_classnum)
            print('classnum:{}'.format(pseudo_classnum))

            pseudo_classes = [
                pseudo_all[count[index]:count[index + 1]]
                for index in range(self.num_splits)
            ]
            meta_file.close()

            train_dataset = [get_train_dataset(conf.emore_folder)] + [
                get_pseudo_dataset([conf.pseudo_folder, index + 1],
                                   pseudo_classes[index], conf.remove_single)
                for index in range(self.num_splits)
            ]
            self.class_num = [num for _, num in train_dataset]
            print('Loading dataset done')

            train_longest_size = [len(item[0]) for item in train_dataset]
            temp = int(np.floor(conf.batch_size // (self.num_splits + 1)))
            self.batch_size = [conf.batch_size - temp * self.num_splits
                               ] + [temp] * self.num_splits
            train_longest_size = max([
                int(np.floor(td / bs))
                for td, bs in zip(train_longest_size, self.batch_size)
            ])
            train_sampler = [
                GivenSizeSampler(td[0],
                                 total_size=train_longest_size * bs,
                                 rand_seed=None)
                for td, bs in zip(train_dataset, self.batch_size)
            ]

            self.train_loader = [
                DataLoader(train_dataset[k][0],
                           batch_size=self.batch_size[k],
                           shuffle=False,
                           pin_memory=conf.pin_memory,
                           num_workers=conf.num_workers,
                           sampler=train_sampler[k])
                for k in range(1 + self.num_splits)
            ]
            print('Loading loader done')

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0
            self.head = [
                Arcface(embedding_size=conf.embedding_size,
                        classnum=self.class_num[0]),
                Arcface(embedding_size=conf.embedding_size,
                        classnum=pseudo_classnum)
            ]

            if conf.device > 1:
                self.head = [
                    nn.DataParallel(self.head[0], device_ids=gpu_ids).cuda(),
                    nn.DataParallel(self.head[1], device_ids=gpu_ids).cuda()
                ]
            else:
                self.head = [self.head[0].cuda(), self.head[1].cuda()]

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model.module)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.parameters()],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                params = [a.module.parameters() for a in self.head]
                params = list(params[0]) + list(params[1])
                #from IPython import embed;embed()
                self.optimizer = optim.SGD([{
                    'params': paras_wo_bn + params,
                    'weight_decay': 5e-4
                }, {
                    'params': paras_only_bn
                }],
                                           lr=conf.lr,
                                           momentum=conf.momentum)
            print(self.optimizer)

            if conf.resume is not None:
                self.start_epoch = self.load_state(conf.resume)
            else:
                self.start_epoch = 0

            print('optimizers generated')
            self.board_loss_every = len(self.train_loader[0]) // 10
            self.evaluate_every = len(self.train_loader[0]) // 5
            self.save_every = len(self.train_loader[0]) // 5
            self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(
                conf.eval_path)
        else:
            self.threshold = conf.threshold

    def save_state(self,
                   conf,
                   accuracy,
                   e,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
            if not os.path.exists(str(save_path)):
                os.makedirs(str(save_path))
        else:
            save_path = conf.model_path
            if not os.path.exists(str(save_path)):
                os.makedirs(str(save_path))
        if model_only:
            torch.save(
                self.model.state_dict(),
                os.path.join(str(save_path),
                             ('model_{}_accuracy:{}_step:{}_{}.pth'.format(
                                 get_time(), accuracy, self.step, extra))))
        else:
            save = {
                'optimizer': self.optimizer.state_dict(),
                'head': [self.head[0].state_dict(), self.head[1].state_dict()],
                'model': self.model.state_dict(),
                'epoch': e
            }
            torch.save(
                save,
                os.path.join(str(save_path),
                             ('accuracy:{}_step:{}_{}.pth'.format(
                                 get_time(), accuracy, self.step, extra))))

    def load_state(self, save_path, from_save_folder=False, model_only=False):
        if model_only:
            self.model.load_state_dict(torch.load(save_path))
        else:
            state = torch.load(save_path)
            self.model.load_state_dict(state['model'])
            self.head[0].load_state_dict(state['head'][0])
            self.head[1].load_state_dict(state['head'][1])
            self.optimizer.load_state_dict(state['optimizer'])
        return state['epoch'] + 1

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
#         self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
#         self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
#         self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(
                        fliped.cuda())
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch)
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.cuda()).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(
                        fliped.cuda())
                    embeddings[idx:] = l2_norm(emb_batch)
                else:
                    embeddings[idx:] = self.model(batch.cuda()).cpu()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame,
                                                       nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):

            imgs = imgs.cuda()
            labels = labels.cuda()
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step

            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(self.start_epoch, epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            self.iters = [iter(loader) for loader in self.train_loader]
            for i in tqdm(range(len(self.train_loader[0]))):
                data = [self.iters[i].next() for i in range(len(self.iters))]
                imgs, labels = zip(
                    *[data[k] for k in range(self.num_splits + 1)])
                labeled_num = len(imgs[0])

                imgs = torch.cat(imgs, dim=0)
                labels = torch.cat(labels, dim=0)

                imgs = imgs.cuda()
                labels = labels.cuda()
                self.optimizer.zero_grad()
                embeddings = self.model(imgs)

                thetas = self.head[0](embeddings[:labeled_num],
                                      labels[:labeled_num])
                losses1 = conf.ce_loss(thetas, labels[:labeled_num])
                thetas = self.head[1](embeddings[labeled_num:],
                                      labels[labeled_num:])
                losses2 = conf.ce_loss(thetas, labels[labeled_num:])

                num_ratio = labeled_num / len(embeddings)
                loss = num_ratio * losses1 + (1 - num_ratio) * losses2

                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    print('step:{}, train_loss:{}'.format(
                        self.step, loss_board))
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.agedb_30, self.agedb_30_issame)
                    accuracy1 = accuracy
                    self.board_val('agedb_30', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    accuracy2 = accuracy
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.cfp_fp, self.cfp_fp_issame)
                    accuracy3 = accuracy
                    self.board_val('cfp_fp', accuracy, best_threshold,
                                   roc_curve_tensor)
                    print('step:{}, agedb:{},lfw:{},cfp_fp:{}'.format(
                        self.step, accuracy1, accuracy2, accuracy3))
                    self.model.train()
                if self.step % self.save_every == 0 and self.step != 0:
                    self.save_state(conf, accuracy, e)

                self.step += 1

        self.save_state(conf, accuracy, e, to_save_folder=True, extra='final')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(conf.test_transform(img).cuda().unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).cuda().unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(conf.test_transform(img).cuda().unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum
示例#10
0
def prepare(args):
    resume_from_checkpoint = args.resume_from_checkpoint

    prepare_start_time = time.time()
    logger.info('global', 'Start preparing.')
    check_config_dir()
    logger.info('setting', config_info(), time_report=False)

    model = Backbone()
    model = model.cuda()
    logger.info('setting', model_summary(model), time_report=False)
    logger.info('setting', str(model), time_report=False)

    branches = [
        main_branch(Config.nr_class, Config.in_planes),
        parsing_branch(Config.nr_class, Config.in_planes),
        parsing_branch(Config.nr_class, Config.in_planes),
        parsing_branch(Config.nr_class, Config.in_planes),
        parsing_branch(Config.nr_class, Config.in_planes)
    ]

    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(Config.input_shape),
        transforms.RandomApply([
            transforms.ColorJitter(
                brightness=0.3, contrast=0.3, saturation=0.3, hue=0)
        ],
                               p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.Pad(10),
        transforms.RandomCrop(Config.input_shape),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    test_transforms = transforms.Compose([
        transforms.Resize(Config.input_shape),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    trainset = Veri776_train(transforms=train_transforms,
                             need_mask=True,
                             bg_switch=Config.p_bgswitch)
    testset = Veri776_test(transforms=test_transforms)

    pksampler = PKSampler(trainset, p=Config.P, k=Config.K)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=Config.batch_size,
                                               sampler=pksampler,
                                               num_workers=Config.nr_worker,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        testset,
        batch_size=Config.batch_size,
        sampler=torch.utils.data.SequentialSampler(testset),
        num_workers=Config.nr_worker,
        pin_memory=True)

    weight_decay_setting = parm_list_with_Wdecay_multi([model] + branches)
    optimizer = torch.optim.Adam(weight_decay_setting, lr=Config.lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_multi_func)

    losses = {}
    losses['cross_entropy_loss'] = [
        torch.nn.CrossEntropyLoss(),
        weight_cross_entropy(Config.ce_thres[0]),
        weight_cross_entropy(Config.ce_thres[1]),
        weight_cross_entropy(Config.ce_thres[2]),
        weight_cross_entropy(Config.ce_thres[3])
    ]
    losses['triplet_hard_loss'] = [
        triplet_hard_loss(margin=Config.triplet_margin),
        weighted_triplet_hard_loss(margin=Config.branch_margin,
                                   soft_margin=Config.soft_marigin),
        weighted_triplet_hard_loss(margin=Config.branch_margin,
                                   soft_margin=Config.soft_marigin),
        weighted_triplet_hard_loss(margin=Config.branch_margin,
                                   soft_margin=Config.soft_marigin),
        weighted_triplet_hard_loss(margin=Config.branch_margin,
                                   soft_margin=Config.soft_marigin)
    ]

    for k in losses.keys():
        if isinstance(losses[k], list):
            for i in range(len(losses[k])):
                losses[k][i] = losses[k][i].cuda()
        else:
            losses[k] = losses[k].cuda()

    for i in range(len(branches)):
        branches[i] = branches[i].cuda()

    start_epoch = 0
    if resume_from_checkpoint and os.path.exists(Config.checkpoint_path):
        checkpoint = load_checkpoint()
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    # continue training for next the epoch of the checkpoint, or simply start from 1
    start_epoch += 1

    ret = {
        'start_epoch': start_epoch,
        'model': model,
        'branches': branches,
        'train_loader': train_loader,
        'test_loader': test_loader,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'losses': losses
    }

    prepare_end_time = time.time()
    time_spent = sec2min_sec(prepare_start_time, prepare_end_time)
    logger.info(
        'global', 'Finish preparing, time spend: {}mins {}s.'.format(
            time_spent[0], time_spent[1]))

    return ret