def run_Affectnet_training(): config = { 'batch_size': 64, 'val_batch_size': 40, 'img_size': (112, 112), # (128, 128), 'metric_update_period': 50, 'layers': 50, 'load_epoch': 0, #'load_path': '/media/nlab/data/test/resnext50-valence-llr', 'save_model_prefix': '/media/nlab/data/SF/mbnet-singleframe', 'emotions_list': [ 'Neutral', 'Happy', 'Sad', 'Surprise', 'Fear', 'Anger', 'Disgust', 'Contempt' ], # 'multiply_basic_ratio': 4 } train_iter = AffectnetIter(data_json_path='../training.csv', batch_size=config['batch_size'], train=True, img_size=config['img_size'], detector=None) train_iter.global_num_inst = int( train_iter.n_objects / config['batch_size']) * config['batch_size'] * config['load_epoch'] fc1 = fmobilefacenet.get_symbol() module = mx.mod.Module(fc1, context=mx.gpu(0)) module.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) module.init_params(arg_params=None, aux_params=None, initializer=mx.init.MSRAPrelu(), allow_missing=True) val_iter = AffectnetIter(data_json_path='../validation.csv', batch_size=config['val_batch_size'], train=False, img_size=config['img_size'], detector=None) train_Affectnet(module, train_iter, val_iter, config)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol( args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax if args.fc7_no_bias: fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') else: _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) if args.finetune: print("{}finetuning from trained model".format('-' * 10)) _weight = mx.symbol.Variable("finetune_weight", shape=(args.num_classes, args.emb_size), lr_mult=10.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
model.forward(db, is_train=False) embedding = model.get_outputs()[0].asnumpy() #embedding = sklearn.preprocessing.normalize(embedding).flatten() end_time = time.time() embedding_time += end_time - start_time #print('cost of generate features:' + str(end_time - start_time)) return read_img_time / loop_time, crop_time / loop_time, embedding_time / loop_time ave_image_read_dict = {} ave_crop_dict = {} ave_embedding_dict = {} # 原始模型 embedding = fmobilefacenet.get_symbol(128, bn_mom=0.9, version_output='GNAP') detector = MtcnnDetector(model_folder=mtcnn_path, ctx=mx.cpu(0), num_worker=1, accurate_landmark=True, threshold=[0.6, 0.7, 0.8]) ave_read_image_time, ave_crop_time, ave_embedding_time = cal_time_cost( embedding, detector, 50) print(ave_read_image_time, ave_crop_time, ave_embedding_time) ave_image_read_dict['orignal'] = ave_read_image_time ave_crop_dict['orignal'] = ave_crop_time ave_embedding_dict['orignal'] = ave_embedding_time # 去掉45,5层 embedding = fmobilefacenet.get_symbol1(128, bn_mom=0.9, version_output='GNAP') detector = MtcnnDetector(model_folder=mtcnn_path,
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_input=args.version_input, version_output=args.version_output, version_multiplier=args.version_multiplier) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol( args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax if args.fc7_no_bias: fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') else: _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 6: s = args.margin_s m = args.margin_m assert s > 0.0 assert args.margin_b > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) intra_loss = t / np.pi intra_loss = mx.sym.mean(intra_loss) #intra_loss = mx.sym.exp(cos_t*-1.0) intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale=args.margin_b) if m > 0.0: t = t + m body = mx.sym.cos(t) new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 7: s = args.margin_s m = args.margin_m assert s > 0.0 assert args.margin_b > 0.0 assert args.margin_a > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) #counter_weight = mx.sym.take(_weight, gt_label, axis=1) #counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True) counter_weight = mx.sym.take(_weight, gt_label, axis=0) counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True) #counter_cos = mx.sym.minimum(counter_cos, 1.0) #counter_angle = mx.sym.arccos(counter_cos) #counter_angle = counter_angle * -1.0 #counter_angle = counter_angle/np.pi #[0,1] #inter_loss = mx.sym.exp(counter_angle) #counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True) #counter_cos = mx.sym.minimum(counter_cos, 1.0) #counter_angle = mx.sym.arccos(counter_cos) #counter_angle = mx.sym.sort(counter_angle, axis=1) #counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a)) #inter_loss = counter_angle*-1.0 # [-1,0] #inter_loss = inter_loss+1.0 # [0,1] inter_loss = counter_cos inter_loss = mx.sym.mean(inter_loss) inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale=args.margin_b) if m > 0.0: t = t + m body = mx.sym.cos(t) new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) if args.loss_type == 6: out_list.append(intra_loss) if args.loss_type == 7: out_list.append(inter_loss) #out_list.append(mx.sym.BlockGrad(counter_weight)) #out_list.append(intra_loss) if args.ce_loss: #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size body = mx.symbol.SoftmaxActivation(data=fc7) body = mx.symbol.log(body) _label = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=-1.0, off_value=0.0) body = body * _label ce_loss = mx.symbol.sum(body) / args.per_batch_size out_list.append(mx.symbol.BlockGrad(ce_loss)) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0]=='y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type==0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==2: s = args.margin_s m = args.margin_m assert(s>0.0) assert(m>0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot elif args.loss_type==4: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi-m)*m #threshold = 0.0 threshold = math.cos(math.pi-m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s*mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0]=='y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None s = args.margin_s #m = args.margin_m assert s>0.0 nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s out_list = [mx.symbol.BlockGrad(embedding)] _args = copy.deepcopy(args) if USE_FR: _args.grad_scale = 1.0 fr_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) fr_label = mx.symbol.reshape(fr_label, (args.per_batch_size,)) fr_softmax = get_softmax(_args, embedding, nembedding, fr_label, 'fc7') out_list.append(fr_softmax) if USE_GENDER: _args.grad_scale = 0.2 _args.margin_a = 0.0 _args.num_classes = 2 gender_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) gender_label = mx.symbol.reshape(gender_label, (args.per_batch_size,)) gender_softmax = get_softmax(_args, embedding, nembedding, gender_label, 'fc8') out_list.append(gender_softmax) if USE_AGE: _args.grad_scale = 0.01 _args.margin_a = 0.0 _args.num_classes = 2 for i in xrange(AGE): age_label = mx.symbol.slice_axis(all_label, axis=1, begin=2+i, end=3+i) age_label = mx.symbol.reshape(age_label, (args.per_batch_size,)) age_softmax = get_softmax(_args, embedding, nembedding, age_label, 'fc9_%d'%(i)) out_list.append(age_softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): # define network data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] args.num_layers = 1 print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, wd_mult=args.fc7_wd_mult) # define loss all_label = mx.symbol.Variable('softmax_label') gt_label = all_label _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] logit_t_val = mx.symbol.Variable('logit_t') softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') soft_loss = mx.symbol.mean( mx.symbol.square(fc7 / args.tau - logit_t_val / args.tau)) log_softmax = mx.sym.log_softmax(fc7) hard_loss = -mx.sym.sum(mx.sym.broadcast_mul( gt_one_hot, log_softmax)) / args.batch_size total_loss = soft_loss * args.lamda + hard_loss * (1 - args.lamda) total_loss = mx.symbol.MakeLoss(total_loss) out_list.append(mx.sym.BlockGrad(softmax)) out_list.append(total_loss) out = mx.sym.Group(out_list) # out = mx.sym.Group([mx.sym.BlockGrad(embedding), softmax]) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): # data_shape = (args.image_channel, args.image_h, args.image_w) # (3L,112L,112L) # image_shape = ",".join([str(x) for x in data_shape]) #3,112,112 # margin_symbols = [] print('***network: ', args.network) # r100 if args.network[0] == 'd': # densenet embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': # mobilenet print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) # elif args.network[0] == 'v': # print('init MobileNet-V3', args.num_layers) # embedding = fmobilenetv3.get_symbol(args.emb_size) elif args.network[0] == 'i': # inception-resnet-v2 print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) else: # 执行resnet print('init resnet, 层数: ', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) # get_symbol all_label = mx.symbol.Variable('softmax_label') gt_label = all_label # extra_loss = None # 重新定义fc7的权重 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: # softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: # sphere print('*******'*10) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: # CosineFace s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) # onehot两个值最大值s_m,最小值0.0 fc7 = fc7 - gt_one_hot elif args.loss_type == 4: # ArcFace s = args.margin_s # 参数s, 64 m = args.margin_m # 参数m, 0.5 assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) # pdb.set_trace() # 权重归一化 _weight = mx.symbol.L2Normalization(_weight, mode='instance') # shape = [(4253, 512)] # 特征归一化,并放大到 s*x nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # args.num_classes:8631 zy = mx.sym.pick(fc7, gt_label, axis=1) # fc7每一行找出gt_label对应的值, 即s*cos_t cos_t = zy / s # 网络输出output = s*x/|x|*w/|w|*cos(theta), 这里将输出除以s,得到实际的cos值,即cos(theta) cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m # sin(pi-m)*m = sin(m) * m 0.2397 # threshold = 0.0 threshold = math.cos(math.pi - m) # 这个阈值避免theta+m >= pi, 实际上threshold<0 -cos(m) -0.8775825618903726 if args.easy_margin: # 将0作为阈值,得到超过阈值的索引 cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold # 将负数作为阈值 cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t # cos_t^2 + sin_t^2 = 1 body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m # cos(t+m) = cos(t)cos(m) - sin(t)sin(m) b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s # s*cos(t + m) if args.easy_margin: zy_keep = zy # zy_keep为zy,即s*cos(theta) else: zy_keep = zy - s * mm # zy-s*sin(m)*m = s*cos(t)- s*m*sin(m) new_zy = mx.sym.where(cond, new_zy, zy_keep) # cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) # 对应yi处为new_zy - zy fc7 = fc7 + body # 对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m) elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) # print(out) # sys.exit() return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, wd_mult=args.fc7_wd_mult) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # split fc7 into fc7(labeled part) and fc7_2(unlabeled part) # calculate the likelihood function of the fc7_2 part subtract = args.num_classes * math.log( args.num_classes ) ############### This is calculated based on the number of labeled classes num_unlabeled = args.num_unlabeled num_labeled = args.per_batch_size - args.num_unlabeled per_batch_size = args.per_batch_size assert num_labeled > 0 fc7_2 = mx.sym.slice_axis(fc7, axis=0, begin=num_labeled, end=per_batch_size) fc7 = mx.sym.slice_axis(fc7, axis=0, begin=0, end=num_labeled) fc7_2 = mx.sym.softmax(data=fc7_2, axis=1) log_likelihood = -mx.sym.log_softmax(fc7_2, axis=1) log_likelihood_loss = (mx.sym.sum(log_likelihood) / num_unlabeled) log_likelihood_loss = log_likelihood_loss - subtract gt_label = mx.sym.slice_axis(gt_label, axis=0, begin=0, end=num_labeled) ############################################################ #last_to_append = mx.symbol.BlockGrad(mx.sym.softmax(fc7)) out_list = [] zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list.append(mx.symbol.BlockGrad(embedding)) if args.loss_type == 4: #softmax = mx.sym.softmax_cross_entropy(data=fc7, label=gt_label) cos_fc7 = mx.symbol.BlockGrad( fc7) # TODO: maybe this softmax makes the accuracy inconsistent gt_label_one_hot = mx.sym.one_hot(gt_label, args.num_classes) fc7 = -mx.sym.log_softmax(fc7, axis=1) softmax = gt_label_one_hot * fc7 softmax = mx.sym.sum(softmax) / (args.per_batch_size - args.num_unlabeled) scale = args.likelihood_scale_factor loss = mx.sym.make_loss(data=(softmax + log_likelihood_loss * scale)) out_list.append(loss) out_list.append(cos_fc7) out_list.append(mx.sym.BlockGrad(softmax)) out_list.append(mx.sym.BlockGrad(log_likelihood_loss)) else: softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0]=='y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') #center_label = mx.symbol.Variable('center_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type==0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==2: s = args.margin_s m = args.margin_m assert(s>0.0) assert(m>0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot elif args.loss_type==4: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s nembedding = mx.symbol.Dropout(data=nembedding, p=0.4) fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m #threshold = math.cos(math.pi-m) if args.easy_margin: # m<pi/2 so: pi-m > pi/2 cond = mx.symbol.Activation(data=cos_t, act_type='relu') # this means t > pi/2, when pi/2 <t < 3*pi/2, sin is decrease #cond_v = cos_t - threshold #cond = mx.symbol.Activation(data=cond_v, act_type='relu') else: cond_v = math.pi - t cond = mx.symbol.Activation(data=cond_v, act_type='relu') print("not easy margin: cond") #cond = mx.symbol.Activation(data=cos_t, act_type='relu') if args.easy_margin: zy_keep = mx.sym.cos(t) else: print("not easy margin: sine") zy_keep = mx.sym.sin(t)-1 new_zy = mx.sym.cos(t) body = mx.sym.where(cond, new_zy, zy_keep) #body = mx.sym.cos(t) if args.margin_b>0.0: body = body + args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==6: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) #print("out shape",np.shape(softmax)) #center_in = mx.symbol.concat(embedding,fc7,dim=1) #center_loss_data = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss_data', op_type='centerloss',\ # num_class=args.num_classes, alpha=0.5, scale=0.5,lamdb=0.1,batchsize=args.per_batch_size,emb_size=args.emb_size) #extra_center_loss = mx.symbol.MakeLoss(name='extra_center_loss', data=center_loss_data) #total_loss = mx.symbol.ElementWiseSum([softmax, extra_loss],name='total_loss') #total_loss_op = mx.symbol.MakeLoss(name='total_loss_op',data=total_loss) #out_list.append(extra_center_loss) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): # data_shape = (args.image_channel, args.image_h, args.image_w) data_shape = (3, 112, 112) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] # print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m # threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)