def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='m': fc1 = fmobilenet.get_symbol(AGE*2+2, multiplier = args.multiplier, version_input=args.version_input, version_output=args.version_output) else: fc1 = fresnet.get_symbol(AGE*2+2, args.num_layers, version_input=args.version_input, version_output=args.version_output) label = mx.symbol.Variable('softmax_label') gender_label = mx.symbol.slice_axis(data = label, axis=1, begin=0, end=1) gender_label = mx.symbol.reshape(gender_label, shape=(args.per_batch_size,)) gender_fc1 = mx.symbol.slice_axis(data = fc1, axis=1, begin=0, end=2) #gender_fc7 = mx.sym.FullyConnected(data=gender_fc1, num_hidden=2, name='gender_fc7') gender_softmax = mx.symbol.SoftmaxOutput(data=gender_fc1, label = gender_label, name='gender_softmax', normalization='valid', use_ignore=True, ignore_label = 9999) outs = [gender_softmax] for i in range(AGE): age_label = mx.symbol.slice_axis(data = label, axis=1, begin=i+1, end=i+2) age_label = mx.symbol.reshape(age_label, shape=(args.per_batch_size,)) age_fc1 = mx.symbol.slice_axis(data = fc1, axis=1, begin=2+i*2, end=4+i*2) #age_fc7 = mx.sym.FullyConnected(data=age_fc1, num_hidden=2, name='age_fc7_%i'%i) age_softmax = mx.symbol.SoftmaxOutput(data=age_fc1, label = age_label, name='age_softmax_%d'%i, normalization='valid', grad_scale=1) outs.append(age_softmax) outs.append(mx.sym.BlockGrad(fc1)) out = mx.symbol.Group(outs) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) if args.network[0]=='d': embedding = fdensenet.get_symbol(512, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) embedding = fmobilenet.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(512, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(512, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(512, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) gt_label = mx.symbol.Variable('softmax_label') assert args.loss_type>=0 extra_loss = None if args.loss_type==0: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==2: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') print('center-loss', args.center_alpha, args.center_scale) extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\ num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size) elif args.loss_type==10: #marginal loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body*mask body = body+params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type==11: #npair loss params = [0.9, 0.2] nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') nembedding = mx.sym.transpose(nembedding) nembedding = mx.symbol.reshape(nembedding, (512, args.per_identities, args.images_per_identity)) nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512 #nembedding = mx.symbol.reshape(nembedding, (512, args.images_per_identity, args.per_identities)) #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512 n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1) n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2) #n1 = [] #n2 = [] #for i in xrange(args.per_identities): # _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1) # _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2) # n1.append(_n1) # n2.append(_n2) #n1 = mx.sym.concat(*n1, dim=0) #n2 = mx.sym.concat(*n2, dim=0) #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512)) #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1) #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2) n1 = mx.symbol.reshape(n1, (args.per_identities, 512)) n2 = mx.symbol.reshape(n2, (args.per_identities, 512)) cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair data_extra = mx.sym.Variable('extra') data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities) mask = cosine_matrix * data_extra #body = mx.sym.mean(mask) fii = mx.sym.sum_axis(mask, axis=1) fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii) fij_fii = mx.sym.exp(fij_fii) row = mx.sym.sum_axis(fij_fii, axis=1) row = mx.sym.log(row) body = mx.sym.mean(row) extra_loss = mx.sym.MakeLoss(body) elif args.loss_type==12: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') params = [0.9, 0.2] nembedding = mx.symbol.slice_axis(embedding, axis=0, begin=0, end=args.images_per_identity) nembedding = mx.symbol.L2Normalization(nembedding, mode='instance', name='fc1n') n1 = mx.sym.expand_dims(nembedding, axis=1) n2 = mx.sym.expand_dims(nembedding, axis=0) body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N body = body - params[0] body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) n = args.images_per_identity body = body/(n*n-n) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[1]) #extra_loss = None else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') if args.loss_type<=1 and args.incay>0.0: params = [1.e-10] sel = mx.symbol.argmax(data = fc7, axis=1) sel = (sel==gt_label) norm = embedding*embedding norm = mx.symbol.sum(norm, axis=1) norm = norm+params[0] feature_incay = sel/norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) out_list = [mx.symbol.BlockGrad(embedding)] softmax = None if args.loss_type<10: softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) if softmax is None: out_list.append(mx.sym.BlockGrad(gt_label)) if extra_loss is not None: out_list.append(extra_loss) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_input=args.version_input, version_output=args.version_output, version_multiplier=args.version_multiplier) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol( args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax if args.fc7_no_bias: fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') else: _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 6: s = args.margin_s m = args.margin_m assert s > 0.0 assert args.margin_b > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) intra_loss = t / np.pi intra_loss = mx.sym.mean(intra_loss) #intra_loss = mx.sym.exp(cos_t*-1.0) intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale=args.margin_b) if m > 0.0: t = t + m body = mx.sym.cos(t) new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 7: s = args.margin_s m = args.margin_m assert s > 0.0 assert args.margin_b > 0.0 assert args.margin_a > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) #counter_weight = mx.sym.take(_weight, gt_label, axis=1) #counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True) counter_weight = mx.sym.take(_weight, gt_label, axis=0) counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True) #counter_cos = mx.sym.minimum(counter_cos, 1.0) #counter_angle = mx.sym.arccos(counter_cos) #counter_angle = counter_angle * -1.0 #counter_angle = counter_angle/np.pi #[0,1] #inter_loss = mx.sym.exp(counter_angle) #counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True) #counter_cos = mx.sym.minimum(counter_cos, 1.0) #counter_angle = mx.sym.arccos(counter_cos) #counter_angle = mx.sym.sort(counter_angle, axis=1) #counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a)) #inter_loss = counter_angle*-1.0 # [-1,0] #inter_loss = inter_loss+1.0 # [0,1] inter_loss = counter_cos inter_loss = mx.sym.mean(inter_loss) inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale=args.margin_b) if m > 0.0: t = t + m body = mx.sym.cos(t) new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) if args.loss_type == 6: out_list.append(intra_loss) if args.loss_type == 7: out_list.append(inter_loss) #out_list.append(mx.sym.BlockGrad(counter_weight)) #out_list.append(intra_loss) if args.ce_loss: #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size body = mx.symbol.SoftmaxActivation(data=fc7) body = mx.symbol.log(body) _label = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=-1.0, off_value=0.0) body = body * _label ce_loss = mx.symbol.sum(body) / args.per_batch_size out_list.append(mx.symbol.BlockGrad(ce_loss)) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) all_label = mx.symbol.Variable('softmax_label') if not args.output_c2c: gt_label = all_label else: gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size, )) c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size, )) assert args.loss_type >= 0 extra_loss = None if args.loss_type == 0: #softmax _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 8: #centerloss, TODO _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') print('center-loss', args.center_alpha, args.center_scale) extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\ num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size) elif args.loss_type == 2: s = args.margin_s m = args.margin_m _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') if s > 0.0: nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if m > 0.0: if args.margin_verbose > 0: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s margin_symbols.append(mx.symbol.mean(cos_t)) s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot if args.margin_verbose > 0: new_zy = mx.sym.pick(fc7, gt_label, axis=1) new_cos_t = new_zy / s margin_symbols.append(mx.symbol.mean(new_cos_t)) else: fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if m > 0.0: body = embedding * embedding body = mx.sym.sum_axis(body, axis=1, keepdims=True) body = mx.sym.sqrt(body) body = body * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, body) fc7 = fc7 - body elif args.loss_type == 3: s = args.margin_s m = args.margin_m assert args.margin == 2 or args.margin == 4 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s if args.margin_verbose > 0: margin_symbols.append(mx.symbol.mean(cos_t)) #threshold = math.cos(args.margin_m) #cond_v = cos_t - threshold #cond = mx.symbol.Activation(data=cond_v, act_type='relu') #body = cos_t #for i in xrange(args.margin//2): # body = body*body # body = body*2-1 #new_zy = body*s #zy_keep = zy #new_zy = mx.sym.where(cond, new_zy, zy_keep) #if args.margin_verbose>0: # new_cos_t = new_zy/s # margin_symbols.append(mx.symbol.mean(new_cos_t)) #diff = new_zy - zy #diff = mx.sym.expand_dims(diff, 1) #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) #body = mx.sym.broadcast_mul(gt_one_hot, diff) #fc7 = fc7+body elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s if args.output_c2c == 0: cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.margin_verbose > 0: margin_symbols.append(mx.symbol.mean(cos_t)) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) else: #set c2c as cosm^2 in data.py cos_m = mx.sym.sqrt(c2c_label) sin_m = 1.0 - c2c_label sin_m = mx.sym.sqrt(sin_m) body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.margin_verbose > 0: new_cos_t = new_zy / s margin_symbols.append(mx.symbol.mean(new_cos_t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 10: #marginal loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body * mask body = body + params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body / (args.per_batch_size * args.per_batch_size - args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type == 11: #npair loss params = [0.9, 0.2] nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') nembedding = mx.sym.transpose(nembedding) nembedding = mx.symbol.reshape( nembedding, (args.emb_size, args.per_identities, args.images_per_identity)) nembedding = mx.sym.transpose(nembedding, axes=(2, 1, 0)) #2*id*512 #nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities)) #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512 n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1) n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2) #n1 = [] #n2 = [] #for i in xrange(args.per_identities): # _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1) # _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2) # n1.append(_n1) # n2.append(_n2) #n1 = mx.sym.concat(*n1, dim=0) #n2 = mx.sym.concat(*n2, dim=0) #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512)) #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1) #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2) n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size)) n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size)) cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b=True) #id*id, id=N of N-pair data_extra = mx.sym.Variable('extra') data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities) mask = cosine_matrix * data_extra #body = mx.sym.mean(mask) fii = mx.sym.sum_axis(mask, axis=1) fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii) fij_fii = mx.sym.exp(fij_fii) row = mx.sym.sum_axis(fij_fii, axis=1) row = mx.sym.log(row) body = mx.sym.mean(row) extra_loss = mx.sym.MakeLoss(body) elif args.loss_type == 12: #triplet loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3, end=2 * args.per_batch_size // 3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size) ap = anchor - positive an = anchor - negative ap = ap * ap an = an * an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data=(ap - an + args.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type == 9: #coco loss centroids = [] for i in xrange(args.per_identities): xs = mx.symbol.slice_axis(embedding, axis=0, begin=i * args.images_per_identity, end=(i + 1) * args.images_per_identity) mean = mx.symbol.mean(xs, axis=0, keepdims=True) mean = mx.symbol.L2Normalization(mean, mode='instance') centroids.append(mean) centroids = mx.symbol.concat(*centroids, dim=0) nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * args.coco_scale fc7 = mx.symbol.dot(nembedding, centroids, transpose_b=True) #(batchsize, per_identities) #extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #extra_loss = mx.symbol.BlockGrad(extra_loss) else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') if args.loss_type <= 1 and args.incay > 0.0: params = [1.e-10] sel = mx.symbol.argmax(data=fc7, axis=1) sel = (sel == gt_label) norm = embedding * embedding norm = mx.symbol.sum(norm, axis=1) norm = norm + params[0] feature_incay = sel / norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) out_list = [mx.symbol.BlockGrad(embedding)] softmax = None if args.loss_type < 10: softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) if softmax is None: out_list.append(mx.sym.BlockGrad(gt_label)) if extra_loss is not None: out_list.append(extra_loss) for _sym in margin_symbols: _sym = mx.sym.BlockGrad(_sym) out_list.append(_sym) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params, sym_embedding=None): if sym_embedding is None: if args.network[0] == 'd': embedding = fdensenet.get_symbol( args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) else: embedding = sym_embedding gt_label = mx.symbol.Variable('softmax_label') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3, end=2 * args.per_batch_size // 3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size) ap = anchor - positive an = anchor - negative ap = ap * ap an = an * an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data=(ap - an + args.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) triplet_loss = mx.symbol.MakeLoss(triplet_loss) out_list = [mx.symbol.BlockGrad(embedding)] out_list.append(mx.sym.BlockGrad(gt_label)) out_list.append(triplet_loss) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None s = args.margin_s #m = args.margin_m assert s>0.0 nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s out_list = [mx.symbol.BlockGrad(embedding)] _args = copy.deepcopy(args) if USE_FR: _args.grad_scale = 1.0 fr_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) fr_label = mx.symbol.reshape(fr_label, (args.per_batch_size,)) fr_softmax = get_softmax(_args, embedding, nembedding, fr_label, 'fc7') out_list.append(fr_softmax) if USE_GENDER: _args.grad_scale = 0.2 _args.margin_a = 0.0 _args.num_classes = 2 gender_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) gender_label = mx.symbol.reshape(gender_label, (args.per_batch_size,)) gender_softmax = get_softmax(_args, embedding, nembedding, gender_label, 'fc8') out_list.append(gender_softmax) if USE_AGE: _args.grad_scale = 0.01 _args.margin_a = 0.0 _args.num_classes = 2 for i in xrange(AGE): age_label = mx.symbol.slice_axis(all_label, axis=1, begin=2+i, end=3+i) age_label = mx.symbol.reshape(age_label, (args.per_batch_size,)) age_softmax = get_softmax(_args, embedding, nembedding, age_label, 'fc9_%d'%(i)) out_list.append(age_softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0]=='y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None s = args.margin_s #m = args.margin_m assert s>0.0 nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s out_list = [mx.symbol.BlockGrad(embedding)] _args = copy.deepcopy(args) if USE_FR: _args.grad_scale = 1.0 fr_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) fr_label = mx.symbol.reshape(fr_label, (args.per_batch_size,)) fr_softmax = get_softmax(_args, embedding, nembedding, fr_label, 'fc7') out_list.append(fr_softmax) if USE_GENDER: _args.grad_scale = 0.2 _args.margin_a = 0.0 _args.num_classes = 2 gender_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) gender_label = mx.symbol.reshape(gender_label, (args.per_batch_size,)) gender_softmax = get_softmax(_args, embedding, nembedding, gender_label, 'fc8') out_list.append(gender_softmax) if USE_AGE: _args.grad_scale = 0.01 _args.margin_a = 0.0 _args.num_classes = 2 for i in xrange(AGE): age_label = mx.symbol.slice_axis(all_label, axis=1, begin=2+i, end=3+i) age_label = mx.symbol.reshape(age_label, (args.per_batch_size,)) age_softmax = get_softmax(_args, embedding, nembedding, age_label, 'fc9_%d'%(i)) out_list.append(age_softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0]=='y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') #center_label = mx.symbol.Variable('center_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type==0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==2: s = args.margin_s m = args.margin_m assert(s>0.0) assert(m>0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot elif args.loss_type==4: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s nembedding = mx.symbol.Dropout(data=nembedding, p=0.4) fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m #threshold = math.cos(math.pi-m) if args.easy_margin: # m<pi/2 so: pi-m > pi/2 cond = mx.symbol.Activation(data=cos_t, act_type='relu') # this means t > pi/2, when pi/2 <t < 3*pi/2, sin is decrease #cond_v = cos_t - threshold #cond = mx.symbol.Activation(data=cond_v, act_type='relu') else: cond_v = math.pi - t cond = mx.symbol.Activation(data=cond_v, act_type='relu') print("not easy margin: cond") #cond = mx.symbol.Activation(data=cos_t, act_type='relu') if args.easy_margin: zy_keep = mx.sym.cos(t) else: print("not easy margin: sine") zy_keep = mx.sym.sin(t)-1 new_zy = mx.sym.cos(t) body = mx.sym.where(cond, new_zy, zy_keep) #body = mx.sym.cos(t) if args.margin_b>0.0: body = body + args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==6: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) #print("out shape",np.shape(softmax)) #center_in = mx.symbol.concat(embedding,fc7,dim=1) #center_loss_data = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss_data', op_type='centerloss',\ # num_class=args.num_classes, alpha=0.5, scale=0.5,lamdb=0.1,batchsize=args.per_batch_size,emb_size=args.emb_size) #extra_center_loss = mx.symbol.MakeLoss(name='extra_center_loss', data=center_loss_data) #total_loss = mx.symbol.ElementWiseSum([softmax, extra_loss],name='total_loss') #total_loss_op = mx.symbol.MakeLoss(name='total_loss_op',data=total_loss) #out_list.append(extra_center_loss) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0]=='y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type==0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==2: s = args.margin_s m = args.margin_m assert(s>0.0) assert(m>0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot elif args.loss_type==4: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi-m)*m #threshold = 0.0 threshold = math.cos(math.pi-m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s*mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') if not args.output_c2c: gt_label = all_label else: gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size,)) c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size,)) assert args.loss_type>=0 extra_loss = None if args.loss_type==0: #softmax _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type==1: #sphere _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type==8: #centerloss, TODO _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7') print('center-loss', args.center_alpha, args.center_scale) extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\ num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size) elif args.loss_type==2: s = args.margin_s m = args.margin_m _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') if s>0.0: nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: if args.margin_verbose>0: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s margin_symbols.append(mx.symbol.mean(cos_t)) s_m = s*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot if args.margin_verbose>0: new_zy = mx.sym.pick(fc7, gt_label, axis=1) new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) else: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') if m>0.0: body = embedding*embedding body = mx.sym.sum_axis(body, axis=1, keepdims=True) body = mx.sym.sqrt(body) body = body*m gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, body) fc7 = fc7-body elif args.loss_type==3: s = args.margin_s m = args.margin_m assert args.margin==2 or args.margin==4 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(cos_t)) if m>1.0: t = mx.sym.arccos(cos_t) t = t*m body = mx.sym.cos(t) new_zy = body*s if args.margin_verbose>0: new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body #threshold = math.cos(args.margin_m) #cond_v = cos_t - threshold #cond = mx.symbol.Activation(data=cond_v, act_type='relu') #body = cos_t #for i in xrange(args.margin//2): # body = body*body # body = body*2-1 #new_zy = body*s #zy_keep = zy #new_zy = mx.sym.where(cond, new_zy, zy_keep) #if args.margin_verbose>0: # new_cos_t = new_zy/s # margin_symbols.append(mx.symbol.mean(new_cos_t)) #diff = new_zy - zy #diff = mx.sym.expand_dims(diff, 1) #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) #body = mx.sym.broadcast_mul(gt_one_hot, diff) #fc7 = fc7+body elif args.loss_type==4: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(cos_t)) if args.output_c2c==0: cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi-m)*m #threshold = 0.0 threshold = math.cos(math.pi-m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s*mm new_zy = mx.sym.where(cond, new_zy, zy_keep) else: #set c2c as cosm^2 in data.py cos_m = mx.sym.sqrt(c2c_label) sin_m = 1.0-c2c_label sin_m = mx.sym.sqrt(sin_m) body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy*s if args.margin_verbose>0: new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: #s = args.margin_s #m = args.margin_m #assert s>0.0 #assert m>=0.0 #assert m<(math.pi/2) #_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) #_weight = mx.symbol.L2Normalization(_weight, mode='instance') #nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s #fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') #zy = mx.sym.pick(fc7, gt_label, axis=1) #cos_t = zy/s #if args.margin_verbose>0: # margin_symbols.append(mx.symbol.mean(cos_t)) #if m>0.0: # a1 = args.margin_a # r1 = ta-a1 # r1 = mx.symbol.Activation(data=r1, act_type='relu') # r1 = r1+a1 # t = mx.sym.arccos(cos_t) # cond = t-1.0 # cond = mx.symbol.Activation(data=cond, act_type='relu') # r = mx.sym.where(cond, r2, r1) # t = t+var_m # body = mx.sym.cos(t) # new_zy = body*s # if args.margin_verbose>0: # new_cos_t = new_zy/s # margin_symbols.append(mx.symbol.mean(new_cos_t)) # #margin_symbols.append(mx.symbol.mean(var_m)) # diff = new_zy - zy # diff = mx.sym.expand_dims(diff, 1) # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) # body = mx.sym.broadcast_mul(gt_one_hot, diff) # fc7 = fc7+body s = args.margin_s m = args.margin_m assert s>0.0 #assert m>=0.0 #assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) if args.margin_a>0.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==6: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) t_min = mx.sym.min(t) ta = mx.sym.broadcast_div(t_min, t) a1 = args.margin_a r1 = ta-a1 r1 = mx.symbol.Activation(data=r1, act_type='relu') r1 = r1+a1 r2 = mx.symbol.zeros(shape=(args.per_batch_size,)) cond = t-1.0 cond = mx.symbol.Activation(data=cond, act_type='relu') r = mx.sym.where(cond, r2, r1) var_m = r*m t = t+var_m body = mx.sym.cos(t) new_zy = body*s if args.margin_verbose>0: #new_cos_t = new_zy/s #margin_symbols.append(mx.symbol.mean(new_cos_t)) margin_symbols.append(mx.symbol.mean(t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==7: s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_verbose>0: margin_symbols.append(mx.symbol.mean(t)) var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,)) t = mx.sym.broadcast_add(t,var_m) body = mx.sym.cos(t) new_zy = body*s if args.margin_verbose>0: #new_cos_t = new_zy/s #margin_symbols.append(mx.symbol.mean(new_cos_t)) margin_symbols.append(mx.symbol.mean(t)) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==10: #marginal loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body*mask body = body+params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type==11: #npair loss params = [0.9, 0.2] nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') nembedding = mx.sym.transpose(nembedding) nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.per_identities, args.images_per_identity)) nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512 #nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities)) #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512 n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1) n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2) #n1 = [] #n2 = [] #for i in xrange(args.per_identities): # _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1) # _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2) # n1.append(_n1) # n2.append(_n2) #n1 = mx.sym.concat(*n1, dim=0) #n2 = mx.sym.concat(*n2, dim=0) #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512)) #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1) #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2) n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size)) n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size)) cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair data_extra = mx.sym.Variable('extra') data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities) mask = cosine_matrix * data_extra #body = mx.sym.mean(mask) fii = mx.sym.sum_axis(mask, axis=1) fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii) fij_fii = mx.sym.exp(fij_fii) row = mx.sym.sum_axis(fij_fii, axis=1) row = mx.sym.log(row) body = mx.sym.mean(row) extra_loss = mx.sym.MakeLoss(body) elif args.loss_type==12: #triplet loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) ap = anchor - positive an = anchor - negative ap = ap*ap an = an*an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==13: #triplet loss with angular margin m = args.margin_m sin_m = math.sin(m) cos_m = math.cos(m) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) ap = anchor * positive an = anchor * negative ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) ap = mx.symbol.arccos(ap) an = mx.symbol.arccos(an) triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu') #body = ap*ap #body = 1.0-body #body = mx.symbol.sqrt(body) #body = body*sin_m #ap = ap*cos_m #ap = ap-body #triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==9: #coco loss centroids = [] for i in xrange(args.per_identities): xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity) mean = mx.symbol.mean(xs, axis=0, keepdims=True) mean = mx.symbol.L2Normalization(mean, mode='instance') centroids.append(mean) centroids = mx.symbol.concat(*centroids, dim=0) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities) #extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #extra_loss = mx.symbol.BlockGrad(extra_loss) else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') if args.loss_type<=1 and args.incay>0.0: params = [1.e-10] sel = mx.symbol.argmax(data = fc7, axis=1) sel = (sel==gt_label) norm = embedding*embedding norm = mx.symbol.sum(norm, axis=1) norm = norm+params[0] feature_incay = sel/norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) out_list = [mx.symbol.BlockGrad(embedding)] softmax = None if args.loss_type<10: softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid') out_list.append(softmax) if args.logits_verbose>0: logits = mx.symbol.softmax(data = fc7) logits = mx.sym.pick(logits, gt_label, axis=1) margin_symbols.append(logits) #logit_max = mx.sym.max(logits) #logit_min = mx.sym.min(logits) #margin_symbols.append(logit_max) #margin_symbols.append(logit_min) if softmax is None: out_list.append(mx.sym.BlockGrad(gt_label)) if extra_loss is not None: out_list.append(extra_loss) for _sym in margin_symbols: _sym = mx.sym.BlockGrad(_sym) out_list.append(_sym) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) if args.network[0] == 'so': print('init spherenet_o', args.num_layers) embedding = spherenet.get_symbol(0, args.emb_size, args.num_layers) elif args.network[0] == '': print('init spherenet', args.num_layers) embedding = spherenet_bn.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) nembedding = mx.symbol.L2Normalization(embedding, mode='instance') out_list = [mx.symbol.BlockGrad(embedding)] all_label = mx.symbol.Variable('softmax_label') label_softmax = mx.sym.slice_axis(all_label, axis=0, begin=0, end=args.batch_size // args.ctx_num) nembedding_softmax = mx.sym.slice_axis(nembedding, axis=0, begin=0, end=args.batch_size // args.ctx_num) label_inter = mx.sym.slice_axis(all_label, axis=0, begin=args.batch_size // args.ctx_num, end=args.batch_size // args.ctx_num + args.batchsize_id // args.ctx_num) nembedding_inter = mx.sym.slice_axis(nembedding, axis=0, begin=args.batch_size // args.ctx_num, end=args.batch_size // args.ctx_num + args.batchsize_id // args.ctx_num) # nembedding_inter = mx.symbol.L2Normalization(embedding_inter, mode='instance') nembedding_inter = mx.sym.transpose(nembedding_inter) nembedding_inter = mx.symbol.reshape( nembedding_inter, (args.emb_size, args.batchsize_id // (args.ctx_num * args.images_per_identity), args.images_per_identity)) nembedding_inter = mx.sym.transpose(nembedding_inter, axes=(2, 1, 0)) # 3*id*512 nembedding_inter = mx.sym.mean(nembedding_inter, axis=0) nembedding_inter = mx.sym.L2Normalization(nembedding_inter, mode='instance') emb_norm = mx.sym.norm(nembedding_inter) nembedding_inter_t = mx.sym.transpose(nembedding_inter) cosine_matrix = mx.sym.dot(nembedding_inter, nembedding_inter_t) cosine_matrix = cosine_matrix - mx.symbol.eye( args.batchsize_id // (args.ctx_num * args.images_per_identity)) cosine_matrix = cosine_matrix * cosine_matrix inter_loss = args.interweight * mx.symbol.mean(cosine_matrix) inter_loss = mx.sym.MakeLoss(inter_loss) if args.loss_type == 0: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=nembedding_softmax, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') else: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding_softmax = nembedding_softmax * s fc7 = mx.sym.FullyConnected(data=nembedding_softmax, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, label_softmax, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m # threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(label_softmax, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body #1 #softmaxloss = mx.symbol.SoftmaxOutput(data=fc7, label=label_softmax, name='softmax', normalization='valid') #2 #softmax=mx.sym.softmax_cross_entropy(data=fc7, label=label_softmax) #softmaxloss = mx.sym.MakeLoss(softmax) #3 if args.noise: softmaxs = mx.sym.log_softmax(data=fc7, name="softmax") pred_label = mx.sym.argmax(softmaxs, axis=1) pred_one_hot = mx.sym.one_hot(pred_label, depth=args.num_classes, on_value=1.0, off_value=0.0) gt_one_hot = mx.sym.one_hot(label_softmax, depth=args.num_classes, on_value=1.0, off_value=0.0) cross_entropy_gt = -mx.sym.sum( mx.sym.broadcast_mul(gt_one_hot, softmaxs), axis=[0, 1]) cross_entropy_pred = -mx.sym.sum( mx.sym.broadcast_mul(pred_one_hot, softmaxs), axis=[0, 1]) cross_entropy = args.noise_beta * cross_entropy_gt + ( 1 - args.noise_beta) * cross_entropy_pred cross_entropy = cross_entropy / (args.batch_size // 2) softmaxloss = mx.sym.MakeLoss(cross_entropy) else: softmaxs = mx.sym.log_softmax(data=fc7, name="softmax") gt_one_hot = mx.sym.one_hot(label_softmax, depth=args.num_classes, on_value=1.0, off_value=0.0) cross_entropy = -mx.sym.sum(mx.sym.broadcast_mul(gt_one_hot, softmaxs), axis=[0, 1]) cross_entropy = cross_entropy / (args.batch_size // 2) softmaxloss = mx.sym.MakeLoss(cross_entropy) out_list.append(mx.symbol.BlockGrad(label_softmax)) out_list.append(mx.symbol.BlockGrad(fc7)) out_list.append(mx.symbol.BlockGrad(label_inter)) out_list.append(mx.symbol.BlockGrad(softmaxs)) out_list.append(softmaxloss) out_list.append(inter_loss) out_list.append(mx.symbol.BlockGrad(emb_norm)) out_list.append(mx.symbol.BlockGrad(cosine_matrix)) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params, sym_embedding=None): if sym_embedding is None: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) else: embedding = sym_embedding gt_label = mx.symbol.Variable('softmax_label') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') ''' anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3) positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3) negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size) ap = anchor - positive an = anchor - negative ap = ap*ap an = an*an ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu') triplet_loss = mx.symbol.mean(triplet_loss) ''' # n = mx.symbol.shape_array(nembedding)[0] n = args.per_batch_size # dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = mx.symbol.pow(nembedding, 2) dist = mx.symbol.sum(dist, 1, True) dist = mx.symbol.broadcast_to(dist, shape=(n, n)) # dist = dist + dist.t() dist = dist + mx.symbol.transpose(dist) # dist.addmm_(1, -2, inputs, inputs.t()).clamp_(min=).sqrt_() dist = dist - 2 * mx.symbol.dot(nembedding, mx.symbol.transpose(nembedding)) dist = mx.symbol.maximum(dist, 1e-12) dist = mx.symbol.sqrt(dist) # dist = dist * gl_conf.scale #todo####### dist = dist* # todo how to use triplet only, can use temprature decay/progessive learinig curriculum learning # For each anchor, find the hardest positive and negative #mask = targets.expand(n, n).eq(targets.expand(n, n).t()) label = mx.symbol.reshape(gt_label, (n, 1)) mask = mx.symbol.broadcast_equal(label, mx.symbol.transpose(label)) # a = to_numpy(targets) # print(a.shape, np.unique(a).shape) # daps = dist[mask].view(n, -1) # here can use -1, assume the number of ap is the same, e.g., all is 4! mask = mx.symbol.argsort(mask) mask_o = mx.symbol.slice(mask, begin=(0, n - args.images_per_identity), end=(n, n)) mask_o = mx.symbol.reshape(mask_o, (1, -1)) order = mx.symbol.reshape(mx.symbol.arange(start=0, stop=n), (1, -1)) order_p = mx.symbol.concat(order, order, dim=0) for i in range(args.images_per_identity - 2): order_p = mx.symbol.concat(order_p, order, dim=0) order_p = mx.symbol.reshape(order_p, (1, -1)) order_p = mx.symbol.sort(order_p, axis=1) mask_p = mx.symbol.concat(order_p, mask_o, dim=0) #mask_p[:, 1] = mask_o daps = mx.symbol.gather_nd(dist, mask_p) daps = mx.symbol.reshape(daps, (n, -1)) # todo how to copy with varied length? # dans = dist[mask == 0].view(n, -1) mask_o = mx.symbol.slice(mask, begin=(0, 0), end=(n, n - args.images_per_identity)) mask_o = mx.symbol.reshape(mask_o, (1, -1)) order_n = mx.symbol.concat(order, order, dim=0) for i in range(n - args.images_per_identity - 2): order_n = mx.symbol.concat(order_n, order, dim=0) order_n = mx.symbol.reshape(order_n, (1, -1)) order_n = mx.symbol.sort(order_n, axis=1) mask_n = mx.symbol.concat(order_n, mask_o, dim=0) dans = mx.symbol.gather_nd(dist, mask_n) dans = mx.symbol.reshape(dans, (n, -1)) # ap_wei = F.softmax(daps.detach(), dim=1) # an_wei = F.softmax(-dans.detach(), dim=1) ap_wei = mx.symbol.softmax(daps, axis=1) an_wei = mx.symbol.softmax(-dans, axis=1) ap_wei_ng = mx.symbol.BlockGrad(ap_wei) an_wei_ng = mx.symbol.BlockGrad(an_wei) # dist_ap = (daps * ap_wei).sum(dim=1) # dist_an = (dans * an_wei).sum(dim=1) dist_ap = mx.symbol.broadcast_mul(daps, ap_wei_ng) dist_ap = mx.symbol.sum(dist_ap, axis=1) dist_an = mx.symbol.broadcast_mul(dans, an_wei_ng) dist_an = mx.symbol.sum(dist_an, axis=1) # loss = F.softplus(dist_ap - dist_an).mean() triplet_loss = mx.symbol.relu(dist_ap - dist_an + args.triplet_alpha) triplet_loss = mx.symbol.mean(triplet_loss) triplet_loss = mx.symbol.MakeLoss(triplet_loss) out_list = [mx.symbol.BlockGrad(embedding)] out_list.append(mx.sym.BlockGrad(gt_label)) out_list.append(triplet_loss) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): # data_shape = (args.image_channel, args.image_h, args.image_w) # (3L,112L,112L) # image_shape = ",".join([str(x) for x in data_shape]) #3,112,112 # margin_symbols = [] print('***network: ', args.network) # r100 if args.network[0] == 'd': # densenet embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': # mobilenet print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) # elif args.network[0] == 'v': # print('init MobileNet-V3', args.num_layers) # embedding = fmobilenetv3.get_symbol(args.emb_size) elif args.network[0] == 'i': # inception-resnet-v2 print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) else: # 执行resnet print('init resnet, 层数: ', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) # get_symbol all_label = mx.symbol.Variable('softmax_label') gt_label = all_label # extra_loss = None # 重新定义fc7的权重 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: # softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: # sphere print('*******'*10) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: # CosineFace s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) # onehot两个值最大值s_m,最小值0.0 fc7 = fc7 - gt_one_hot elif args.loss_type == 4: # ArcFace s = args.margin_s # 参数s, 64 m = args.margin_m # 参数m, 0.5 assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) # pdb.set_trace() # 权重归一化 _weight = mx.symbol.L2Normalization(_weight, mode='instance') # shape = [(4253, 512)] # 特征归一化,并放大到 s*x nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # args.num_classes:8631 zy = mx.sym.pick(fc7, gt_label, axis=1) # fc7每一行找出gt_label对应的值, 即s*cos_t cos_t = zy / s # 网络输出output = s*x/|x|*w/|w|*cos(theta), 这里将输出除以s,得到实际的cos值,即cos(theta) cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m # sin(pi-m)*m = sin(m) * m 0.2397 # threshold = 0.0 threshold = math.cos(math.pi - m) # 这个阈值避免theta+m >= pi, 实际上threshold<0 -cos(m) -0.8775825618903726 if args.easy_margin: # 将0作为阈值,得到超过阈值的索引 cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold # 将负数作为阈值 cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t # cos_t^2 + sin_t^2 = 1 body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m # cos(t+m) = cos(t)cos(m) - sin(t)sin(m) b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s # s*cos(t + m) if args.easy_margin: zy_keep = zy # zy_keep为zy,即s*cos(theta) else: zy_keep = zy - s * mm # zy-s*sin(m)*m = s*cos(t)- s*m*sin(m) new_zy = mx.sym.where(cond, new_zy, zy_keep) # cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) # 对应yi处为new_zy - zy fc7 = fc7 + body # 对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m) elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) # print(out) # sys.exit() return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, wd_mult=args.fc7_wd_mult) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # split fc7 into fc7(labeled part) and fc7_2(unlabeled part) # calculate the likelihood function of the fc7_2 part subtract = args.num_classes * math.log( args.num_classes ) ############### This is calculated based on the number of labeled classes num_unlabeled = args.num_unlabeled num_labeled = args.per_batch_size - args.num_unlabeled per_batch_size = args.per_batch_size assert num_labeled > 0 fc7_2 = mx.sym.slice_axis(fc7, axis=0, begin=num_labeled, end=per_batch_size) fc7 = mx.sym.slice_axis(fc7, axis=0, begin=0, end=num_labeled) fc7_2 = mx.sym.softmax(data=fc7_2, axis=1) log_likelihood = -mx.sym.log_softmax(fc7_2, axis=1) log_likelihood_loss = (mx.sym.sum(log_likelihood) / num_unlabeled) log_likelihood_loss = log_likelihood_loss - subtract gt_label = mx.sym.slice_axis(gt_label, axis=0, begin=0, end=num_labeled) ############################################################ #last_to_append = mx.symbol.BlockGrad(mx.sym.softmax(fc7)) out_list = [] zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list.append(mx.symbol.BlockGrad(embedding)) if args.loss_type == 4: #softmax = mx.sym.softmax_cross_entropy(data=fc7, label=gt_label) cos_fc7 = mx.symbol.BlockGrad( fc7) # TODO: maybe this softmax makes the accuracy inconsistent gt_label_one_hot = mx.sym.one_hot(gt_label, args.num_classes) fc7 = -mx.sym.log_softmax(fc7, axis=1) softmax = gt_label_one_hot * fc7 softmax = mx.sym.sum(softmax) / (args.per_batch_size - args.num_unlabeled) scale = args.likelihood_scale_factor loss = mx.sym.make_loss(data=(softmax + log_likelihood_loss * scale)) out_list.append(loss) out_list.append(cos_fc7) out_list.append(mx.sym.BlockGrad(softmax)) out_list.append(mx.sym.BlockGrad(log_likelihood_loss)) else: softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0] == 'd': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'm': print('init mobilenet', args.num_layers) if args.num_layers == 1: embedding = fmobilenet.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol( args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0] == 'n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0] == 's': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'y': print('init mobilefacenet', args.num_layers) embedding = fmobilefacenet.get_symbol( args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult, wd_mult=args.fc7_wd_mult) if args.loss_type == 0: #softmax if args.fc7_no_bias: fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') else: _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: #sphere _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7') elif args.loss_type == 2: s = args.margin_s m = args.margin_m assert (s > 0.0) assert (m > 0.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s * m gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type == 4: s = args.margin_s m = args.margin_m assert s > 0.0 assert m >= 0.0 assert m < (math.pi / 2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) if args.finetune: print("{}finetuning from trained model".format('-' * 10)) _weight = mx.symbol.Variable("finetune_weight", shape=(args.num_classes, args.emb_size), lr_mult=10.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi - m) * m #threshold = 0.0 threshold = math.cos(math.pi - m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t * cos_t body = 1.0 - body sin_t = mx.sym.sqrt(body) new_zy = cos_t * cos_m b = sin_t * sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s * mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 5: s = args.margin_s m = args.margin_m assert s > 0.0 _weight = mx.symbol.L2Normalization(_weight, mode='instance') nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0: if args.margin_a == 1.0 and args.margin_m == 0.0: s_m = s * args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) if args.margin_a != 1.0: t = t * args.margin_a if args.margin_m > 0.0: t = t + args.margin_m body = mx.sym.cos(t) if args.margin_b > 0.0: body = body - args.margin_b new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)
'stage3_unit12_bn2', 'stage3_unit12_bn3', 'stage3_unit13_bn2', 'stage3_unit13_bn3', 'stage3_unit14_bn2', 'stage3_unit14_bn3', 'stage4_unit1_bn2', 'stage4_unit1_bn3', 'stage4_unit1_sc', 'stage4_unit2_bn2', 'stage4_unit2_bn3', 'stage4_unit3_bn2', 'stage4_unit3_bn3' ] assert (len(conv_names) == len(bn_prefixes)) for i in xrange(len(conv_names)): conv_name = conv_names[i] bn_prefix = bn_prefixes[i] merge_bn(arg_params, aux_params, conv_name, bn_prefix) emb_size = 128 num_layers = 50 version_se = 0 version_input = 1 version_output = 'E' version_unit = 3 version_act = 'prelu' nobn_sym = fresnet.get_symbol(emb_size, num_layers, version_se=version_se, version_input=version_input, version_output=version_output, version_unit=version_unit, version_act=version_act) mx.model.save_checkpoint('mergebn_test', 0, nobn_sym, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params): if args.retrain: new_args = arg_params else: new_args = None data_shape = (args.image_channel, args.image_h, args.image_w) image_shape = ",".join([str(x) for x in data_shape]) if args.network[0] == 's': embedding = spherenet.get_symbol(512, args.num_layers) elif args.network[0] == 'm': print('init marginal', args.num_layers) embedding = marginalnet.get_symbol(512, args.num_layers) elif args.network[0] == 'i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(512) elif args.network[0] == 'x': print('init xception', args.num_layers) embedding, _ = xception.get_xception_symbol(512) else: print('init resnet', args.num_layers) embedding = fresnet.get_symbol(512, args.num_layers, use_se=args.use_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) gt_label = mx.symbol.Variable('softmax_label') assert args.loss_type >= 0 extra_loss = None if args.loss_type == 0: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') elif args.loss_type == 1: _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') elif args.loss_type == 10: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] n1 = mx.sym.expand_dims(nembedding, axis=1) n2 = mx.sym.expand_dims(nembedding, axis=0) body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N #body = mx.sym.sqrt(body) body = body - params[0] mask = mx.sym.Variable('extra') body = body * mask body = body + params[1] #body = mx.sym.maximum(body, 0.0) body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) body = body / (args.per_batch_size * args.per_batch_size - args.per_batch_size) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2]) elif args.loss_type == 11: _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7') params = [0.9, 0.2] nembedding = mx.symbol.slice_axis(embedding, axis=0, begin=0, end=args.images_per_identity) nembedding = mx.symbol.L2Normalization(nembedding, mode='instance', name='fc1n') n1 = mx.sym.expand_dims(nembedding, axis=1) n2 = mx.sym.expand_dims(nembedding, axis=0) body = mx.sym.broadcast_sub(n1, n2) #N,N,C body = body * body body = mx.sym.sum(body, axis=2) # N,N body = body - params[0] body = mx.symbol.Activation(data=body, act_type='relu') body = mx.sym.sum(body) n = args.images_per_identity body = body / (n * n - n) extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[1]) #extra_loss = None else: #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type) embedding = embedding * 5 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2 fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes, weight=_weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=100, name='fc7') #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes, # beta=args.beta, margin=args.margin, scale=args.scale, # op_type='ASoftmax', name='fc7') softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid') if args.loss_type <= 1 and args.incay > 0.0: params = [1.e-10] sel = mx.symbol.argmax(data=fc7, axis=1) sel = (sel == gt_label) norm = embedding * embedding norm = mx.symbol.sum(norm, axis=1) norm = norm + params[0] feature_incay = sel / norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay) #out = softmax #l2_embedding = mx.symbol.L2Normalization(embedding) #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)]) if extra_loss is not None: out = mx.symbol.Group( [mx.symbol.BlockGrad(embedding), softmax, extra_loss]) else: out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax]) return (out, new_args, aux_params)
def get_symbol(args, arg_params, aux_params): data_shape = (args.image_channel,args.image_h,args.image_w) image_shape = ",".join([str(x) for x in data_shape]) margin_symbols = [] if args.network[0]=='d': embedding = fdensenet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='m': print('init mobilenet', args.num_layers) if args.num_layers==1: embedding = fmobilenet.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) else: embedding = fmobilenetv2.get_symbol(args.emb_size) elif args.network[0]=='i': print('init inception-resnet-v2', args.num_layers) embedding = finception_resnet_v2.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='x': print('init xception', args.num_layers) embedding = fxception.get_symbol(args.emb_size, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='p': print('init dpn', args.num_layers) embedding = fdpn.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) elif args.network[0]=='n': print('init nasnet', args.num_layers) embedding = fnasnet.get_symbol(args.emb_size) elif args.network[0]=='s': print('init spherenet', args.num_layers) embedding = spherenet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'c': print('init crunet', args.num_layers) embedding = fcrunet.get_symbol(args.emb_size, args.num_layers) elif args.network[0] == 'a': print('init residual attention network', args.num_layers) embedding = fresattnet.get_symbol(args.emb_size, args.num_layers) else: print('init resnet', args.num_layers) ibn = args.ibn if ibn: embedding = fresnet_ibn_a.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act, ibn=args.ibn) else: embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit, version_act=args.version_act, stn=args.stn, stn1=args.stn1, stn2=args.stn2, stn3=args.stn3, stn4=args.stn4) all_label = mx.symbol.Variable('softmax_label') gt_label = all_label extra_loss = None if args.loss_type==0: #softmax _weight = mx.symbol.Variable('fc7_weight') _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0) fc7 = mx.sym.FullyConnected( data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7' ) # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') elif args.loss_type==1: #sphereface _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization( _weight, mode='instance', name='fc7_weight_n') fc7 = mx.sym.LSoftmax( data=embedding, label=gt_label, num_hidden=args.num_classes, weight = _weight, beta=args.beta, margin=args.margin, scale=args.scale, beta_min=args.beta_min, verbose=1000, name='fc7_lsoftmax' ) # for softmax ACC computation fc_pred = mx.sym.FullyConnected( data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7' ) elif args.loss_type==2: s = args.margin_s m = args.margin_m assert(s>0.0) assert(m>0.0) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization( _weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7' ) # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 s_m = s*m gt_one_hot = mx.sym.one_hot( gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0) fc7 = fc7 - gt_one_hot elif args.loss_type==4: #ArcFace s = args.margin_s m = args.margin_m assert s>0.0 assert m>=0.0 assert m < (math.pi / 2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7' ) # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s cos_m = math.cos(m) sin_m = math.sin(m) mm = math.sin(math.pi-m)*m #threshold = 0.0 threshold = math.cos(math.pi-m) if args.easy_margin: cond = mx.symbol.Activation(data=cos_t, act_type='relu') else: cond_v = cos_t - threshold cond = mx.symbol.Activation(data=cond_v, act_type='relu') body = cos_t*cos_t body = 1.0-body sin_t = mx.sym.sqrt(body) new_zy = cos_t*cos_m b = sin_t*sin_m new_zy = new_zy - b new_zy = new_zy * s if args.easy_margin: zy_keep = zy else: zy_keep = zy - s*mm new_zy = mx.sym.where(cond, new_zy, zy_keep) diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type==5: s = args.margin_s m = args.margin_m assert s>0.0 _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0: if args.margin_a==1.0 and args.margin_m==0.0: s_m = s*args.margin_b gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7-gt_one_hot else: zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) if args.margin_a!=1.0: t = t*args.margin_a if args.margin_m>0.0: t = t+args.margin_m body = mx.sym.cos(t) if args.margin_b>0.0: body = body - args.margin_b new_zy = body*s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type == 6: #spa loss s = args.margin_s m = args.margin_m b = args.margin_b assert (s >= 1.0 and m >= 1.0 and b >= 0.0) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') if s>0.0: nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 if m>1.0: s_m = s*(m - 1 + b ) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7 * m - gt_one_hot else: fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 if m>1.0: body = embedding*embedding body = mx.sym.sum_axis(body, axis=1, keepdims=True) body = mx.sym.sqrt(body) body = body * (m - 1 + b) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) body = mx.sym.broadcast_mul(gt_one_hot, body) fc7 = fc7 * m - body elif args.loss_type == 61: #spa-v1 loss fixed s = args.margin_s m = args.margin_m b = args.margin_b assert (s >= 1.0 and m >= 1.0 and b >= 0.0) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') fc_pred = fc7 if m > 1.0: cos_theta = fc7.clip(-s, s) / s # clip cosine into [-1, 1] cos_theta_1 = cos_theta - 1 # s_m = s*(m - 1 + b ) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1, off_value = 0.0) cos_theta = cos_theta_1 + cos_theta_1 * gt_one_hot * (m - 1) if b > 0: cos_theta -= gt_one_hot * b fc7 = cos_theta * s elif args.loss_type == 63: # spa-v3 loss s = args.margin_s m = args.margin_m assert (s >= 1.0 and m >= 0.0) _weight = mx.symbol.Variable("fc7_weight", shape=( args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') fc_pred = fc7 if m > 0: cos_theta = fc7.clip(-s, s) / s # clip cosine into [-1, 1] theta = mx.sym.arccos(cos_theta) gt_one_hot = mx.sym.one_hot( gt_label, depth=args.num_classes, on_value=1, off_value=0.0) cos_theta = cos_theta - gt_one_hot * theta * m fc7 = cos_theta * s elif args.loss_type == 64: # spa-v4 loss s = args.margin_s m = args.margin_m b = args.margin_b if b < 1.0: b=1.0 assert (s >= 1.0 and m >= b and b >= 1.0) _weight = mx.symbol.Variable("fc7_weight", shape=( args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 cos_theta = fc7.clip(-s, s) / s # clip cosine into [-1, 1] # calculate angle theta and normalize into [0, 1.0] theta = mx.sym.arccos(cos_theta) * (1.0 / np.pi) s *= 2 # make the output have a region span of 2 like cos(x) if m > b: gt_one_hot = mx.sym.one_hot( gt_label, depth=args.num_classes, on_value=1, off_value=0.0) fc7 = (theta * b + gt_one_hot * theta * (m-b)) * (-s) else: fc7 = theta * (-s) elif args.loss_type == 65: # spa-v5 loss s = args.margin_s m = args.margin_m assert (s >= 1.0 and m >= 0 and m < 180) _weight = mx.symbol.Variable("fc7_weight", shape=( args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') m = m / 180.0 nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n') fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 cos_theta = fc7.clip(-s, s) / s # clip cosine into [-1, 1] # calculate angle theta and normalize into [0, 1.0] theta = mx.sym.arccos(cos_theta) * (1.0 / np.pi) s *= 2 # make the output have a region span of 2 like cos(x) if m > 0: gt_one_hot = mx.sym.one_hot( gt_label, depth=args.num_classes, on_value=1, off_value=0.0) fc7 = (theta + gt_one_hot * m) * (-s) else: fc7 = theta * (-s) elif args.loss_type == 7: #combine spa loss s = args.margin_s m = args.margin_m b = args.margin_b a = args.margin_a _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') assert m>=1.0 assert b>=0.0 assert s>0.0 nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) t = t * a new_zy = mx.sym.cos(t) * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) s_m = s*(m - 1) gt_one_hot2 = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = m * (fc7 + body) - gt_one_hot2 elif args.loss_type == 8: #Adaloss s = args.margin_s m = args.margin_m _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy / s t = mx.sym.arccos(cos_t) body = ((1-mx.sym.sin(1.0*t/(2*m)))* mx.sym.cos(1.0 *t/m) * 2) - 1 new_zy = body * s diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0) body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7 + body elif args.loss_type == 9: #semi hard loss, have fault s = args.margin_s m = args.margin_m b = args.margin_b a = args.margin_a _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') assert m>=1.0 assert b>=0.0 if s>0.0: nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) bounding = mx.sym.Variable('') s_m = s*(m - 1 + b ) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7 * m - gt_one_hot elif args.loss_type==10:#combine intra loss s = args.margin_s m = args.margin_m assert s>0.0 assert args.margin_b>0.0 b = args.margin_b _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=nembedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s t = mx.sym.arccos(cos_t) intra_loss = t/np.pi intra_loss = mx.sym.mean(intra_loss) intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b) s_m = s*(m - 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7 * m - gt_one_hot elif args.loss_type == 11: #reweight s = args.margin_s m = args.margin_m b = args.margin_b _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') assert m>=1.0 assert b>=0.0 #reweight spatial_norm = embedding * embedding spatial_norm = mx.sym.sum(data=spatial_norm, axis=1, keepdims=True) spatial_sqrt = mx.sym.sqrt(spatial_norm) spatial_mean = mx.sym.mean(spatial_sqrt) spatial_div_inverse = mx.sym.broadcast_div(spatial_mean, spatial_sqrt) reweight_s = s * spatial_div_inverse nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') # for softmax ACC computation fc_pred = mx.sym.FullyConnected( data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc_pred' ) * s # fc_pred = fc7 nembedding = mx.symbol.broadcast_mul(nembedding, reweight_s) fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') s_m = s*(m - 1 + b ) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7 * m - gt_one_hot elif args.loss_type == 12: #hard example margin s = args.margin_s m = args.margin_m b = args.margin_b _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n') assert m>=1.0 assert b>=0.0 nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected( data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7') # for softmax ACC computation # fc_pred = mx.sym.FullyConnected( # data=embedding, weight=_weight, # no_bias=True, num_hidden=args.num_classes, # name='fc_pred' # ) fc_pred = fc7 predict_label = mx.sym.argmax(fc7, axis=1) wrong_label_mask = predict_label.__lt__(gt_label) wrong_label = predict_label * wrong_label_mask s_m = s*(m - 1 + b ) wrong_label_one_hot = mx.sym.one_hot(wrong_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) fc7 = fc7 * m - wrong_label_one_hot out_list = [mx.symbol.BlockGrad(embedding)] softmax = mx.symbol.SoftmaxOutput( data=fc7, label=gt_label, name='softmax', normalization='valid') out_list.append(softmax) # for softmax ACC computation pred_softmax = mx.symbol.SoftmaxOutput( data=fc_pred, label=gt_label, name='pred_softmax', normalization='valid') out_list.append(mx.symbol.BlockGrad(pred_softmax)) out_list.append(mx.symbol.BlockGrad(_weight)) if args.loss_type == 10: out_list.append(intra_loss) if args.ce_loss: #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size body = mx.symbol.SoftmaxActivation(data=fc7) body = mx.symbol.log(body) _label = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = -1.0, off_value = 0.0) body = body*_label ce_loss = mx.symbol.sum(body)/args.per_batch_size out_list.append(mx.symbol.BlockGrad(ce_loss)) out = mx.symbol.Group(out_list) return (out, arg_params, aux_params)