class face_learner(object): def __init__(self, conf, inference=False): print(conf) if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.board_loss_every = len(self.loader) // 100 self.evaluate_every = len(self.loader) // 10 self.save_every = len(self.loader) // 5 self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data( self.loader.dataset.root.parent) else: self.threshold = conf.threshold def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.model.state_dict(), save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(save_path / 'model_{}'.format(fixed_str))) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.agedb_30, self.agedb_30_issame) self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.cfp_fp, self.cfp_fp_issame) self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor) self.model.train() if self.step % self.save_every == 0 and self.step != 0: self.save_state(conf, accuracy) self.step += 1 self.save_state(conf, accuracy, to_save_folder=True, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum
class face_learner(object): def __init__(self, conf, inference=False): print(conf) # self.loader, self.class_num = construct_msr_dataset(conf) self.loader, self.class_num = get_train_loader(conf) self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = QAMFace(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) self.focalLoss = FocalLoss() print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.board_loss_every = len(self.loader) // 1000 self.evaluate_every = len(self.loader) // 10 self.save_every = len(self.loader) // 2 else: self.threshold = conf.threshold # 多GPU训练 self.model = torch.nn.DataParallel(self.model) self.model.to(conf.device) self.head = torch.nn.DataParallel(self.head) self.head = self.head.to(conf.device) def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.model.state_dict(), save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): print('resume model from ' + fixed_str) if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(save_path / 'model_{}'.format(fixed_str))) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold=0, roc_curve_tensor=0): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) # manually decay lr if e in self.milestones: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader)): imgs = (imgs[:, (2, 1, 0)].to(conf.device) * 255) # RGB labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = self.focalLoss(thetas, labels) loss.backward() running_loss += loss.item() / conf.batch_size self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: self.model.eval() for bmk in [ 'agedb_30', 'lfw', 'calfw', 'cfp_ff', 'cfp_fp', 'cplfw', 'vgg2_fp' ]: acc = eval_emore_bmk(conf, self.model, bmk) self.board_val(bmk, acc) self.model.train() if self.step % self.save_every == 0 and self.step != 0: self.save_state(conf, acc) self.step += 1 self.save_state(conf, acc, to_save_folder=True, extra='final') def myValidation(self, conf): self.model.eval() for bmk in [ 'agedb_30', 'lfw', 'calfw', 'cfp_ff', 'cfp_fp', 'cplfw', 'vgg2_fp' ]: eval_emore_bmk(conf, self.model, bmk) def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer)
import torch from mtcnn import MTCNN import cv2 import numpy as np import PIL.Image as Image from model import Backbone, Arcface, MobileFaceNet, Am_softmax, l2_norm from torchvision import transforms as trans device = torch.device('cuda:0') mtcnn = MTCNN() model = Backbone(50, 0.6, 'ir_se').to(device) model.eval() model.load_state_dict(torch.load('./saved_models/model_ir_se50.pth')) # threshold = 1.54 test_transform = trans.Compose( [trans.ToTensor(), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) img = cv2.imread( '/home/taotao/Downloads/celeba-512/000014.jpg.jpg')[:, :, ::-1] bboxes, faces = mtcnn.align_multi(Image.fromarray(img), limit=10, min_face_size=30) input = test_transform(faces[0]).unsqueeze(0) embbed = model(input.cuda()) print(embbed.shape) print(bboxes)
class face_learner(object): def __init__(self, conf, inference=False, transfer=0, ext='final'): pprint.pprint(conf) self.conf = conf if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) tmp_idx = ext.rfind('_') # find the last '_' to replace it by '/' self.ext = '/' + ext[:tmp_idx] + '/' + ext[tmp_idx + 1:] self.writer = SummaryWriter(str(conf.log_path) + self.ext) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) self.optimizer = optim.Adam( list(self.model.parameters()) + list(self.head.parameters()), conf.lr) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.save_freq = len(self.loader) // 5 #//5 # originally, 100 self.evaluate_every = len(self.loader) #//5 # originally, 10 self.save_every = len(self.loader) #//2 # originally, 5 # self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(self.loader.dataset.root.parent) # self.val_112, self.val_112_issame = get_val_pair(self.loader.dataset.root.parent, 'val_112') else: self.threshold = conf.threshold self.train_losses = [] self.train_counter = [] self.test_losses = [] self.test_accuracy = [] self.test_counter = [] def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.model.state_dict(), save_path / ('model_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), save_path / ('head_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(save_path / 'model_{}'.format(fixed_str), map_location=conf.device)) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in enumerate( self.loader): #tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in iter(self.loader): #tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.save_freq == 0 and self.step != 0: self.train_losses.append(loss.item()) self.train_counter.append(self.step) self.step += 1 self.save_loss() # self.save_state(conf, accuracy, to_save_folder=True, extra=self.ext, model_only=True) def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum def binfer(self, conf, faces, target_embs, tta=False): ''' return raw scores for every class faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' self.model.eval() self.plot_result() embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) # print(dist) return dist.detach().cpu().numpy() # minimum, min_idx = torch.min(dist, dim=1) # min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 # return min_idx, minimum def save_loss(self): if not os.path.exists(self.conf.stored_result_dir): os.mkdir(self.conf.stored_result_dir) result = dict() result["train_losses"] = np.asarray(self.train_losses) result["train_counter"] = np.asarray(self.train_counter) result['test_accuracy'] = np.asarray(self.test_accuracy) result['test_losses'] = np.asarray(self.test_losses) result["test_counter"] = np.asarray(self.test_counter) with open(os.path.join(self.conf.stored_result_dir, "result_log.p"), 'wb') as fp: pickle.dump(result, fp) def plot_result(self): result_log_path = os.path.join(self.conf.stored_result_dir, "result_log.p") with open(result_log_path, 'rb') as f: result_dict = pickle.load(f) train_losses = result_dict['train_losses'] train_counter = result_dict['train_counter'] test_losses = result_dict['test_losses'] test_counter = result_dict['test_counter'] test_accuracy = result_dict['test_accuracy'] fig1 = plt.figure(figsize=(12, 8)) ax1 = fig1.add_subplot(111) ax1.plot(train_counter, train_losses, 'b', label='Train_loss') ax1.legend('Train_losses') plt.savefig(os.path.join(self.conf.stored_result_dir, "train_loss.png")) """
class face_learner(object): def __init__(self, conf, inference=False): print(conf) self.lr=conf.lr if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: ############################### ir_se50 ######################################## if conf.struct =='ir_se_50': self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) ############################### resnet101 ###################################### if conf.struct =='ir_se_101': self.model = resnet101().to(conf.device) print('resnet101 model generated') if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.writer = SummaryWriter(conf.log_path) self.step = 0 ############################### ir_se50 ######################################## if conf.struct =='ir_se_50': self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) self.head_race = Arcface(embedding_size=conf.embedding_size, classnum=4).to(conf.device) ############################### resnet101 ###################################### if conf.struct =='ir_se_101': self.head = ArcMarginModel(embedding_size=conf.embedding_size,classnum=self.class_num).to(conf.device) self.head_race = ArcMarginModel(embedding_size=conf.embedding_size,classnum=self.class_num).to(conf.device) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD([ {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5}, {'params': [paras_wo_bn[-1]] + [self.head.kernel] + [self.head_race.kernel], 'weight_decay': 4e-4}, {'params': paras_only_bn} ], lr = conf.lr, momentum = conf.momentum) else: self.optimizer = optim.SGD([ {'params': paras_wo_bn + [self.head.kernel] + [self.head_race.kernel], 'weight_decay': 5e-4}, {'params': paras_only_bn} ], lr = conf.lr, momentum = conf.momentum) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') print('len of loader:',len(self.loader)) self.board_loss_every = len(self.loader)//min(len(self.loader),100) self.evaluate_every = len(self.loader)//1 self.save_every = len(self.loader)//1 self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(conf.val_folder) else: #self.threshold = conf.threshold pass def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.model.state_dict(), save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) torch.save( self.head_race.state_dict(), save_path / ('head__race{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) def load_state(self, model, head=None,head_race=None,optimizer=None): self.model.load_state_dict(torch.load(model),strict=False) if head is not None: self.head.load_state_dict(torch.load(head)) if head_race is not None: self.head_race.load_state_dict(torch.load(head_race)) if optimizer is not None: self.optimizer.load_state_dict(torch.load(optimizer)) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor,tpr_val): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) self.writer.add_scalar('{}[email protected]'.format(db_name), tpr_val, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds = 5, tta = False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) try: tpr_val = tpr[np.less(fpr,0.0012)&np.greater(fpr,0.0008)][0] except: tpr_val = 0 buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor,tpr_val def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss,batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model = self.model.to(conf.device) self.head = self.head.to(conf.device) self.head_race = self.head_race.to(conf.device) self.model.train() self.head.train() self.head_race.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == 8:#5 #train hear_race #self.init_lr() conf.loss0 = False conf.loss1 = True conf.loss2 = True conf.model = False conf.head = False conf.head_race = True print(conf) if e == 16:#10: #self.init_lr() self.schedule_lr() conf.loss0 = True conf.loss1 = True conf.loss2 = True conf.model = True conf.head = True conf.head_race = True print(conf) if e == 28:#22 self.schedule_lr() if e == 32: self.schedule_lr() if e == 35: self.schedule_lr() requires_grad(self.head,conf.head) requires_grad(self.head_race,conf.head_race) requires_grad(self.model,conf.model) for imgs, labels in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) labels_race = torch.zeros_like(labels) race0_index = labels.lt(sum(conf.race_num[:1])) race1_index = labels.lt(sum(conf.race_num[:2])) & labels.ge(sum(conf.race_num[:1])) race2_index = labels.lt(sum(conf.race_num[:3])) & labels.ge(sum(conf.race_num[:2])) race3_index = labels.ge(sum(conf.race_num[:3])) labels_race[race0_index]=0 labels_race[race1_index] = 1 labels_race[race2_index] = 2 labels_race[race3_index] = 3 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas ,w = self.head(embeddings, labels) thetas_race ,w_race = self.head_race(embeddings, labels_race) loss = 0 loss0 = conf.ce_loss(thetas, labels) loss1 = conf.ce_loss(thetas_race, labels_race) loss2 = torch.mm(w_race.t(),w).to(conf.device) target = torch.zeros_like(loss2).to(conf.device) target[0][:sum(conf.race_num[:1])] = 1 target[1][sum(conf.race_num[:1]):sum(conf.race_num[:2])] = 1 target[2][sum(conf.race_num[:2]):sum(conf.race_num[:3])] = 1 target[3][sum(conf.race_num[:3]):] = 1 weight = torch.zeros_like(loss2).to(conf.device) for i in range(4): weight[i,:] = sum(conf.race_num)/conf.race_num[i] #loss2 = torch.nn.functional.mse_loss(loss2 , target) loss2 = F.binary_cross_entropy(torch.sigmoid(loss2),target,weight) if conf.loss0 ==True: loss += 2*loss0 if conf.loss1 ==True: loss += loss1 if conf.loss2 ==True: loss += loss2 #loss = loss0 + loss1 + loss2 loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: accuracy=None accuracy, best_threshold, roc_curve_tensor ,tpr_val= self.evaluate(conf, self.agedb_30, self.agedb_30_issame) self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor,tpr_val) accuracy, best_threshold, roc_curve_tensor,tpr_val = self.evaluate(conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor,tpr_val) accuracy, best_threshold, roc_curve_tensor,tpr_val = self.evaluate(conf, self.cfp_fp, self.cfp_fp_issame) self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor,tpr_val) self.model.train() if self.step % self.save_every == 0 and self.step != 0: self.save_state(conf, accuracy) self.step += 1 self.save_state(conf, accuracy, to_save_folder=True, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def init_lr(self): for params in self.optimizer.param_groups: params['lr'] = self.lr print(self.optimizer) def schedule_lr_add(self): for params in self.optimizer.param_groups: params['lr'] *= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model(conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model(conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append(self.model(conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose(1,0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # import libnvjpeg # import pickle # img_root_dir = '/media/taotao/958c7d2d-c4ce-4117-a93b-c8a7aa4b88e3/taotao/part1/' # save_path = '/media/taotao/958c7d2d-c4ce-4117-a93b-c8a7aa4b88e3/taotao/stars_256_0.85/' img_root_dir = './images/' save_path = './aligned/' # embed_path = '/home/taotao/Downloads/celeb-aligned-256/embed.pkl' device = torch.device('cpu') mtcnn = MTCNN() model = Backbone(50, 0.6, 'ir_se').to(device) model.eval() model.load_state_dict(torch.load('./model_ir_se50.pth', map_location=device)) # threshold = 1.54 test_transform = trans.Compose( [trans.ToTensor(), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) # decoder = libnvjpeg.py_NVJpegDecoder() ind = 0 embed_map = {} for root, dirs, files in os.walk(img_root_dir): for name in files: if name.endswith('jpg') or name.endswith('png'): try:
class face_learner(object): def __init__(self, conf, inference=False): if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) self.growup = GrowUP().to(conf.device) self.discriminator = Discriminator().to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) if conf.discriminator: self.child_loader, self.adult_loader = get_train_loader_d(conf) os.makedirs(conf.log_path, exist_ok=True) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) # Will not use anymore if conf.use_dp: self.model = nn.DataParallel(self.model) self.head = nn.DataParallel(self.head) print(self.class_num) print(conf) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) if conf.discriminator: self.optimizer_g = optim.Adam(self.growup.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer_g2 = optim.Adam(self.growup.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer_d = optim.Adam(self.discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer2 = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) if conf.finetune_model_path is not None: self.optimizer = optim.SGD([{ 'params': paras_wo_bn, 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print('optimizers generated') self.board_loss_every = len(self.loader) // 100 self.evaluate_every = len(self.loader) // 2 self.save_every = len(self.loader) dataset_root = "/home/nas1_userD/yonggyu/Face_dataset/face_emore" self.lfw = np.load( os.path.join(dataset_root, "lfw_align_112_list.npy")).astype(np.float32) self.lfw_issame = np.load( os.path.join(dataset_root, "lfw_align_112_label.npy")) self.fgnetc = np.load( os.path.join(dataset_root, "FGNET_new_align_list.npy")).astype(np.float32) self.fgnetc_issame = np.load( os.path.join(dataset_root, "FGNET_new_align_label.npy")) else: # Will not use anymore # self.model = nn.DataParallel(self.model) self.threshold = conf.threshold def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor, negative_wrong, positive_wrong): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_scalar('{}_negative_wrong'.format(db_name), negative_wrong, self.step) self.writer.add_scalar('{}_positive_wrong'.format(db_name), positive_wrong, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) def evaluate(self, conf, carray, issame, nrof_folds=10, tta=True): self.model.eval() self.growup.eval() self.discriminator.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model( batch.to(conf.device)).cpu() + self.model( fliped.to(conf.device)).cpu() embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model( batch.to(conf.device)).cpu() + self.model( fliped.to(conf.device)).cpu() embeddings[idx:] = l2_norm(emb_batch).cpu() else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds, dist = evaluate_dist( embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = transforms.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, dist def evaluate_child(self, conf, carray, issame, nrof_folds=10, tta=True): self.model.eval() self.growup.eval() self.discriminator.eval() idx = 0 embeddings1 = np.zeros([len(carray) // 2, conf.embedding_size]) embeddings2 = np.zeros([len(carray) // 2, conf.embedding_size]) carray1 = carray[::2, ] carray2 = carray[1::2, ] with torch.no_grad(): while idx + conf.batch_size <= len(carray1): batch = torch.tensor(carray1[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \ self.growup(self.model(fliped.to(conf.device))).cpu() embeddings1[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings1[idx:idx + conf.batch_size] = self.growup( self.model(batch.to(conf.device))).cpu() idx += conf.batch_size if idx < len(carray1): batch = torch.tensor(carray1[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \ self.growup(self.model(fliped.to(conf.device))).cpu() embeddings1[idx:] = l2_norm(emb_batch).cpu() else: embeddings1[idx:] = self.growup( self.model(batch.to(conf.device))).cpu() while idx + conf.batch_size <= len(carray2): batch = torch.tensor(carray2[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)).cpu() + \ self.model(fliped.to(conf.device)).cpu() embeddings2[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings2[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray2): batch = torch.tensor(carray2[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)).cpu() + \ self.model(fliped.to(conf.device)).cpu() embeddings2[idx:] = l2_norm(emb_batch).cpu() else: embeddings2[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate_child( embeddings1, embeddings2, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = transforms.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def zero_grad(self): self.optimizer.zero_grad() self.optimizer_g.zero_grad() self.optimizer_d.zero_grad() def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() for imgs, labels, ages in tqdm(iter(self.loader)): self.optimizer.zero_grad() imgs = imgs.to(conf.device) labels = labels.to(conf.device) embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. # added wrong on evaluations if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') # LFW evaluation accuracy, best_threshold, roc_curve_tensor, dist = self.evaluate( conf, self.lfw, self.lfw_issame) # NEGATIVE WRONG wrong_list = np.where((self.lfw_issame == False) & (dist < best_threshold))[0] negative_wrong = len(wrong_list) # POSITIVE WRONG wrong_list = np.where((self.lfw_issame == True) & (dist > best_threshold))[0] positive_wrong = len(wrong_list) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor, negative_wrong, positive_wrong) # FGNETC evaluation accuracy2, best_threshold2, roc_curve_tensor2, dist2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) # NEGATIVE WRONG wrong_list = np.where((self.fgnetc_issame == False) & (dist2 < best_threshold2))[0] negative_wrong2 = len(wrong_list) # POSITIVE WRONG wrong_list = np.where((self.fgnetc_issame == True) & (dist2 > best_threshold2))[0] positive_wrong2 = len(wrong_list) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2, negative_wrong2, positive_wrong2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? if conf.finetune_model_path is not None: self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) else: self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 print('Horray!') def train_with_growup(self, conf, epochs): ''' Our method ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() self.optimizer_g2.zero_grad() self.growup.train() c = (ages == 0) # select children for enhancement embeddings = self.model(imgs) if sum(c) > 1: # there might be no childern in loader's batch embeddings_c = embeddings[c] embeddings_a_hat = self.growup(embeddings_c) embeddings[c] = embeddings_a_hat elif sum(c) == 1: self.growup.eval() embeddings_c = embeddings[c] embeddings_a_hat = self.growup(embeddings_c) embeddings[c] = embeddings_a_hat thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() self.optimizer_g2.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() self.growup.train() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] embeddings_a_hat = self.growup(embeddings_c) labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_a_hat)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer_g.zero_grad() embeddings_c = self.model(imgs_c) embeddings_a_hat = self.growup(embeddings_c) pred_c = torch.squeeze(self.discriminator(embeddings_a_hat)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) l1_loss = conf.l1_loss(embeddings_a_hat, embeddings_c) g_total_loss = g_loss + 10 * l1_loss g_total_loss.backward() # g_loss.backward() self.optimizer_g.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate_child( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def train_age_invariant(self, conf, epochs): ''' Our method, without growup ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() self.schedule_lr2() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_c)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer2.zero_grad() embeddings_c = self.model(imgs_c) pred_c = torch.squeeze(self.discriminator(embeddings_c)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) g_loss.backward() self.optimizer2.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def train_age_invariant2(self, conf, epochs): ''' Our method, without growup, using paired dataset TODO ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() self.schedule_lr2() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_c)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer2.zero_grad() embeddings_c = self.model(imgs_c) pred_c = torch.squeeze(self.discriminator(embeddings_c)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) g_loss.backward() self.optimizer2.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def analyze_angle(self, conf, name): ''' Only works on age labeled vgg dataset, agedb dataset ''' angle_table = [{ 0: set(), 1: set(), 2: set(), 3: set(), 4: set(), 5: set(), 6: set(), 7: set() } for i in range(self.class_num)] # batch = 0 # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device) if conf.resume_analysis: self.loader = [] for imgs, labels, ages in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) ages = ages.to(conf.device) embeddings = self.model(imgs) if conf.use_dp: kernel_norm = l2_norm(self.head.module.kernel, axis=0) cos_theta = torch.mm(embeddings, kernel_norm) cos_theta = cos_theta.clamp(-1, 1) else: cos_theta = self.head.get_angle(embeddings) thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta))) for i in range(len(thetas)): age_bin = 7 if ages[i] < 26: age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2 elif ages[i] < 66: age_bin = int(((ages[i] + 4) // 10).item()) angle_table[labels[i]][age_bin].add( thetas[i][labels[i]].item()) if conf.resume_analysis: with open('analysis/angle_table.pkl', 'rb') as f: angle_table = pickle.load(f) else: with open('analysis/angle_table.pkl', 'wb') as f: pickle.dump(angle_table, f) count, avg_angle = [], [] for i in range(self.class_num): count.append( [len(single_set) for single_set in angle_table[i].values()]) avg_angle.append([ sum(list(single_set)) / len(single_set) if len(single_set) else 0 # if set() size is zero, avg is zero for single_set in angle_table[i].values() ]) count_df = pd.DataFrame(count) avg_angle_df = pd.DataFrame(avg_angle) with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format( conf.data_mode, name)) as writer: count_df.to_excel(writer, sheet_name='count') avg_angle_df.to_excel(writer, sheet_name='avg_angle') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def schedule_lr2(self): for params in self.optimizer2.param_groups: params['lr'] /= 10 print(self.optimizer2) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = transforms.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum def save_best_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path os.makedirs('work_space/models', exist_ok=True) torch.save( self.model.state_dict(), str(save_path) + ('lfw_best_model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), str(save_path) + ('lfw_best_head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), str(save_path) + ('lfw_best_optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path os.makedirs('work_space/models', exist_ok=True) torch.save( self.model.state_dict(), str(save_path) + ('/model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), str(save_path) + ('/head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), str(save_path) + ('/optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if conf.discriminator: torch.save( self.growup.state_dict(), str(save_path) + ('/growup_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False, analyze=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(os.path.join(save_path, 'model_{}'.format(fixed_str)))) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) if not analyze: self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
def train(args): # gpu init multi_gpu = False if len(args.gpus.split(',')) > 1: multi_gpu = True os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') D = MultiscaleDiscriminator( input_nc=3, ndf=64, n_layers=3, use_sigmoid=False, norm_layer=torch.nn.InstanceNorm2d) # pix2pix use MSEloss G = AAD_Gen() F = Backbone(50, drop_ratio=0.6, mode='ir_se') F.load_state_dict(torch.load(args.arc_model_path)) E = Att_Encoder() optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.0, 0.999)) optimizer_GE = torch.optim.Adam([{ 'params': G.parameters() }, { 'params': E.parameters() }], lr=0.0004, betas=(0.0, 0.999)) if multi_gpu: D = DataParallel(D).to(device) G = DataParallel(G).to(device) F = DataParallel(F).to(device) E = DataParallel(E).to(device) else: D = D.to(device) G = G.to(device) F = F.to(device) E = E.to(device) if args.resume: if os.path.isfile(args.resume_model_path): print("Loading checkpoint from {}".format(args.resume_model_path)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] D.load_state_dict(checkpoint["state_dict_D"]) G.load_state_dict(checkpoint["state_dict_G"]) E.load_state_dict(checkpoint["state_dict_E"]) # optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) optimizer_GE.load_state_dict(checkpoint['optimizer_GE']) else: print('Cannot found checkpoint {}'.format(args.resume_model_path)) else: args.start_epoch = 1 def print_with_time(string): print(time.strftime("%Y-%m-%d %H:%M:%S ", time.localtime()) + string) def weights_init(m): classname = m.__class__.__name__ if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0.0, 0.02) if classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) def set_requires_grad(nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def trans_batch(batch): t = trans.Compose( [trans.ToPILImage(), trans.Resize((112, 112)), trans.ToTensor()]) bs = batch.shape[0] res = torch.ones(bs, 3, 112, 112).type_as(batch) for i in range(bs): res[i] = t(batch[i].cpu()) return res set_requires_grad(F, requires_grad=False) data_transform = trans.Compose([ trans.ToTensor(), trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) #dataset = ImageFolder(args.data_path, transform=data_transform) dataset = FaceEmbed(args.data_path) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) D.apply(weights_init) G.apply(weights_init) E.apply(weights_init) for epoch in range(args.start_epoch, args.total_epoch + 1): D.train() G.train() F.eval( ) # Only extract features! # input dim=3,256,256 out dim=256 ! E.train() for batch_idx, data in enumerate(data_loader): time_curr = time.time() iteration = (epoch - 1) * len(data_loader) + batch_idx try: source, target, label = data source = source.to(device) target = target.to(device) label = torch.LongTensor(label).to(device) #Zid =F(trans_batch(source)) # bs, 512 Zid = F( downsample(source[:, :, 50:-10, 30:-30], size=(112, 112))) Zatt = E(target) # list:8 each:bs,,, Yst0 = G(Zid, Zatt) # bs,3,256,256 # train discriminators pred_gen = D(Yst0.detach()) #pred_gen = list(map(lambda x: x[0].detach(), pred_gen)) pred_real = D(target) optimizer_D.zero_grad() loss_real, loss_fake = loss_hinge_dis()(pred_gen, pred_real) L_dis = loss_real + loss_fake # if batch_idx%3==0: L_dis.backward() optimizer_D.step() # train generators pred_gen = D(Yst0) L_gen = loss_hinge_gen()(pred_gen) #L_id = IdLoss()(F(trans_batch(Yst0)), Zid) L_id = IdLoss()(F( downsample(Yst0[:, :, 50:-10, 30:-30], size=(112, 112))), Zid) #Zatt = list(map(lambda x: x.detach(), Zatt)) L_att = AttrLoss()(E(Yst0), Zatt) L_Rec = RecLoss()(Yst0, target, label) Loss = (L_gen + 10 * L_att + 5 * L_id + 10 * L_Rec).to(device) optimizer_GE.zero_grad() Loss.backward() optimizer_GE.step() except Exception as e: print(e) continue if batch_idx % args.log_interval == 0 or batch_idx == 20: time_used = time.time() - time_curr print_with_time( 'Train Epoch: {} [{}/{} ({:.0f}%)], L_dis:{:.4f}, loss_real:{:.4f}, loss_fake:{:.4f}, Loss:{:.4f}, L_gen:{:.4f}, L_id:{:.4f}, L_att:{:.4f}, L_Rec:{:.4f}' .format( epoch, batch_idx * len(data), len(data_loader.dataset), 100. * batch_idx * len(data) / len(data_loader.dataset), L_dis.item(), loss_real.item(), loss_fake.item(), Loss.item(), L_gen.item(), 5 * L_id.item(), 10 * L_att.item(), 10 * L_Rec)) time_curr = time.time() if epoch % args.save_interval == 0: #or batch_idx*len(data) % 350004==0: state = { "epoch": epoch, "state_dict_D": D.state_dict(), "state_dict_G": G.state_dict(), "state_dict_E": E.state_dict(), "optimizer_D": optimizer_D.state_dict(), "optimizer_GE": optimizer_GE.state_dict(), # "optimizer_E": optimizer_E.state_dict(), } filename = "../model/train1_{:03d}_{:03d}.pth.tar".format( epoch, batch_idx * len(data)) torch.save(state, filename)
class face_learner(object): def __init__(self, conf, inference=False): print(conf) self.num_splits = int(conf.meta_file.split('_labels.txt')[0][-1]) if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if conf.device > 1: gpu_ids = list( range(0, min(torch.cuda.device_count(), conf.device))) self.model = nn.DataParallel(self.model, device_ids=gpu_ids).cuda() else: self.model = self.model.cuda() if not inference: self.milestones = conf.milestones if conf.remove_single is True: conf.meta_file = conf.meta_file.replace('.txt', '_clean.txt') meta_file = open(conf.meta_file, 'r') meta = meta_file.readlines() pseudo_all = [int(item.split('\n')[0]) for item in meta] pseudo_classnum = set(pseudo_all) if -1 in pseudo_classnum: pseudo_classnum = len(pseudo_classnum) - 1 else: pseudo_classnum = len(pseudo_classnum) print('classnum:{}'.format(pseudo_classnum)) pseudo_classes = [ pseudo_all[count[index]:count[index + 1]] for index in range(self.num_splits) ] meta_file.close() train_dataset = [get_train_dataset(conf.emore_folder)] + [ get_pseudo_dataset([conf.pseudo_folder, index + 1], pseudo_classes[index], conf.remove_single) for index in range(self.num_splits) ] self.class_num = [num for _, num in train_dataset] print('Loading dataset done') train_longest_size = [len(item[0]) for item in train_dataset] temp = int(np.floor(conf.batch_size // (self.num_splits + 1))) self.batch_size = [conf.batch_size - temp * self.num_splits ] + [temp] * self.num_splits train_longest_size = max([ int(np.floor(td / bs)) for td, bs in zip(train_longest_size, self.batch_size) ]) train_sampler = [ GivenSizeSampler(td[0], total_size=train_longest_size * bs, rand_seed=None) for td, bs in zip(train_dataset, self.batch_size) ] self.train_loader = [ DataLoader(train_dataset[k][0], batch_size=self.batch_size[k], shuffle=False, pin_memory=conf.pin_memory, num_workers=conf.num_workers, sampler=train_sampler[k]) for k in range(1 + self.num_splits) ] print('Loading loader done') self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = [ Arcface(embedding_size=conf.embedding_size, classnum=self.class_num[0]), Arcface(embedding_size=conf.embedding_size, classnum=pseudo_classnum) ] if conf.device > 1: self.head = [ nn.DataParallel(self.head[0], device_ids=gpu_ids).cuda(), nn.DataParallel(self.head[1], device_ids=gpu_ids).cuda() ] else: self.head = [self.head[0].cuda(), self.head[1].cuda()] print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model.module) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.parameters()], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: params = [a.module.parameters() for a in self.head] params = list(params[0]) + list(params[1]) #from IPython import embed;embed() self.optimizer = optim.SGD([{ 'params': paras_wo_bn + params, 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print(self.optimizer) if conf.resume is not None: self.start_epoch = self.load_state(conf.resume) else: self.start_epoch = 0 print('optimizers generated') self.board_loss_every = len(self.train_loader[0]) // 10 self.evaluate_every = len(self.train_loader[0]) // 5 self.save_every = len(self.train_loader[0]) // 5 self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data( conf.eval_path) else: self.threshold = conf.threshold def save_state(self, conf, accuracy, e, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path if not os.path.exists(str(save_path)): os.makedirs(str(save_path)) else: save_path = conf.model_path if not os.path.exists(str(save_path)): os.makedirs(str(save_path)) if model_only: torch.save( self.model.state_dict(), os.path.join(str(save_path), ('model_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra)))) else: save = { 'optimizer': self.optimizer.state_dict(), 'head': [self.head[0].state_dict(), self.head[1].state_dict()], 'model': self.model.state_dict(), 'epoch': e } torch.save( save, os.path.join(str(save_path), ('accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra)))) def load_state(self, save_path, from_save_folder=False, model_only=False): if model_only: self.model.load_state_dict(torch.load(save_path)) else: state = torch.load(save_path) self.model.load_state_dict(state['model']) self.head[0].load_state_dict(state['head'][0]) self.head[1].load_state_dict(state['head'][1]) self.optimizer.load_state_dict(state['optimizer']) return state['epoch'] + 1 def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model( fliped.cuda()) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model( batch.cuda()).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model( fliped.cuda()) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.cuda()).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): imgs = imgs.cuda() labels = labels.cuda() batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(self.start_epoch, epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() self.iters = [iter(loader) for loader in self.train_loader] for i in tqdm(range(len(self.train_loader[0]))): data = [self.iters[i].next() for i in range(len(self.iters))] imgs, labels = zip( *[data[k] for k in range(self.num_splits + 1)]) labeled_num = len(imgs[0]) imgs = torch.cat(imgs, dim=0) labels = torch.cat(labels, dim=0) imgs = imgs.cuda() labels = labels.cuda() self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head[0](embeddings[:labeled_num], labels[:labeled_num]) losses1 = conf.ce_loss(thetas, labels[:labeled_num]) thetas = self.head[1](embeddings[labeled_num:], labels[labeled_num:]) losses2 = conf.ce_loss(thetas, labels[labeled_num:]) num_ratio = labeled_num / len(embeddings) loss = num_ratio * losses1 + (1 - num_ratio) * losses2 loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) print('step:{}, train_loss:{}'.format( self.step, loss_board)) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.agedb_30, self.agedb_30_issame) accuracy1 = accuracy self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) accuracy2 = accuracy self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.cfp_fp, self.cfp_fp_issame) accuracy3 = accuracy self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor) print('step:{}, agedb:{},lfw:{},cfp_fp:{}'.format( self.step, accuracy1, accuracy2, accuracy3)) self.model.train() if self.step % self.save_every == 0 and self.step != 0: self.save_state(conf, accuracy, e) self.step += 1 self.save_state(conf, accuracy, e, to_save_folder=True, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model(conf.test_transform(img).cuda().unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).cuda().unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model(conf.test_transform(img).cuda().unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum
def prepare(args): resume_from_checkpoint = args.resume_from_checkpoint prepare_start_time = time.time() logger.info('global', 'Start preparing.') check_config_dir() logger.info('setting', config_info(), time_report=False) model = Backbone() model = model.cuda() logger.info('setting', model_summary(model), time_report=False) logger.info('setting', str(model), time_report=False) branches = [ main_branch(Config.nr_class, Config.in_planes), parsing_branch(Config.nr_class, Config.in_planes), parsing_branch(Config.nr_class, Config.in_planes), parsing_branch(Config.nr_class, Config.in_planes), parsing_branch(Config.nr_class, Config.in_planes) ] train_transforms = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(Config.input_shape), transforms.RandomApply([ transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0) ], p=0.5), transforms.RandomHorizontalFlip(), transforms.Pad(10), transforms.RandomCrop(Config.input_shape), transforms.ToTensor(), transforms.RandomErasing(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transforms = transforms.Compose([ transforms.Resize(Config.input_shape), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trainset = Veri776_train(transforms=train_transforms, need_mask=True, bg_switch=Config.p_bgswitch) testset = Veri776_test(transforms=test_transforms) pksampler = PKSampler(trainset, p=Config.P, k=Config.K) train_loader = torch.utils.data.DataLoader(trainset, batch_size=Config.batch_size, sampler=pksampler, num_workers=Config.nr_worker, pin_memory=True) test_loader = torch.utils.data.DataLoader( testset, batch_size=Config.batch_size, sampler=torch.utils.data.SequentialSampler(testset), num_workers=Config.nr_worker, pin_memory=True) weight_decay_setting = parm_list_with_Wdecay_multi([model] + branches) optimizer = torch.optim.Adam(weight_decay_setting, lr=Config.lr) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_multi_func) losses = {} losses['cross_entropy_loss'] = [ torch.nn.CrossEntropyLoss(), weight_cross_entropy(Config.ce_thres[0]), weight_cross_entropy(Config.ce_thres[1]), weight_cross_entropy(Config.ce_thres[2]), weight_cross_entropy(Config.ce_thres[3]) ] losses['triplet_hard_loss'] = [ triplet_hard_loss(margin=Config.triplet_margin), weighted_triplet_hard_loss(margin=Config.branch_margin, soft_margin=Config.soft_marigin), weighted_triplet_hard_loss(margin=Config.branch_margin, soft_margin=Config.soft_marigin), weighted_triplet_hard_loss(margin=Config.branch_margin, soft_margin=Config.soft_marigin), weighted_triplet_hard_loss(margin=Config.branch_margin, soft_margin=Config.soft_marigin) ] for k in losses.keys(): if isinstance(losses[k], list): for i in range(len(losses[k])): losses[k][i] = losses[k][i].cuda() else: losses[k] = losses[k].cuda() for i in range(len(branches)): branches[i] = branches[i].cuda() start_epoch = 0 if resume_from_checkpoint and os.path.exists(Config.checkpoint_path): checkpoint = load_checkpoint() start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) # continue training for next the epoch of the checkpoint, or simply start from 1 start_epoch += 1 ret = { 'start_epoch': start_epoch, 'model': model, 'branches': branches, 'train_loader': train_loader, 'test_loader': test_loader, 'optimizer': optimizer, 'scheduler': scheduler, 'losses': losses } prepare_end_time = time.time() time_spent = sec2min_sec(prepare_start_time, prepare_end_time) logger.info( 'global', 'Finish preparing, time spend: {}mins {}s.'.format( time_spent[0], time_spent[1])) return ret