def predict(image): image = image_loader(image=image) BACKBONE = IR_50(INPUT_SIZE) HEAD = ArcFace(in_features=EMBEDDING_SIZE, out_features=1000, device_id=GPU_ID) BACKBONE = BACKBONE.to(DEVICE) HEAD = HEAD.to(DEVICE) BACKBONE.load_state_dict( torch.load('./trained_model/Backbone_IR_50_ArcFace_30.pth')) HEAD.load_state_dict( torch.load('./trained_model/Head_IR_50_ArcFace_30.pth')) BACKBONE.eval() HEAD.eval() image = image.to(DEVICE) bs, ncrops, c, h, w = image.size() inputs = image.view(-1, c, h, w) features = BACKBONE(inputs) outputs = HEAD(features, None) outputs = outputs.view(bs, ncrops, -1).mean(1) top_probs, top_labs = outputs.data.topk(1) top_labs = top_labs.cpu().numpy() top_probs = top_probs.cpu().numpy() return int(top_labs), float(top_probs)
def __init__( self, model_path='Backbone_IR_50_Epoch_125_Batch_3125_Time_2020-11-19-13-22_checkpoint.pth', device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): input_size = [112, 112] rgb_mean = [0.5, 0.5, 0.5] rgb_std = [0.5, 0.5, 0.5] embedding_size = 512 self.transform = transforms.Compose([ transforms.Resize([ int(128 * input_size[0] / 112), int(128 * input_size[0] / 112) ]), # smaller side resized transforms.CenterCrop([input_size[0], input_size[1]]), transforms.ToTensor(), transforms.Normalize(mean=rgb_mean, std=rgb_std) ]) self.backbone = IR_50(input_size) self.backbone.load_state_dict( torch.load(model_path, map_location=device)) self.backbone.eval() self.device = device # With map_location, following codes seem unnecessary. Preserved to ensure compatibility. self.backbone.to(self.device)
def __init__(self): torch.set_grad_enabled(False) print('[*] Load face recognition model...') self.backbone = IR_50([FACE_INPUT_SIZE, FACE_INPUT_SIZE]) self.backbone.load_state_dict( torch.load(MODEL_PATH, map_location=DEVICE)) self.backbone.to(DEVICE) self.backbone.eval()
def __init__(self, device, backbone_name = cfg['BACKBONE_NAME'], INPUT_SIZE = cfg['INPUT_SIZE'], BACKBONE_RESUME_ROOT = cfg['BACKBONE_RESUME_ROOT']): super().__init__() BACKBONE_DICT = {'ResNet_50': ResNet_50(INPUT_SIZE), 'ResNet_101': ResNet_101(INPUT_SIZE), 'ResNet_152': ResNet_152(INPUT_SIZE), 'IR_50': IR_50(INPUT_SIZE), 'IR_101': IR_101(INPUT_SIZE), 'IR_152': IR_152(INPUT_SIZE), 'IR_SE_50': IR_SE_50(INPUT_SIZE), 'IR_SE_101': IR_SE_101(INPUT_SIZE), 'IR_SE_152': IR_SE_152(INPUT_SIZE)} self.device = device self.embedding = BACKBONE_DICT[backbone_name] self.embedding.load_state_dict(torch.load(BACKBONE_RESUME_ROOT)) self.embedding = self.embedding.to(device)
def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None): if backbone_net == 'MobileFace': net = MobileFaceNet() elif backbone_net == 'SERes50_IR': net = Backbone(50, drop_ratio=0.4, mode='ir_se') elif backbone_net == 'IR_50': net = IR_50((112, 112)) else: print(args.backbone, ' is not available!') # gpu init multi_gpus = False if len(gpus.split(',')) > 1: multi_gpus = True os.environ['CUDA_VISIBLE_DEVICES'] = gpus device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net.load_state_dict(torch.load(resume)) if multi_gpus: net = DataParallel(net).to(device) else: net = net.to(device) transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) agedb_dataset = AgeDB30(data_root, file_list, transform=transform) agedb_loader = torch.utils.data.DataLoader(agedb_dataset, batch_size=128, shuffle=False, num_workers=2, drop_last=False) return net.eval(), device, agedb_dataset, agedb_loader
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, drop_last=DROP_LAST) NUM_CLASS = len(train_loader.dataset.classes) print("Number of Training Classes: {}".format(NUM_CLASS)) lfw, cfp_ff, cfp_fp, agedb, calfw, cplfw, vgg2_fp, lfw_issame, cfp_ff_issame, cfp_fp_issame, agedb_issame, calfw_issame, cplfw_issame, vgg2_fp_issame = get_val_data( DATA_ROOT) # ======= model & loss & optimizer =======# BACKBONE_DICT = { 'ResNet_50': ResNet_50(INPUT_SIZE), 'ResNet_101': ResNet_101(INPUT_SIZE), 'ResNet_152': ResNet_152(INPUT_SIZE), 'IR_50': IR_50(INPUT_SIZE), 'IR_101': IR_101(INPUT_SIZE), 'IR_152': IR_152(INPUT_SIZE), 'IR_SE_50': IR_SE_50(INPUT_SIZE), 'IR_SE_101': IR_SE_101(INPUT_SIZE), 'IR_SE_152': IR_SE_152(INPUT_SIZE) } BACKBONE = BACKBONE_DICT[BACKBONE_NAME] print("=" * 60) print(BACKBONE) print("{} Backbone Generated".format(BACKBONE_NAME)) print("=" * 60) HEAD_DICT = { 'ArcFace': ArcFace(in_features=EMBEDDING_SIZE, out_features=NUM_CLASS),
image_root2=img_root2, actual_issame=False) tp = tp + tp2 fp = fp + fp2 tn = tn + tn2 fn = fn + fn2 print(tp, tn, fp, fn) tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) acc = float(tp + tn) / float(tp + tn + fp + fn) print(tp, tn, fp, fn) print(tpr) print(fpr) print(acc) return tpr, fpr, acc if __name__ == "__main__": cfg = configurations[1] INPUT_SIZE = cfg['INPUT_SIZE'] BACKBONE_NAME = cfg['BACKBONE_NAME'] BACKBONE_DICT = {'IR_50': IR_50(INPUT_SIZE)} BATCH_SIZE = cfg['BATCH_SIZE'] backbone = BACKBONE_DICT[BACKBONE_NAME] DROP_LAST = cfg['DROP_LAST'] PIN_MEMORY = cfg['PIN_MEMORY'] NUM_WORKERS = cfg['NUM_WORKERS'] #model_root = "./model/ms1m-ir50/backbone_ir50_ms1m_epoch120.pth" get_image_pairs()
from PIL import Image import torch import torch.nn as nn import os from backbone.model_irse import IR_50 cpt_dir = './model/backbone_ir50_ms1m.pth' IMAGE_SIZE = [112, 112] model = IR_50(IMAGE_SIZE) model.load_state_dict(torch.load(cpt_dir)) model = model.cuda() model = model.eval()
img = torch.from_numpy(img) return img def l2_norm(input, axis = 1): norm = torch.norm(input, 2, axis, True) output = torch.div(input, norm) return output def json_load(path): with open(path, "r", encoding="utf-8") as file: config = json.load(file) return config config = json_load('./Name_2.json') model1 = IR_50([112,112]) model1.load_state_dict(torch.load('IR_50.pth',map_location='cpu')) model1.eval() path1="../data/securityAI_round1_images/" alpha = 1.0 mean_feat1 = pd.read_csv('./mean_ir_50_712_2.csv').rename(columns = {'Unnamed: 0':'name'}) for person in config.items(): stage = person[0] steps = person[1]['steps'] coeff = person[1]['coeff'] threshold = person[1]['threshold'] lsteps = None bad_name = person[1]['name']
if not os.path.exists(ADV_DIR): os.mkdir(ADV_DIR) MULTIGPU = False CKPT_LIST = [ 'backbone_resnet50.pth', 'backbone_ir50_ms1m.pth', 'backbone_ir50_asia.pth', 'backbone_ir152.pth' ] CKPT_LIST = [ os.path.join('/home/zhangao/model_zoo/face_model_zoo', ckpt) for ckpt in CKPT_LIST ] model_list = [ ResNet_50([112, 112]), IR_50([112, 112]), IR_50([112, 112]), IR_152([112, 112]) ] weights = [] LOAD_EMBEDDINGS = None #LOAD_EMBEDDINGS="id_embeddings_rgb.pkl" PIN_MEMORY = False NUM_WORKERS = 0 BATCH_SIZE = 8 EMBEDDING_SIZE = 512 SAVE = True device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #device = 'cpu'
def main(ARGS): if ARGS.model_path == None: raise AssertionError("Path should not be None") use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ####### Model setup print('Model type: %s' % ARGS.model_type) if ARGS.model_type == 'ResNet_50': model = ResNet_50(ARGS.input_size) elif ARGS.model_type == 'ResNet_101': model = ResNet_101(ARGS.input_size) elif ARGS.model_type == 'ResNet_152': model = ResNet_152(ARGS.input_size) elif ARGS.model_type == 'IR_50': model = IR_50(ARGS.input_size) elif ARGS.model_type == 'IR_101': model = IR_101(ARGS.input_size) elif ARGS.model_type == 'IR_152': model = IR_152(ARGS.input_size) elif ARGS.model_type == 'IR_SE_50': model = IR_SE_50(ARGS.input_size) elif ARGS.model_type == 'IR_SE_101': model = IR_SE_101(ARGS.input_size) elif ARGS.model_type == 'IR_SE_152': model = IR_SE_152(ARGS.input_size) else: raise AssertionError( 'Unsuported model_type {}. We only support: [\'ResNet_50\', \'ResNet_101\', \'ResNet_152\', \'IR_50\', \'IR_101\', \'IR_152\', \'IR_SE_50\', \'IR_SE_101\', \'IR_SE_152\']' .format(ARGS.model_type)) if use_cuda: model.load_state_dict(torch.load(ARGS.model_path)) else: model.load_state_dict(torch.load(ARGS.model_path, map_location='cpu')) model.to(device) # embedding_size = 512 model.eval() # DATA_ROOT = './../evoLVe_data/data' # the parent root where your train/val/test data are stored # INPUT_SIZE = [112, 112] # support: [112, 112] and [224, 224] # BACKBONE_RESUME_ROOT = './../evoLVe_data/pth/backbone_ir50_ms1m_epoch120.pth' # the root to resume training from a saved checkpoint # BACKBONE_RESUME_ROOT = './../pytorch-face/pth/IR_50_MODEL_arcface_casia_epoch56_lfw9925.pth' MULTI_GPU = False # flag to use multiple GPUs; if you choose to train with single GPU, you should first run "export CUDA_VISILE_DEVICES=device_id" to specify the GPU card you want to use # DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # EMBEDDING_SIZE = 512 # feature dimension # BATCH_SIZE = 512 # BACKBONE = IR_50(INPUT_SIZE) # if os.path.isfile(BACKBONE_RESUME_ROOT): # print("Loading Backbone Checkpoint '{}'".format(BACKBONE_RESUME_ROOT)) # BACKBONE.load_state_dict(torch.load(BACKBONE_RESUME_ROOT, map_location='cpu')) # else: # print("No Checkpoint Found at '{}'.".format(BACKBONE_RESUME_ROOT)) # sys.exit() print("=" * 60) print( "Performing Evaluation on LFW, CFP_FF, CFP_FP, AgeDB, CALFW, CPLFW and VGG2_FP, and Save Checkpoints..." ) #### LFW print("Performing Evaluation on LFW...") lfw, lfw_issame = get_val_pair(ARGS.data_root, 'lfw') accuracy_lfw, best_threshold_lfw, roc_curve_lfw = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, lfw, lfw_issame) print("Evaluation: LFW Acc: {}".format(accuracy_lfw)) #### CALFW WORKS print("Performing Evaluation on CALFW...") calfw, calfw_issame = get_val_pair(ARGS.data_root, 'calfw') accuracy_calfw, best_threshold_calfw, roc_curve_calfw = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, calfw, calfw_issame) print("Evaluation: CALFW Acc: {}".format(accuracy_calfw)) #### CPLFW print("Performing Evaluation on CPLFW...") cplfw, cplfw_issame = get_val_pair(ARGS.data_root, 'cplfw') accuracy_cplfw, best_threshold_calfw, roc_curve_calfw = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, cplfw, cplfw_issame) print("Evaluation: CPLFW Acc: {}".format(accuracy_cplfw)) #### CFP-FF print("Performing Evaluation on CFP-FF...") cfp_ff, cfp_ff_issame = get_val_pair(ARGS.data_root, 'cfp_ff') accuracy_cfp_ff, best_threshold_cfp_ff, roc_curve_cfp_ff = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, cfp_ff, cfp_ff_issame) print("Evaluation: CFP-FF Acc: {}".format(accuracy_cfp_ff)) #### CFP-FP print("Performing Evaluation on CFP-FP...") cfp_fp, cfp_fp_issame = get_val_pair(ARGS.data_root, 'cfp_fp') accuracy_cfp_fp, best_threshold_cfp_fp, roc_curve_cfp_fp = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, cfp_fp, cfp_fp_issame) print("Evaluation: CFP-FP Acc: {}".format(accuracy_cfp_fp)) #### AgeDB_30 print("Performing Evaluation on AgeDB_30...") agedb_30, agedb_30_issame = get_val_pair(ARGS.data_root, 'agedb_30') accuracy_agedb, best_threshold_agedb, roc_curve_agedb = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, agedb_30, agedb_30_issame) print("Evaluation: AgeDB_30 Acc: {}".format(accuracy_agedb)) #### VggFace2_FP print("Performing Evaluation on VggFace2_FP...") vgg2_fp, vgg2_fp_issame = get_val_pair(ARGS.data_root, 'vgg2_fp') accuracy_vgg2_fp, best_threshold_vgg2_fp, roc_curve_vgg2_fp = perform_val( MULTI_GPU, device, ARGS.embedding_size, ARGS.batch_size, model, vgg2_fp, vgg2_fp_issame) print("Evaluation: VggFace2_FP Acc: {}".format(accuracy_vgg2_fp)) print("=" * 60) print("FINAL RESULTS:") print( "Evaluation: LFW Acc: {}, CFP_FF Acc: {}, CFP_FP Acc: {}, AgeDB Acc: {}, CALFW Acc: {}, CPLFW Acc: {}, VGG2_FP Acc: {}" .format(accuracy_lfw, accuracy_cfp_ff, accuracy_cfp_fp, accuracy_agedb, accuracy_calfw, accuracy_cplfw, accuracy_vgg2_fp)) print("=" * 60)
import sys sys.path.append('./align') import os import numpy as np import cv2 import torch import pickle as pkl import matplotlib.pyplot as plt from PIL import Image from backbone.model_irse import IR_50 from Utils.img_util import plot_images, align_face, img2tensor, l2_norm model = IR_50([112, 112]) model.load_state_dict( torch.load('backbone_ir50_ms1m_epoch120.pth', map_location='cpu')) model.eval() # @function: store the target image data processed by the model # @parameters: isAver: True: store the average face feature data False: store each face feature data # @attention: the face feature data isn't normalized def save_imgdata(root_dir='./data/train/', isAver=False): paths = os.listdir(root_dir) if not isAver: warehouse = 'warehouse_3.pkl' else: warehouse = 'aver_warehouse.pkl' num_persons = len(paths) num_faces = len(os.listdir(root_dir + paths[0])) face_datas = np.zeros((512, num_faces * num_persons)) names = []
from backbone.model_irse import IR_50 from util.utils_test import * import time, os, glob, random IMG_SIZE = 112 DEVICE = 'cpu' # torch.device("cuda:0") USE_FLIP = True MODEL_PATH = 'pretrained/backbone_ir50_asia.pth' torch.set_grad_enabled(False) # load model print('[*] Load model...') backbone = IR_50([IMG_SIZE, IMG_SIZE]) backbone.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) backbone.to(DEVICE) backbone.eval() img_list = glob.glob('imgs/*/*.jpg') random.shuffle(img_list) # test print('[*] Inference...') for img_path_A in img_list: img_path_B = random.sample(img_list, k=1)[0] _, subject_A, id_A = img_path_A.split('/') _, subject_B, id_B = img_path_B.split('/')
NUM_CLASS = len(train_loader.dataset.classes) print("Number of Training Classes: {}".format(NUM_CLASS)) lfw, cfp_ff, cfp_fp, agedb, calfw, cplfw, vgg2_fp, lfw_issame, cfp_ff_issame, cfp_fp_issame, agedb_issame, calfw_issame, cplfw_issame, vgg2_fp_issame = get_val_data( DATA_ROOT) #======= model & loss & optimizer =======# BACKBONE_DICT = { 'ResNet_50': ResNet_50(INPUT_SIZE), 'ResNet_101': ResNet_101(INPUT_SIZE), 'ResNet_152': ResNet_152(INPUT_SIZE), 'IR_50': IR_50(INPUT_SIZE), 'IR_101': IR_101(INPUT_SIZE), 'IR_152': IR_152(INPUT_SIZE), 'IR_SE_50': IR_SE_50(INPUT_SIZE), 'IR_SE_101': IR_SE_101(INPUT_SIZE), 'IR_SE_152': IR_SE_152(INPUT_SIZE), 'ShuffleNetV2_0.5': shufflenet_v2_x0_5(pretrained=True, only_features=True), 'ShuffleNetV2_1.0': shufflenet_v2_x1_0(pretrained=False, only_features=True), 'ShuffleNetV2_1.5':
model = IR_152([112, 112]).to(device) for model_path in models_list1: name = model_path.split('/')[1].split('.')[0] model.load_state_dict(torch.load(model_path)) features = inference_dataload(model, test_path, tta=True) fe_dict = get_feature_dict(name_list, features) print('Output number:', len(fe_dict)) sio.savemat(feature_path + name + '_filp.mat', fe_dict) features = inference_dataload(model, test_path, tta=False) fe_dict = get_feature_dict(name_list, features) print('Output number:', len(fe_dict)) sio.savemat(feature_path + name + '_nofilp.mat', fe_dict) model = IR_50([112, 112]).to(device) for model_path in models_list2: name = model_path.split('/')[1].split('.')[0] model.load_state_dict(torch.load(model_path)) features = inference_dataload(model, test_path, tta=True) fe_dict = get_feature_dict(name_list, features) print('Output number:', len(fe_dict)) sio.savemat(feature_path + name + '_filp.mat', fe_dict) features = inference_dataload(model, test_path, tta=False) fe_dict = get_feature_dict(name_list, features) print('Output number:', len(fe_dict)) sio.savemat(feature_path + name + '_nofilp.mat', fe_dict) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #get the face feature Similarity
from torchvision import models from torchvision import transforms import torch.nn as nn from PIL import Image import matplotlib.pyplot as plt import matplotlib.pyplot as plt warnings.filterwarnings('ignore') from backbone.model_irse import IR_50, IR_101, IR_152, IR_SE_50, IR_SE_101, IR_SE_152 from backbone.model_resnet import ResNet_50, ResNet_101, ResNet_152 from backbone.model_facenet import model_920, model_921 IMG_SIZE = 224 # ensemble multi-model # 1 model_ir50_epoch120 = IR_50([112, 112]) model_ir50_epoch120.load_state_dict( torch.load('./Defense_Model/backbone_ir50_ms1m_epoch120.pth', map_location='cuda')) model_ir50_epoch120.eval() criterion_ir50_epoch120 = nn.MSELoss() # 2 model_IR_152_Epoch_112 = IR_152([112, 112]) model_IR_152_Epoch_112.load_state_dict( torch.load( './Defense_Model/Backbone_IR_152_Epoch_112_Batch_2547328_Time_2019-07-13-02-59_checkpoint.pth', map_location='cuda')) model_IR_152_Epoch_112.eval() criterion_IR_152_Epoch_112 = nn.MSELoss() # 3 model_IR_SE_50_Epoch_2 = IR_SE_50([112, 112])
num_workers = NUM_WORKERS, drop_last = DROP_LAST ) NUM_CLASS = len(train_loader.dataset.classes) print("Number of Training Classes: {}".format(NUM_CLASS)) lfw, cfp_ff, cfp_fp, agedb, calfw, cplfw, vgg2_fp, lfw_issame, cfp_ff_issame, cfp_fp_issame, agedb_issame, calfw_issame, cplfw_issame, vgg2_fp_issame = get_val_data(DATA_ROOT) #======= model & loss & optimizer =======# if BACKBONE_NAME == 'ResNet_50': BACKBONE = ResNet_50(INPUT_SIZE) # 'ResNet_101': resnet101(INPUT_SIZE), # 'ResNet_152': resnet152(INPUT_SIZE), elif BACKBONE_NAME == 'IR_50': BACKBONE = IR_50(INPUT_SIZE) elif BACKBONE_NAME == 'IR_101': BACKBONE = IR_101(INPUT_SIZE) elif BACKBONE_NAME == 'IR_152': BACKBONE = IR_152(INPUT_SIZE) elif BACKBONE_NAME == 'IR_SE_50': BACKBONE = IR_SE_50(INPUT_SIZE) elif BACKBONE_NAME == 'IR_SE_101': BACKBONE = IR_SE_101(INPUT_SIZE) elif BACKBONE_NAME == 'IR_SE_152': BACKBONE = IR_SE_152(INPUT_SIZE) elif BACKBONE_NAME == 'ShuffleNet': BACKBONE = shufflenet(cfg=cfg) elif BACKBONE_NAME == 'ShuffleNetV2': BACKBONE = shufflenetv2(cfg=cfg) elif BACKBONE_NAME == 'Mobilenet':
train_loader = torch.utils.data.DataLoader( dataset_train, batch_size = BATCH_SIZE, sampler = sampler, pin_memory = PIN_MEMORY, num_workers = NUM_WORKERS, drop_last = DROP_LAST ) NUM_CLASS = len(train_loader.dataset.classes) print("Number of Training Classes: {}".format(NUM_CLASS)) # validate on LFW, CFP_FF, CFP_FP, AgeDB, CALFW, CPLFW and VGGFace2_FP lfw, cfp_ff, cfp_fp, agedb, calfw, cplfw, vgg2_fp, lfw_issame, cfp_ff_issame, cfp_fp_issame, agedb_issame, calfw_issame, cplfw_issame, vgg2_fp_issame = get_val_data(DATA_ROOT) #======= model & loss & optimizer =======# BACKBONE_DICT = {'ResNet_50': ResNet_50(INPUT_SIZE), 'ResNet_101': ResNet_101(INPUT_SIZE), 'ResNet_152': ResNet_152(INPUT_SIZE), 'IR_50': IR_50(INPUT_SIZE), 'IR_101': IR_101(INPUT_SIZE), 'IR_152': IR_152(INPUT_SIZE), 'IR_SE_50': IR_SE_50(INPUT_SIZE), 'IR_SE_101': IR_SE_101(INPUT_SIZE), 'IR_SE_152': IR_SE_152(INPUT_SIZE)} BACKBONE = BACKBONE_DICT[BACKBONE_NAME] print("=" * 60) print(BACKBONE) print("{} Backbone Generated".format(BACKBONE_NAME)) print("=" * 60) HEAD_DICT = {'ArcFace': ArcFace(in_features = EMBEDDING_SIZE, out_features = NUM_CLASS), 'CosFace': CosFace(in_features = EMBEDDING_SIZE, out_features = NUM_CLASS), 'SphereFace': SphereFace(in_features = EMBEDDING_SIZE, out_features = NUM_CLASS), 'Am_softmax': Am_softmax(in_features = EMBEDDING_SIZE, out_features = NUM_CLASS)} HEAD = HEAD_DICT[HEAD_NAME] print("=" * 60) print(HEAD) print("{} Head Generated".format(HEAD_NAME))