class face_learner(object): def __init__(self, conf, inference=False): self.backbone = Backbone().to(conf.device) self.idprehead = PreheadID().to(conf.device) self.idhead = Arcface().to(conf.device) self.attrhead = Attrhead().to(conf.device) print('model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones train_dataset = CelebA( 'dataset', 'celebA_train.txt', trans.Compose([ trans.RandomHorizontalFlip(), trans.ToTensor(), trans.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) valid_dataset = CelebA( 'dataset', 'celebA_validation.txt', trans.Compose([ trans.RandomHorizontalFlip(), trans.ToTensor(), trans.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) self.loader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=conf.pin_memory, num_workers=conf.num_workers) self.valid_loader = DataLoader(valid_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=conf.pin_memory, num_workers=conf.num_workers) self.writer = SummaryWriter(conf.log_path) self.step = 0 paras_only_bn_1, paras_wo_bn_1 = separate_bn_paras(self.backbone) paras_only_bn_2, paras_wo_bn_2 = separate_bn_paras(self.idprehead) paras_only_bn_3, paras_wo_bn_3 = separate_bn_paras(self.attrhead) paras_only_bn = paras_only_bn_1 + paras_only_bn_2 + paras_only_bn_3 paras_wo_bn = paras_wo_bn_1 + paras_wo_bn_2 + paras_wo_bn_3 self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.idhead.kernel], 'weight_decay': 1e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.board_loss_every = len(self.loader) // 8 self.evaluate_every = len(self.loader) // 4 self.save_every = len(self.loader) // 2 self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data( Path("data/faces_emore/")) else: self.threshold = conf.threshold def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.backbone.state_dict(), save_path / ('backbone_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.idprehead.state_dict(), save_path / ('idprehead_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.idhead.state_dict(), save_path / ('idhead_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.attrhead.state_dict(), save_path / ('attrhead_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) # def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): # if from_save_folder: # save_path = conf.save_path # else: # save_path = conf.model_path # self.model.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str))) # if not model_only: # self.head.load_state_dict(torch.load(save_path/'head_{}'.format(fixed_str))) # self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.backbone.eval() self.idprehead.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.idprehead( self.backbone(batch.to(conf.device))) + self.idprehead( self.backbone(fliped.to(conf.device))) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.idprehead( self.backbone(batch.to(conf.device))).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.idprehead( self.backbone(batch.to(conf.device))) + self.idprehead( self.backbone(fliped.to(conf.device))) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.idprehead( self.backbone(batch.to(conf.device))).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor # def find_lr(self, # conf, # init_value=1e-8, # final_value=10., # beta=0.98, # bloding_scale=3., # num=None): # if not num: # num = len(self.loader) # mult = (final_value / init_value)**(1 / num) # lr = init_value # for params in self.optimizer.param_groups: # params['lr'] = lr # self.model.train() # avg_loss = 0. # best_loss = 0. # batch_num = 0 # losses = [] # log_lrs = [] # for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): # imgs = imgs.to(conf.device) # labels = labels.to(conf.device) # batch_num += 1 # self.optimizer.zero_grad() # embeddings = self.model(imgs) # thetas = self.head(embeddings, labels) # loss = conf.ce_loss(thetas, labels) # #Compute the smoothed loss # avg_loss = beta * avg_loss + (1 - beta) * loss.item() # self.writer.add_scalar('avg_loss', avg_loss, batch_num) # smoothed_loss = avg_loss / (1 - beta**batch_num) # self.writer.add_scalar('smoothed_loss', smoothed_loss,batch_num) # #Stop if the loss is exploding # if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: # print('exited with best_loss at {}'.format(best_loss)) # plt.plot(log_lrs[10:-5], losses[10:-5]) # return log_lrs, losses # #Record the best loss # if smoothed_loss < best_loss or batch_num == 1: # best_loss = smoothed_loss # #Store the values # losses.append(smoothed_loss) # log_lrs.append(math.log10(lr)) # self.writer.add_scalar('log_lr', math.log10(lr), batch_num) # #Do the SGD step # #Update the lr for the next step # loss.backward() # self.optimizer.step() # lr *= mult # for params in self.optimizer.param_groups: # params['lr'] = lr # if batch_num > num: # plt.plot(log_lrs[10:-5], losses[10:-5]) # return log_lrs, losses def train(self, conf, epochs): self.backbone.train() self.idprehead.train() self.attrhead.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) attributes = labels[:, :40] attributes = (attributes + 1) * 0.5 ids = labels[:, 40] self.optimizer.zero_grad() embeddings = self.backbone(imgs) thetas = self.idhead(self.idprehead(embeddings), ids) # attrs = self.attrhead(embeddings) # attributes = attributes.type_as(attrs) loss = conf.ce_loss(thetas, ids) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.agedb_30, self.agedb_30_issame) self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.cfp_fp, self.cfp_fp_issame) self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor) # attr_loss, attr_accu = self.validate_attr(conf) # print(attr_loss, attr_accu) self.backbone.train() self.idprehead.train() self.attrhead.train() if self.step % self.save_every == 0 and self.step != 0: self.save_state(conf, accuracy) self.step += 1 self.save_state(conf, accuracy, to_save_folder=True, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def validate_attr(self, conf): self.backbone.eval() self.attrhead.eval() losses = [] accuracies = [] with torch.no_grad(): for i, (input, target) in enumerate(self.valid_loader): input = input.to(conf.device) target = target.to(conf.device) target = target[:, :40] target = (target + 1) * 0.5 target = target.type(torch.cuda.FloatTensor) # compute output embedding = self.backbone(input) output = self.attrhead(embedding) # measure accuracy and record loss loss = conf.bc_loss(output, target) pred = torch.where(torch.sigmoid(output) > 0.5, 1.0, 0.0) accuracies.append((pred == target).float().sum() / (target.size()[0] * target.size()[1])) losses.append(loss) loss_avg = sum(losses) / len(losses) accu_avg = sum(accuracies) / len(accuracies) return loss_avg, accu_avg
class face_learner(object): def __init__(self, conf, inference=False): print(conf) if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) print('class_num:', self.class_num) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') # if conf.data_mode == 'small_vgg': # self.board_loss_every = len(self.loader) # print('len(loader', len(self.loader)) # self.evaluate_every = len(self.loader) # self.save_every = len(self.loader) # # self.lfw, self.lfw_issame = get_val_data(conf, conf.smallvgg_folder) # else: # self.board_loss_every = len(self.loader) # self.evaluate_every = len(self.loader)//10 # self.save_every = len(self.loader)//5 self.agedb_30, self.cfp_fp, self.lfw, self.kface, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame, self.kface_issame = get_val_data( conf, self.loader.dataset.root.parent) else: self.threshold = conf.threshold def save_state(self, conf, accuracy, e, loss, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path model_name = f'model_e:{e+1}_acc:{accuracy}_loss:{loss}_{extra}.pth' torch.save(self.model.state_dict(), save_path / model_name) if not model_only: # 똥째로 저장 state = { 'model': self.model.state_dict(), 'head': self.head.state_dict(), 'optimizer': self.optimizer.state_dict() } torch.save(state, save_path / model_name) # print('model saved: ', model_name) ## 따로따로 저장 # torch.save( # self.head.state_dict(), save_path / # ('head_{}_accuracy:{}_epoch:{}_step:{}_{}.pth'.format(get_time(), accuracy, e, self.step, extra))) # torch.save( # self.optimizer.state_dict(), save_path / # ('optimizer_{}_accuracy:{}_epoch:{}_step:{}_{}.pth'.format(get_time(), accuracy, e, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path if model_only: self.model.load_state_dict( torch.load(save_path / 'model_{}'.format(fixed_str))) if not model_only: model_name = 'model_{}'.format(fixed_str) state = torch.load(save_path / model_name) self.model.load_state_dict(state['model']) self.head.load_state_dict(state['head']) self.optimizer.load_state_dict(state['optimizer']) # self.head.load_state_dict(torch.load(save_path/'head_{}'.format(fixed_str))) # self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model.train() running_loss = 0. time_ = datetime.datetime.now() # check parameter of model print("------------------------------------------------------------") total_params = sum(p.numel() for p in self.model.parameters()) print("num of parameter : ", total_params) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) print("num of trainable_ parameter :", trainable_params) print("------------------------------------------------------------") # if conf.data_mode == 'small_vgg': for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() accuracy_list = [] for iter_, (imgs, labels) in tqdm(enumerate(iter(self.loader))): # print('iter_', type(iter_)) # print('step', self.step) imgs = imgs.to(conf.device) labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if iter_ % conf.print_iter == 0: elapsed = datetime.datetime.now() - time_ expected = elapsed * (conf.batch_size / conf.print_iter) _epoch = round(e + ((iter_ + 1) / conf.batch_size), 2) # print('_epoch', _epoch) _loss = round(loss.item(), 5) print( f'[{_epoch}/{conf.epochs}] loss:{_loss}, elapsed:{elapsed}, expected per epoch: {expected}' ) time_ = datetime.datetime.now() self.step += 1 # log 남기기 board_loss_every = len(self.loader) / conf.print_iter print('board_loss_ebery', board_loss_every) loss_board = running_loss / board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. # validate accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.agedb_30, self.agedb_30_issame) self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor) accuracy_list.append(accuracy) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy_list.append(accuracy) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.cfp_fp, self.cfp_fp_issame) self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor) accuracy_list.append(accuracy) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.cfp_fp, self.cfp_fp_issame) self.board_val('kface', accuracy, best_threshold, roc_curve_tensor) accuracy_list.append(accuracy) # save model, info # print(accuracy_list) accuracy_mean = round(sum(accuracy_list) / conf.testset_num, 5) loss_mean = round(loss_board, 5) self.save_state(conf, accuracy_mean, e, loss_mean, to_save_folder=False, extra='training') time_ = datetime.datetime.now() elapsed = datetime.datetime.now() - time_ print( f'[epoch {e + 1}] acc: {accuracy_mean}, loss: {loss_mean}, elapsed: {elapsed}' ) # print('train_loss:', loss_board) self.save_state(conf, accuracy_mean, e, loss_mean, to_save_folder=False, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print('***self.optimizer:', self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum
class face_learner(object): def __init__(self, conf, inference=False, transfer=0): pprint.pprint(conf) if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: if transfer == 3: self.optimizer = optim.SGD( [{ 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) elif transfer == 2: self.optimizer = optim.SGD([ { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, ], lr=conf.lr, momentum=conf.momentum) elif transfer == 1: self.optimizer = optim.SGD([ { 'params': [self.head.kernel], 'weight_decay': 4e-4 }, ], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.board_loss_every = len(self.loader) // 5 # originally, 100 self.evaluate_every = len(self.loader) // 5 # originally, 10 self.save_every = len(self.loader) // 2 # originally, 5 # self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(self.loader.dataset.root.parent) self.val_112, self.val_112_issame = get_val_pair( self.loader.dataset.root.parent, 'val_112') else: self.threshold = conf.threshold def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.model.state_dict(), save_path / ('model_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), save_path / ('head_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{:0.2f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(save_path / 'model_{}'.format(fixed_str), map_location=conf.device)) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs, ext='final'): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: # accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.agedb_30, self.agedb_30_issame) # self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor) # accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.lfw, self.lfw_issame) # self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) # accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.cfp_fp, self.cfp_fp_issame) # self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.val_112, self.val_112_issame) self.board_val('n+n_val_112', accuracy, best_threshold, roc_curve_tensor) self.model.train() # if self.step % self.save_every == 0 and self.step != 0: # self.save_state(conf, accuracy, extra=ext) self.step += 1 # self.save_state(conf, accuracy, to_save_folder=True, extra=ext, model_only=True) def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum def binfer(self, conf, faces, target_embs, tta=False): ''' return raw scores for every class faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) # print(dist) return dist.detach().cpu().numpy()
class face_learner(object): def __init__(self, conf, inference=False): accuracy = 0.0 logger.debug(conf) if conf.use_mobilfacenet: # self.model = MobileFaceNet(conf.embedding_size).to(conf.device) self.model = MobileFaceNet(conf.embedding_size).cuda() logger.debug('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).cuda() #.to(conf.device) logger.debug('{}_{} model generated'.format( conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones logger.info('loading data...') self.loader, self.class_num = get_train_loader(conf) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).cuda() logger.debug('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) # self.optimizer = torch.nn.parallel.DistributedDataParallel(optimizer,device_ids=[conf.argsed]) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) if conf.fp16: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O2") self.model = DistributedDataParallel(self.model).cuda() else: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[conf.argsed]).cuda() #add line for distributed logger.debug('dataset {}'.format(self.loader.dataset)) self.board_loss_every = len(self.loader) // 100 self.evaluate_every = len(self.loader) // 4 self.save_every = len(self.loader) // 4 self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data( Path(self.loader.dataset.root).parent) else: self.threshold = conf.threshold self.loader, self.query_ds, self.gallery_ds = get_test_loader(conf) def save_state(self, conf, epoch, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) torch.save( # self.model.state_dict(), save_path / # ('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, # extra))) self.model.state_dict(), save_path / ('model_{}_{}_acc:{:.4f}_{}.pth'.format( epoch, self.step, accuracy, extra))) if not model_only: torch.save( # self.head.state_dict(), save_path / # ('head_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, # extra))) self.head.state_dict(), save_path / ('head_{}_{}_acc:{:.4f}_{}.pth'.format( epoch, self.step, accuracy, extra))) torch.save( # self.optimizer.state_dict(), save_path / # ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, # self.step, extra))) self.optimizer.state_dict(), save_path / ('optimizer_{}_{}_acc:{:.4f}_{}.pth'.format( epoch, self.step, accuracy, extra))) # torch.save( # amp.state_dict(), save_path / # ('amp_{}_{}_acc:{:.4f}_{}.pth'.format(epoch, self.step, accuracy, # extra))) def load_network(self, conf, save_path): state_dict = torch.load(save_path, map_location='cuda:{}'.format(conf.local_rank)) # create new OrderedDict that does not contain `module.` new_state_dict = OrderedDict() for k, v in state_dict.items(): # logger.debug('key {}'.format(k)) namekey = k[7:] # logger.debug('key {}'.format(namekey)) # remove 'module.' new_state_dict[namekey] = v # load params return new_state_dict def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path if conf.resume: self.model.load_state_dict( torch.load(save_path / 'model_{}'.format(fixed_str), map_location='cuda:{}'.format(conf.local_rank))) else: self.model.load_state_dict( self.load_network(conf, save_path / 'model_{}'.format(fixed_str))) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str))) logger.info('load optimizer {}'.format(self.optimizer)) # amp.load_state_dict(torch.load(save_path / 'amp_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model( fliped.cuda()) # emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings[idx:idx + conf.batch_size] = self.model( batch.cuda()).cpu() # embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model( fliped.cuda()) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.cuda()).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor # true top 1, false top 1, miss def compute_true_false_miss(self, conf, log_dir, feat_path, tta): def gen_distmat(qf, q_pids, gf, g_pids): m, n = qf.shape[0], gf.shape[0] logger.debug('query shape {}, gallery shape {}'.format( qf.shape, gf.shape)) # logger.debug('q_pids {}, g_pids {}'.format(q_pids, g_pids)) distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() return distmat def distance(emb1, emb2): diff = np.subtract(emb1, emb2) dist = np.sum(np.square(diff), 1) return dist self.model.eval() if conf.gen_feature: with torch.no_grad(): query_feature, query_label = extract_feature( conf, self.model, self.loader['query']['dl'], tta) gallery_feature, gallery_label = extract_feature( conf, self.model, self.loader['gallery']['dl'], tta) result = { 'query_feature': query_feature.numpy(), 'query_label': query_label.numpy(), 'gallery_feature': gallery_feature.numpy(), 'gallery_label': gallery_label.numpy() } scipy.io.savemat(feat_path, result) else: result = scipy.io.loadmat(feat_path) query_feature = torch.from_numpy(result['query_feature']) query_label = torch.from_numpy(result['query_label'])[0] gallery_feature = torch.from_numpy(result['gallery_feature']) gallery_label = torch.from_numpy(result['gallery_label'])[0] # np.set_printoptions(threshold=np.inf) # logger.debug('query_label: {}'.format(query_label.numpy())) # logger.debug('gallery_label: {}'.format(gallery_label.numpy())) # feat = result['query_feature'][0:2] # logger.debug('feat {}'.format(feat.shape)) # emb1 = np.repeat(feat, [3,1], axis=0) # emb2 = result['gallery_feature'][0:4] # dist = distance(emb1, emb2) # logger.debug('distance {}'.format(dist)) distmat = gen_distmat(query_feature, query_label, gallery_feature, gallery_label) # record txt with open(os.path.join(log_dir, 'result.txt'), 'at') as f: f.write('%s\t%s\t%s\t%s\n' % ('threshold', 'acc', 'err', 'miss')) # record excel xls_file = xlwt.Workbook() sheet_1 = xls_file.add_sheet('sheet_1', cell_overwrite_ok=True) row = 0 path_excel = os.path.join(log_dir, 'result.xls') sheet_title = ['threshold', 'acc', 'err', 'miss'] for i_sheet in range(len(sheet_title)): sheet_1.write(row, i_sheet, sheet_title[i_sheet]) xls_file.save(path_excel) row += 1 index = np.argsort(distmat) # from small to large max_index = index[:, 0] # query_num = distmat.shape[0] # logger.debug('distmat {}'.format(distmat[0:2,0:4])) # for i in range(query_num): # logger.debug('query: {}, gallery: {}'.format(self.query_ds.imgs[i], self.gallery_ds.imgs[max_index[i]])) # logger.debug('index[i] {}'.format(index[i])) # logger.debug('distmat[i, max_index[i]]: {}'.format(distmat[i, max_index[i]])) # logger.debug('distmat[i] {}'.format(distmat[i])) query_list_file = 'data/probe.txt' gallery_list_file = 'data/gallery.txt' err_rank1 = os.path.join(log_dir, 'err_rank1.txt') data_path = DataPath(query_list_file, gallery_list_file) with open(err_rank1, 'at') as f: f.write('%s\t\t\t%s\n' % ('query', 'gallery')) thresholds = np.arange(0.4, 2, 0.01) for threshold in thresholds: acc, err, miss = compute_rank1(distmat, max_index, query_label, gallery_label, threshold, data_path, err_rank1) # record txt with open(os.path.join(log_dir, 'result.txt'), 'at') as f: f.write('%.6f\t%.6f\t%.6f\t%.6f\n' % (threshold, acc, err, miss)) # record excel list_data = [threshold, acc, err, miss] for i_1 in range(len(list_data)): sheet_1.write(row, i_1, list_data[i_1]) xls_file.save(path_excel) row += 1 def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model.train() # logger.debug('model {}'.format(self.model)) running_loss = 0. # 断点加载训练 if conf.resume: logger.debug('resume...') self.load_state(conf, 'ir_se50.pth', from_save_folder=True) logger.debug('optimizer {}'.format(self.optimizer)) for epoch in range(epochs): logger.debug('epoch {} started'.format(epoch)) # if epoch == self.milestones[0]: # self.schedule_lr() # if epoch == self.milestones[1]: # self.schedule_lr() # if epoch == self.milestones[2]: if epoch in self.milestones: self.schedule_lr() #for i, (imgs, labels) in tqdm(enumerate(self.loader)): #for imgs, labels in enumerate(self.loader): for imgs, labels in tqdm(iter(self.loader)): imgs = imgs.cuda() #to(conf.device) labels = labels.cuda() #to(conf.device) self.optimizer.zero_grad() #print(imgs) embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) # loss.backward() if conf.fp16: # we use optimier to backward loss with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() running_loss += loss.item() # print(loss.item()) self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: #comment line loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: #comment line accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.agedb_30, self.agedb_30_issame) self.board_val('agedb_30', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.cfp_fp, self.cfp_fp_issame) self.board_val('cfp_fp', accuracy, best_threshold, roc_curve_tensor) # logger.debug('optimizer {}'.format(self.optimizer)) logger.debug( 'epoch {}, step {}, loss {:.4f}, acc {:.4f}'.format( epoch, self.step, loss.item(), accuracy)) self.model.train() if conf.local_rank == 0 and epoch >= 10 and self.step % self.save_every == 0 and self.step != 0: # if conf.local_rank == 0 and self.step % self.save_every == 0 and self.step != 0: self.save_state(conf, epoch, accuracy) self.step += 1 self.save_state(conf, epoch, accuracy, to_save_folder=True, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 logger.debug('optimizer {}'.format(self.optimizer)) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum
class face_learner(object): def __init__(self, conf, inference=False): print(conf) if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD([ {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5}, {'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4}, {'params': paras_only_bn} ], lr = conf.lr, momentum = conf.momentum) else: self.optimizer = optim.SGD([ {'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4}, {'params': paras_only_bn} ], lr = conf.lr, momentum = conf.momentum) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.board_loss_every = len(self.loader)//100 self.evaluate_every = len(self.loader)//10 self.save_every = len(self.loader)//5 self.lfw, self.lfw_issame = get_val_data(self.loader.dataset.root.parent) else: self.threshold = conf.threshold def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path torch.save( self.model.state_dict(), save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str))) if not model_only: self.head.load_state_dict(torch.load(save_path/'head_{}'.format(fixed_str))) self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str))) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds = 5, tta = False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model(batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model(fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % (self.board_loss_every+1) == 0 and self.step != 0: loss_board = running_loss / (self.board_loss_every+1) self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % (self.evaluate_every+1) == 0 and self.step != 0: accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) self.model.train() if self.step % (self.save_every+1) == 0 and self.step != 0: self.save_state(conf, accuracy) self.step += 1 self.save_state(conf, accuracy, to_save_folder=True, extra='final') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model(conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model(conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append(self.model(conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose(1,0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum
class face_learner(object): def __init__(self, conf, inference=False, transfer=0, ext='final'): pprint.pprint(conf) self.conf = conf if conf.arch == "mobile": self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') elif conf.arch == "ir_se": self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.arch).to(conf.device) print('{}_{} model generated'.format(conf.arch, conf.net_depth)) elif conf.arch == "resnet50": self.model = ResNet(embedding_size=512, arch=conf.arch).to(conf.device) print("resnet model {} generated".format(conf.arch)) else: exit("model not supported yet!") if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) tmp_idx = ext.rfind('_') # find the last '_' to replace it by '/' self.ext = '/' + ext[:tmp_idx] + '/' + ext[tmp_idx + 1:] self.writer = SummaryWriter(str(conf.log_path) + self.ext) self.step = 0 print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if transfer == 3: self.optimizer = optim.Adam( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr) # , momentum = conf.momentum) elif transfer == 2: self.optimizer = optim.Adam( [ { 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 4e-4 }, ], lr=conf.lr) # , momentum = conf.momentum) elif transfer == 1: self.optimizer = optim.Adam( [ { 'params': [self.head.kernel], 'weight_decay': 4e-4 }, ], lr=conf.lr) # , momentum = conf.momentum) else: """ self.optimizer = optim.SGD([ {'params': paras_wo_bn[:-1], 'weight_decay': 4e-5}, {'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4}, {'params': paras_only_bn} ], lr = conf.lr, momentum = conf.momentum) """ self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.head.parameters()), lr=conf.lr) print(self.optimizer) # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) print('optimizers generated') self.save_freq = len(self.loader) #//5 # originally, 100 self.evaluate_every = len(self.loader) #//5 # originally, 10 self.save_every = len(self.loader) #//2 # originally, 5 # self.agedb_30, self.cfp_fp, self.lfw, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame = get_val_data(self.loader.dataset.root.parent) # self.val_112, self.val_112_issame = get_val_pair(self.loader.dataset.root.parent, 'val_112') else: self.threshold = conf.threshold self.train_losses = [] self.train_counter = [] self.test_losses = [] self.test_accuracy = [] self.test_counter = [] def save_state(self, model_only=False): save_path = self.conf.stored_result_dir torch.save(self.model.state_dict(), save_path + os.sep + 'model.pth') if not model_only: torch.save(self.head.state_dict(), save_path + os.sep + 'head.pth') torch.save(self.optimizer.state_dict(), save_path + os.sep + 'optimizer.pth') def load_state(self, save_path, from_file=False, model_only=False): if from_file: if self.conf.arch == "mobile": self.model.load_state_dict( torch.load(save_path / 'model_mobilefacenet.pth', map_location=self.conf.device)) elif self.conf.arch == "ir_se": self.model.load_state_dict( torch.load(save_path / 'model_ir_se50.pth', map_location=self.conf.device)) else: exit("loading model not supported yet!") else: state_dict = torch.load(save_path, map_location=self.conf.device) if "module." in list(state_dict.keys())[0]: new_dict = {} for key in state_dict: new_key = key[7:] assert new_key in self.model.state_dict().keys( ), "wrong model loaded!" new_dict[new_key] = state_dict[key] self.model.load_state_dict(new_dict) else: self.model.load_state_dict(state_dict) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head.pth', map_location=self.conf.device)) self.optimizer.load_state_dict( torch.load(save_path / 'optimizer.pth')) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch) else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)) + self.model( fliped.to(conf.device)) embeddings[idx:] = l2_norm(emb_batch) else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def find_lr(self, conf, init_value=1e-8, final_value=10., beta=0.98, bloding_scale=3., num=None): if not num: num = len(self.loader) mult = (final_value / init_value)**(1 / num) lr = init_value for params in self.optimizer.param_groups: params['lr'] = lr self.model.train() avg_loss = 0. best_loss = 0. batch_num = 0 losses = [] log_lrs = [] for i, (imgs, labels) in enumerate( self.loader): #tqdm(enumerate(self.loader), total=num): imgs = imgs.to(conf.device) labels = labels.to(conf.device) batch_num += 1 self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) #Compute the smoothed loss avg_loss = beta * avg_loss + (1 - beta) * loss.item() self.writer.add_scalar('avg_loss', avg_loss, batch_num) smoothed_loss = avg_loss / (1 - beta**batch_num) self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num) #Stop if the loss is exploding if batch_num > 1 and smoothed_loss > bloding_scale * best_loss: print('exited with best_loss at {}'.format(best_loss)) plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses #Record the best loss if smoothed_loss < best_loss or batch_num == 1: best_loss = smoothed_loss #Store the values losses.append(smoothed_loss) log_lrs.append(math.log10(lr)) self.writer.add_scalar('log_lr', math.log10(lr), batch_num) #Do the SGD step #Update the lr for the next step loss.backward() self.optimizer.step() lr *= mult for params in self.optimizer.param_groups: params['lr'] = lr if batch_num > num: plt.plot(log_lrs[10:-5], losses[10:-5]) return log_lrs, losses def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in iter(self.loader): #tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.save_freq == 0 and self.step != 0: self.train_losses.append(loss.item()) self.train_counter.append(self.step) self.step += 1 self.save_loss() # self.save_state(conf, accuracy, to_save_folder=True, extra=self.ext, model_only=True) def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum def binfer(self, conf, faces, target_embs, tta=False): ''' return raw scores for every class faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' self.model.eval() self.plot_result() embs = [] for img in faces: if tta: mirror = trans.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) # print(dist) return dist.detach().cpu().numpy() # minimum, min_idx = torch.min(dist, dim=1) # min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 # return min_idx, minimum def evaluate(self, data_dir, names_idx, target_embs, tta=False): ''' return raw scores for every class faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' self.model.eval() score_names = [] score = [] wrong_names = dict() test_dir = data_dir for path in test_dir.iterdir(): if path.is_file(): continue # print(path) for fil in path.iterdir(): # print(fil) orig_name = ''.join( [i for i in fil.name.strip().split('.')[0]]) for name in names_idx.keys(): if name in orig_name: score_names.append(names_idx[name]) img = Image.open(str(fil)) with torch.no_grad(): if tta: mirror = trans.functional.hflip(img) emb = self.model( self.conf.test_transform(img).to( self.conf.device).unsqueeze(0)) emb_mirror = self.model( self.conf.test_transform(mirror).to( self.conf.device).unsqueeze(0)) emb = l2_norm(emb + emb_mirror) else: emb = self.model( self.conf.test_transform(img).to( self.conf.device).unsqueeze(0)) diff = emb.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1).cpu().numpy() score.append(np.exp(dist.dot(-1))) pred = np.argmax(score[-1]) label = score_names[-1] if pred != label: wrong_names[orig_name] = pred return score, score_names, wrong_names def save_loss(self): if not os.path.exists(self.conf.stored_result_dir): os.mkdir(self.conf.stored_result_dir) result = dict() result["train_losses"] = np.asarray(self.train_losses) result["train_counter"] = np.asarray(self.train_counter) result['test_accuracy'] = np.asarray(self.test_accuracy) result['test_losses'] = np.asarray(self.test_losses) result["test_counter"] = np.asarray(self.test_counter) with open(os.path.join(self.conf.stored_result_dir, "result_log.p"), 'wb') as fp: pickle.dump(result, fp) def plot_result(self): result_log_path = os.path.join(self.conf.stored_result_dir, "result_log.p") with open(result_log_path, 'rb') as f: result_dict = pickle.load(f) train_losses = result_dict['train_losses'] train_counter = result_dict['train_counter'] test_losses = result_dict['test_losses'] test_counter = result_dict['test_counter'] test_accuracy = result_dict['test_accuracy'] fig1 = plt.figure(figsize=(12, 8)) ax1 = fig1.add_subplot(111) ax1.plot(train_counter, train_losses, 'b', label='Train_loss') ax1.legend('Train_losses') plt.savefig(os.path.join(self.conf.stored_result_dir, "train_loss.png")) plt.close() """
class face_learner(object): def __init__(self, conf): print(conf) self.model = ResNet() self.model.cuda() if conf.initial: self.model.load_state_dict(torch.load("models/"+conf.model)) print('Load model_ir_se101.pth') self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.total_class = 16520 self.data_num = 285356 self.writer = SummaryWriter(conf.log_path) self.step = 0 self.paras_only_bn, self.paras_wo_bn = separate_bn_paras(self.model) if conf.meta: self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.total_class) self.head.cuda() if conf.initial: self.head.load_state_dict(torch.load("models/head_op.pth")) print('Load head_op.pth') self.optimizer = RAdam([ {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4}, {'params': self.paras_only_bn} ], lr=conf.lr) self.meta_optimizer = RAdam([ {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4}, {'params': self.paras_only_bn} ], lr=conf.lr) self.head.train() else: self.head = dict() self.optimizer = dict() for race in races: self.head[race] = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num[race]) self.head[race].cuda() if conf.initial: self.head[race].load_state_dict(torch.load("models/head_op_{}.pth".format(race))) print('Load head_op_{}.pth'.format(race)) self.optimizer[race] = RAdam([ {'params': self.paras_wo_bn + [self.head[race].kernel], 'weight_decay': 5e-4}, {'params': self.paras_only_bn} ], lr=conf.lr, betas=(0.5, 0.999)) self.head[race].train() # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) self.board_loss_every = min(len(self.loader[race]) for race in races) // 10 self.evaluate_every = self.data_num // 5 self.save_every = self.data_num // 2 self.eval, self.eval_issame = get_val_data(conf) def save_state(self, conf, accuracy, extra=None, model_only=False, race='All'): save_path = 'models/' torch.save( self.model.state_dict(), save_path + 'model_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step, extra, race)) if not model_only: if conf.meta: torch.save( self.head.state_dict(), save_path + 'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step, extra, race)) #torch.save( # self.optimizer.state_dict(), save_path + # 'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, # self.step, extra, race)) else: torch.save( self.head[race].state_dict(), save_path + 'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step, extra, race)) #torch.save( # self.optimizer[race].state_dict(), save_path + # 'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), # accuracy, # self.step, extra, # race)) def load_state(self, conf, fixed_str, model_only=False): save_path = 'models/' self.model.load_state_dict(torch.load(save_path + conf.model)) if not model_only: self.head.load_state_dict(torch.load(save_path + conf.head)) self.optimizer.load_state_dict(torch.load(save_path + conf.optim)) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 entry_num = carray.size()[0] embeddings = np.zeros([entry_num, conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= entry_num: batch = carray[idx:idx + conf.batch_size] if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda()) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu().detach().numpy() else: embeddings[idx:idx + conf.batch_size] = self.model(batch.cuda()).cpu().detach().numpy() idx += conf.batch_size if idx < entry_num: batch = carray[idx:] if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda()) embeddings[idx:] = l2_norm(emb_batch).cpu().detach().numpy() else: embeddings[idx:] = self.model(batch.cuda()).cpu().detach().numpy() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def train_finetuning(self, conf, epochs, race): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) ''' if e == self.milestones[0]: for ra in races: for params in self.optimizer[ra].param_groups: params['lr'] /= 10 if e == self.milestones[1]: for ra in races: for params in self.optimizer[ra].param_groups: params['lr'] /= 10 if e == self.milestones[2]: for ra in races: for params in self.optimizer[ra].param_groups: params['lr'] /= 10 ''' for imgs, labels in tqdm(iter(self.loader[race])): imgs = imgs.cuda() labels = labels.cuda() self.optimizer[race].zero_grad() embeddings = self.model(imgs) thetas = self.head[race](embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm) nn.utils.clip_grad_norm_(self.head[race].parameters(), conf.max_grad_norm) self.optimizer[race].step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % (1 * len(self.loader[race])) == 0 and self.step != 0: self.save_state(conf, 'None', race=race, model_only=True) self.step += 1 self.save_state(conf, 'None', extra='final', race=race) torch.save(self.optimizer[race].state_dict(), 'models/optimizer_{}.pth'.format(race)) def train_maml(self, conf, epochs): self.model.train() running_loss = 0. loader_iter = dict() for race in races: loader_iter[race] = iter(self.loader[race]) for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for i in tqdm(range(self.data_num // conf.batch_size)): ra1, ra2 = random.sample(races, 2) try: imgs1, labels1 = loader_iter[ra1].next() except StopIteration: loader_iter[ra1] = iter(self.loader[ra1]) imgs1, labels1 = loader_iter[ra1].next() try: imgs2, labels2 = loader_iter[ra2].next() except StopIteration: loader_iter[ra2] = iter(self.loader[ra2]) imgs2, labels2 = loader_iter[ra2].next() ## save original weights to make the update weights_original_model = deepcopy(self.model.state_dict()) weights_original_head = deepcopy(self.head.state_dict()) # base learn imgs1 = imgs1.cuda() labels1 = labels1.cuda() self.optimizer.zero_grad() embeddings1 = self.model(imgs1) thetas1 = self.head(embeddings1, labels1) loss1 = conf.ce_loss(thetas1, labels1) loss1.backward() nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm) nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm) self.optimizer.step() # meta learn imgs2 = imgs2.cuda() labels2 = labels2.cuda() embeddings2 = self.model(imgs2) thetas2 = self.head(embeddings2, labels2) self.model.load_state_dict(weights_original_model) self.head.load_state_dict(weights_original_head) self.meta_optimizer.zero_grad() loss2 = conf.ce_loss(thetas2, labels2) loss2.backward() nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm) nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm) self.meta_optimizer.step() running_loss += loss2.item() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: for race in races: accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.eval[race], self.eval_issame[race]) self.board_val(race, accuracy, best_threshold, roc_curve_tensor) self.model.train() if self.step % (self.data_num // conf.batch_size // 2) == 0 and self.step != 0: self.save_state(conf, e) self.step += 1 self.save_state(conf, epochs, extra='final') def train_meta_head(self, conf, epochs): self.model.train() running_loss = 0. optimizer = optim.SGD(self.head.parameters(), lr=conf.lr, momentum=conf.momentum) for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for race in races: for imgs, labels in tqdm(iter(self.loader[race])): imgs = imgs.cuda() labels = labels.cuda() optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. self.step += 1 torch.save(self.head.state_dict(), 'models/head_{}_meta_{}.pth'.format(get_time(), e)) def train_race_head(self, conf, epochs, race): self.model.train() running_loss = 0. optimizer = optim.SGD(self.head[race].parameters(), lr=conf.lr, momentum=conf.momentum) for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader[race])): imgs = imgs.cuda() labels = labels.cuda() optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head[race](embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. self.step += 1 torch.save(self.head[race].state_dict(), 'models/head_{}_{}_{}.pth'.format(get_time(), race, epochs)) def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 for params in self.meta_optimizer.param_groups: params['lr'] /= 10 print(self.optimizer, self.meta_optimizer)
class face_learner(object): def __init__(self, conf, inference=False): if conf.use_mobilfacenet: self.model = MobileFaceNet(conf.embedding_size).to(conf.device) print('MobileFaceNet model generated') else: self.model = Backbone(conf.net_depth, conf.drop_ratio, conf.net_mode).to(conf.device) self.growup = GrowUP().to(conf.device) self.discriminator = Discriminator().to(conf.device) print('{}_{} model generated'.format(conf.net_mode, conf.net_depth)) if not inference: self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) if conf.discriminator: self.child_loader, self.adult_loader = get_train_loader_d(conf) os.makedirs(conf.log_path, exist_ok=True) self.writer = SummaryWriter(conf.log_path) self.step = 0 self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num).to(conf.device) # Will not use anymore if conf.use_dp: self.model = nn.DataParallel(self.model) self.head = nn.DataParallel(self.head) print(self.class_num) print(conf) print('two model heads generated') paras_only_bn, paras_wo_bn = separate_bn_paras(self.model) if conf.use_mobilfacenet: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn[:-1], 'weight_decay': 4e-5 }, { 'params': [paras_wo_bn[-1]] + [self.head.kernel], 'weight_decay': 4e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) else: self.optimizer = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) if conf.discriminator: self.optimizer_g = optim.Adam(self.growup.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer_g2 = optim.Adam(self.growup.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer_d = optim.Adam(self.discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optimizer2 = optim.SGD( [{ 'params': paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) if conf.finetune_model_path is not None: self.optimizer = optim.SGD([{ 'params': paras_wo_bn, 'weight_decay': 5e-4 }, { 'params': paras_only_bn }], lr=conf.lr, momentum=conf.momentum) print('optimizers generated') self.board_loss_every = len(self.loader) // 100 self.evaluate_every = len(self.loader) // 2 self.save_every = len(self.loader) dataset_root = "/home/nas1_userD/yonggyu/Face_dataset/face_emore" self.lfw = np.load( os.path.join(dataset_root, "lfw_align_112_list.npy")).astype(np.float32) self.lfw_issame = np.load( os.path.join(dataset_root, "lfw_align_112_label.npy")) self.fgnetc = np.load( os.path.join(dataset_root, "FGNET_new_align_list.npy")).astype(np.float32) self.fgnetc_issame = np.load( os.path.join(dataset_root, "FGNET_new_align_label.npy")) else: # Will not use anymore # self.model = nn.DataParallel(self.model) self.threshold = conf.threshold def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor, negative_wrong, positive_wrong): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_scalar('{}_negative_wrong'.format(db_name), negative_wrong, self.step) self.writer.add_scalar('{}_positive_wrong'.format(db_name), positive_wrong, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) def evaluate(self, conf, carray, issame, nrof_folds=10, tta=True): self.model.eval() self.growup.eval() self.discriminator.eval() idx = 0 embeddings = np.zeros([len(carray), conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= len(carray): batch = torch.tensor(carray[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model( batch.to(conf.device)).cpu() + self.model( fliped.to(conf.device)).cpu() embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray): batch = torch.tensor(carray[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model( batch.to(conf.device)).cpu() + self.model( fliped.to(conf.device)).cpu() embeddings[idx:] = l2_norm(emb_batch).cpu() else: embeddings[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds, dist = evaluate_dist( embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = transforms.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, dist def evaluate_child(self, conf, carray, issame, nrof_folds=10, tta=True): self.model.eval() self.growup.eval() self.discriminator.eval() idx = 0 embeddings1 = np.zeros([len(carray) // 2, conf.embedding_size]) embeddings2 = np.zeros([len(carray) // 2, conf.embedding_size]) carray1 = carray[::2, ] carray2 = carray[1::2, ] with torch.no_grad(): while idx + conf.batch_size <= len(carray1): batch = torch.tensor(carray1[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \ self.growup(self.model(fliped.to(conf.device))).cpu() embeddings1[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings1[idx:idx + conf.batch_size] = self.growup( self.model(batch.to(conf.device))).cpu() idx += conf.batch_size if idx < len(carray1): batch = torch.tensor(carray1[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \ self.growup(self.model(fliped.to(conf.device))).cpu() embeddings1[idx:] = l2_norm(emb_batch).cpu() else: embeddings1[idx:] = self.growup( self.model(batch.to(conf.device))).cpu() while idx + conf.batch_size <= len(carray2): batch = torch.tensor(carray2[idx:idx + conf.batch_size]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)).cpu() + \ self.model(fliped.to(conf.device)).cpu() embeddings2[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu() else: embeddings2[idx:idx + conf.batch_size] = self.model( batch.to(conf.device)).cpu() idx += conf.batch_size if idx < len(carray2): batch = torch.tensor(carray2[idx:]) if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.to(conf.device)).cpu() + \ self.model(fliped.to(conf.device)).cpu() embeddings2[idx:] = l2_norm(emb_batch).cpu() else: embeddings2[idx:] = self.model(batch.to(conf.device)).cpu() tpr, fpr, accuracy, best_thresholds = evaluate_child( embeddings1, embeddings2, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = transforms.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def zero_grad(self): self.optimizer.zero_grad() self.optimizer_g.zero_grad() self.optimizer_d.zero_grad() def train(self, conf, epochs): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() for imgs, labels, ages in tqdm(iter(self.loader)): self.optimizer.zero_grad() imgs = imgs.to(conf.device) labels = labels.to(conf.device) embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. # added wrong on evaluations if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') # LFW evaluation accuracy, best_threshold, roc_curve_tensor, dist = self.evaluate( conf, self.lfw, self.lfw_issame) # NEGATIVE WRONG wrong_list = np.where((self.lfw_issame == False) & (dist < best_threshold))[0] negative_wrong = len(wrong_list) # POSITIVE WRONG wrong_list = np.where((self.lfw_issame == True) & (dist > best_threshold))[0] positive_wrong = len(wrong_list) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor, negative_wrong, positive_wrong) # FGNETC evaluation accuracy2, best_threshold2, roc_curve_tensor2, dist2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) # NEGATIVE WRONG wrong_list = np.where((self.fgnetc_issame == False) & (dist2 < best_threshold2))[0] negative_wrong2 = len(wrong_list) # POSITIVE WRONG wrong_list = np.where((self.fgnetc_issame == True) & (dist2 > best_threshold2))[0] positive_wrong2 = len(wrong_list) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2, negative_wrong2, positive_wrong2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? if conf.finetune_model_path is not None: self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) else: self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 print('Horray!') def train_with_growup(self, conf, epochs): ''' Our method ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() self.optimizer_g2.zero_grad() self.growup.train() c = (ages == 0) # select children for enhancement embeddings = self.model(imgs) if sum(c) > 1: # there might be no childern in loader's batch embeddings_c = embeddings[c] embeddings_a_hat = self.growup(embeddings_c) embeddings[c] = embeddings_a_hat elif sum(c) == 1: self.growup.eval() embeddings_c = embeddings[c] embeddings_a_hat = self.growup(embeddings_c) embeddings[c] = embeddings_a_hat thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() self.optimizer_g2.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() self.growup.train() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] embeddings_a_hat = self.growup(embeddings_c) labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_a_hat)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer_g.zero_grad() embeddings_c = self.model(imgs_c) embeddings_a_hat = self.growup(embeddings_c) pred_c = torch.squeeze(self.discriminator(embeddings_a_hat)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) l1_loss = conf.l1_loss(embeddings_a_hat, embeddings_c) g_total_loss = g_loss + 10 * l1_loss g_total_loss.backward() # g_loss.backward() self.optimizer_g.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate_child( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def train_age_invariant(self, conf, epochs): ''' Our method, without growup ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() self.schedule_lr2() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_c)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer2.zero_grad() embeddings_c = self.model(imgs_c) pred_c = torch.squeeze(self.discriminator(embeddings_c)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) g_loss.backward() self.optimizer2.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def train_age_invariant2(self, conf, epochs): ''' Our method, without growup, using paired dataset TODO ''' self.model.train() running_loss = 0. l1_loss = 0 for e in range(epochs): print('epoch {} started'.format(e)) if e in self.milestones: self.schedule_lr() self.schedule_lr2() a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) for imgs, labels, ages in tqdm(iter(self.loader)): # loader : base loader that returns images with id # a_loader, c_loader : adult, child loader with same datasize # ages : 0 == child, 1== adult try: imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) except StopIteration: a_loader = iter(self.adult_loader) c_loader = iter(self.child_loader) imgs_a, labels_a = next(a_loader) imgs_c, labels_c = next(c_loader) imgs = imgs.to(conf.device) labels = labels.to(conf.device) imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to( conf.device).type(torch.float32) imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to( conf.device).type(torch.float32) bs_a = imgs_a.shape[0] imgs_ac = torch.cat([imgs_a, imgs_c], dim=0) ########################### # Train head # ########################### self.optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() self.optimizer.step() ############################## # Train discriminator # ############################## self.optimizer_d.zero_grad() _embeddings = self.model(imgs_ac) embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[ bs_a:] labels_ac = torch.cat([labels_a, labels_c], dim=0) pred_a = torch.squeeze(self.discriminator( embeddings_a)) # sperate since batchnorm exists pred_c = torch.squeeze(self.discriminator(embeddings_c)) pred_ac = torch.cat([pred_a, pred_c], dim=0) d_loss = conf.ls_loss(pred_ac, labels_ac) d_loss.backward() self.optimizer_d.step() ############################# # Train genertator # ############################# self.optimizer2.zero_grad() embeddings_c = self.model(imgs_c) pred_c = torch.squeeze(self.discriminator(embeddings_c)) labels_a = torch.ones_like(labels_c, dtype=torch.float) # generator should make child 1 g_loss = conf.ls_loss(pred_c, labels_a) g_loss.backward() self.optimizer2.step() if self.step % self.board_loss_every == 0 and self.step != 0: # XXX print('tensorboard plotting....') loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) self.writer.add_scalar('d_loss', d_loss, self.step) self.writer.add_scalar('g_loss', g_loss, self.step) self.writer.add_scalar('l1_loss', l1_loss, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: print('evaluating....') accuracy, best_threshold, roc_curve_tensor = self.evaluate( conf, self.lfw, self.lfw_issame) self.board_val('lfw', accuracy, best_threshold, roc_curve_tensor) accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate( conf, self.fgnetc, self.fgnetc_issame) self.board_val('fgent_c', accuracy2, best_threshold2, roc_curve_tensor2) self.model.train() if self.step % self.save_every == 0 and self.step != 0: print('saving model....') # save with most recently calculated accuracy? self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \ + '_' + str(conf.batch_size) + conf.model_name) self.step += 1 self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode) + '_' + str(conf.net_depth)\ + '_'+ str(conf.batch_size) +'_discriminator_final') def analyze_angle(self, conf, name): ''' Only works on age labeled vgg dataset, agedb dataset ''' angle_table = [{ 0: set(), 1: set(), 2: set(), 3: set(), 4: set(), 5: set(), 6: set(), 7: set() } for i in range(self.class_num)] # batch = 0 # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device) if conf.resume_analysis: self.loader = [] for imgs, labels, ages in tqdm(iter(self.loader)): imgs = imgs.to(conf.device) labels = labels.to(conf.device) ages = ages.to(conf.device) embeddings = self.model(imgs) if conf.use_dp: kernel_norm = l2_norm(self.head.module.kernel, axis=0) cos_theta = torch.mm(embeddings, kernel_norm) cos_theta = cos_theta.clamp(-1, 1) else: cos_theta = self.head.get_angle(embeddings) thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta))) for i in range(len(thetas)): age_bin = 7 if ages[i] < 26: age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2 elif ages[i] < 66: age_bin = int(((ages[i] + 4) // 10).item()) angle_table[labels[i]][age_bin].add( thetas[i][labels[i]].item()) if conf.resume_analysis: with open('analysis/angle_table.pkl', 'rb') as f: angle_table = pickle.load(f) else: with open('analysis/angle_table.pkl', 'wb') as f: pickle.dump(angle_table, f) count, avg_angle = [], [] for i in range(self.class_num): count.append( [len(single_set) for single_set in angle_table[i].values()]) avg_angle.append([ sum(list(single_set)) / len(single_set) if len(single_set) else 0 # if set() size is zero, avg is zero for single_set in angle_table[i].values() ]) count_df = pd.DataFrame(count) avg_angle_df = pd.DataFrame(avg_angle) with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format( conf.data_mode, name)) as writer: count_df.to_excel(writer, sheet_name='count') avg_angle_df.to_excel(writer, sheet_name='avg_angle') def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 print(self.optimizer) def schedule_lr2(self): for params in self.optimizer2.param_groups: params['lr'] /= 10 print(self.optimizer2) def infer(self, conf, faces, target_embs, tta=False): ''' faces : list of PIL Image target_embs : [n, 512] computed embeddings of faces in facebank names : recorded names of faces in facebank tta : test time augmentation (hfilp, that's all) ''' embs = [] for img in faces: if tta: mirror = transforms.functional.hflip(img) emb = self.model( conf.test_transform(img).to(conf.device).unsqueeze(0)) emb_mirror = self.model( conf.test_transform(mirror).to(conf.device).unsqueeze(0)) embs.append(l2_norm(emb + emb_mirror)) else: embs.append( self.model( conf.test_transform(img).to(conf.device).unsqueeze(0))) source_embs = torch.cat(embs) diff = source_embs.unsqueeze(-1) - target_embs.transpose( 1, 0).unsqueeze(0) dist = torch.sum(torch.pow(diff, 2), dim=1) minimum, min_idx = torch.min(dist, dim=1) min_idx[minimum > self.threshold] = -1 # if no match, set idx to -1 return min_idx, minimum def save_best_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path os.makedirs('work_space/models', exist_ok=True) torch.save( self.model.state_dict(), str(save_path) + ('lfw_best_model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), str(save_path) + ('lfw_best_head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), str(save_path) + ('lfw_best_optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def save_state(self, conf, accuracy, to_save_folder=False, extra=None, model_only=False): if to_save_folder: save_path = conf.save_path else: save_path = conf.model_path os.makedirs('work_space/models', exist_ok=True) torch.save( self.model.state_dict(), str(save_path) + ('/model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if not model_only: torch.save( self.head.state_dict(), str(save_path) + ('/head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) torch.save( self.optimizer.state_dict(), str(save_path) + ('/optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) if conf.discriminator: torch.save( self.growup.state_dict(), str(save_path) + ('/growup_{}_accuracy:{:.3f}_step:{}_{}.pth'.format( get_time(), accuracy, self.step, extra))) def load_state(self, conf, fixed_str, from_save_folder=False, model_only=False, analyze=False): if from_save_folder: save_path = conf.save_path else: save_path = conf.model_path self.model.load_state_dict( torch.load(os.path.join(save_path, 'model_{}'.format(fixed_str)))) if not model_only: self.head.load_state_dict( torch.load(save_path / 'head_{}'.format(fixed_str))) if not analyze: self.optimizer.load_state_dict( torch.load(save_path / 'optimizer_{}'.format(fixed_str)))