Example #1
0
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='m':
    fc1 = fmobilenet.get_symbol(AGE*2+2, 
        multiplier = args.multiplier,
        version_input=args.version_input, 
        version_output=args.version_output)
  else:
    fc1 = fresnet.get_symbol(AGE*2+2, args.num_layers,
        version_input=args.version_input, 
        version_output=args.version_output)
  label = mx.symbol.Variable('softmax_label')
  gender_label = mx.symbol.slice_axis(data = label, axis=1, begin=0, end=1)
  gender_label = mx.symbol.reshape(gender_label, shape=(args.per_batch_size,))
  gender_fc1 = mx.symbol.slice_axis(data = fc1, axis=1, begin=0, end=2)
  #gender_fc7 = mx.sym.FullyConnected(data=gender_fc1, num_hidden=2, name='gender_fc7')
  gender_softmax = mx.symbol.SoftmaxOutput(data=gender_fc1, label = gender_label, name='gender_softmax', normalization='valid', use_ignore=True, ignore_label = 9999)
  outs = [gender_softmax]
  for i in range(AGE):
    age_label = mx.symbol.slice_axis(data = label, axis=1, begin=i+1, end=i+2)
    age_label = mx.symbol.reshape(age_label, shape=(args.per_batch_size,))
    age_fc1 = mx.symbol.slice_axis(data = fc1, axis=1, begin=2+i*2, end=4+i*2)
    #age_fc7 = mx.sym.FullyConnected(data=age_fc1, num_hidden=2, name='age_fc7_%i'%i)
    age_softmax = mx.symbol.SoftmaxOutput(data=age_fc1, label = age_label, name='age_softmax_%d'%i, normalization='valid', grad_scale=1)
    outs.append(age_softmax)
  outs.append(mx.sym.BlockGrad(fc1))

  out = mx.symbol.Group(outs)
  return (out, arg_params, aux_params)
Example #2
0
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(512, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    embedding = fmobilenet.get_symbol(512, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(512,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(512,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(512, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(512, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  gt_label = mx.symbol.Variable('softmax_label')
  assert args.loss_type>=0
  extra_loss = None
  if args.loss_type==0:
    _weight = mx.symbol.Variable('fc7_weight')
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
  elif args.loss_type==1:
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                          weight = _weight,
                          beta=args.beta, margin=args.margin, scale=args.scale,
                          beta_min=args.beta_min, verbose=1000, name='fc7')
  elif args.loss_type==2:
    _weight = mx.symbol.Variable('fc7_weight')
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
    print('center-loss', args.center_alpha, args.center_scale)
    extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\
          num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size)
  elif args.loss_type==10: #marginal loss
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
    params = [1.2, 0.3, 1.0]
    n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C
    n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C
    body = mx.sym.broadcast_sub(n1, n2) #N,N,C
    body = body * body
    body = mx.sym.sum(body, axis=2) # N,N
    #body = mx.sym.sqrt(body)
    body = body - params[0]
    mask = mx.sym.Variable('extra')
    body = body*mask
    body = body+params[1]
    #body = mx.sym.maximum(body, 0.0)
    body = mx.symbol.Activation(data=body, act_type='relu')
    body = mx.sym.sum(body)
    body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size)
    extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
  elif args.loss_type==11: #npair loss
    params = [0.9, 0.2]
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
    nembedding = mx.sym.transpose(nembedding)
    nembedding = mx.symbol.reshape(nembedding, (512, args.per_identities, args.images_per_identity))
    nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512
    #nembedding = mx.symbol.reshape(nembedding, (512, args.images_per_identity, args.per_identities))
    #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512
    n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)
    n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)
    #n1 = []
    #n2 = []
    #for i in xrange(args.per_identities):
    #  _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1)
    #  _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2)
    #  n1.append(_n1)
    #  n2.append(_n2)
    #n1 = mx.sym.concat(*n1, dim=0)
    #n2 = mx.sym.concat(*n2, dim=0)
    #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512))
    #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1)
    #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2)
    n1 = mx.symbol.reshape(n1, (args.per_identities, 512))
    n2 = mx.symbol.reshape(n2, (args.per_identities, 512))
    cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair
    data_extra = mx.sym.Variable('extra')
    data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities)
    mask = cosine_matrix * data_extra
    #body = mx.sym.mean(mask)
    fii = mx.sym.sum_axis(mask, axis=1)
    fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)
    fij_fii = mx.sym.exp(fij_fii)
    row = mx.sym.sum_axis(fij_fii, axis=1)
    row = mx.sym.log(row)
    body = mx.sym.mean(row)
    extra_loss = mx.sym.MakeLoss(body)
  elif args.loss_type==12:
    _weight = mx.symbol.Variable('fc7_weight')
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
    params = [0.9, 0.2]
    nembedding = mx.symbol.slice_axis(embedding, axis=0, begin=0, end=args.images_per_identity)
    nembedding = mx.symbol.L2Normalization(nembedding, mode='instance', name='fc1n')
    n1 = mx.sym.expand_dims(nembedding, axis=1)
    n2 = mx.sym.expand_dims(nembedding, axis=0)
    body = mx.sym.broadcast_sub(n1, n2) #N,N,C
    body = body * body
    body = mx.sym.sum(body, axis=2) # N,N
    body = body - params[0]
    body = mx.symbol.Activation(data=body, act_type='relu')
    body = mx.sym.sum(body)
    n = args.images_per_identity
    body = body/(n*n-n)
    extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[1])
    #extra_loss = None
  else:
    #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
    embedding = embedding * 5
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, 512), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                          weight = _weight,
                          beta=args.beta, margin=args.margin, scale=args.scale,
                          beta_min=args.beta_min, verbose=100, name='fc7')

    #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
    #                       beta=args.beta, margin=args.margin, scale=args.scale,
    #                       op_type='ASoftmax', name='fc7')
  if args.loss_type<=1 and args.incay>0.0:
    params = [1.e-10]
    sel = mx.symbol.argmax(data = fc7, axis=1)
    sel = (sel==gt_label)
    norm = embedding*embedding
    norm = mx.symbol.sum(norm, axis=1)
    norm = norm+params[0]
    feature_incay = sel/norm
    feature_incay = mx.symbol.mean(feature_incay) * args.incay
    extra_loss = mx.symbol.MakeLoss(feature_incay)
  #out = softmax
  #l2_embedding = mx.symbol.L2Normalization(embedding)

  #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
  #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
  out_list = [mx.symbol.BlockGrad(embedding)]
  softmax = None
  if args.loss_type<10:
    softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
    out_list.append(softmax)
  if softmax is None:
    out_list.append(mx.sym.BlockGrad(gt_label))
  if extra_loss is not None:
    out_list.append(extra_loss)
  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
Example #3
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_input=args.version_input,
                version_output=args.version_output,
                version_multiplier=args.version_multiplier)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(
            args.emb_size,
            bn_mom=args.bn_mom,
            version_output=args.version_output)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        if args.fc7_no_bias:
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=args.num_classes,
                                        name='fc7')
        else:
            _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        bias=_bias,
                                        num_hidden=args.num_classes,
                                        name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=s_m,
                                    off_value=0.0)
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    elif args.loss_type == 6:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert args.margin_b > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        t = mx.sym.arccos(cos_t)
        intra_loss = t / np.pi
        intra_loss = mx.sym.mean(intra_loss)
        #intra_loss = mx.sym.exp(cos_t*-1.0)
        intra_loss = mx.sym.MakeLoss(intra_loss,
                                     name='intra_loss',
                                     grad_scale=args.margin_b)
        if m > 0.0:
            t = t + m
            body = mx.sym.cos(t)
            new_zy = body * s
            diff = new_zy - zy
            diff = mx.sym.expand_dims(diff, 1)
            gt_one_hot = mx.sym.one_hot(gt_label,
                                        depth=args.num_classes,
                                        on_value=1.0,
                                        off_value=0.0)
            body = mx.sym.broadcast_mul(gt_one_hot, diff)
            fc7 = fc7 + body
    elif args.loss_type == 7:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert args.margin_b > 0.0
        assert args.margin_a > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        t = mx.sym.arccos(cos_t)

        #counter_weight = mx.sym.take(_weight, gt_label, axis=1)
        #counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
        counter_weight = mx.sym.take(_weight, gt_label, axis=0)
        counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True)
        #counter_cos = mx.sym.minimum(counter_cos, 1.0)
        #counter_angle = mx.sym.arccos(counter_cos)
        #counter_angle = counter_angle * -1.0
        #counter_angle = counter_angle/np.pi #[0,1]
        #inter_loss = mx.sym.exp(counter_angle)

        #counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True)
        #counter_cos = mx.sym.minimum(counter_cos, 1.0)
        #counter_angle = mx.sym.arccos(counter_cos)
        #counter_angle = mx.sym.sort(counter_angle, axis=1)
        #counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a))

        #inter_loss = counter_angle*-1.0 # [-1,0]
        #inter_loss = inter_loss+1.0 # [0,1]
        inter_loss = counter_cos
        inter_loss = mx.sym.mean(inter_loss)
        inter_loss = mx.sym.MakeLoss(inter_loss,
                                     name='inter_loss',
                                     grad_scale=args.margin_b)
        if m > 0.0:
            t = t + m
            body = mx.sym.cos(t)
            new_zy = body * s
            diff = new_zy - zy
            diff = mx.sym.expand_dims(diff, 1)
            gt_one_hot = mx.sym.one_hot(gt_label,
                                        depth=args.num_classes,
                                        on_value=1.0,
                                        off_value=0.0)
            body = mx.sym.broadcast_mul(gt_one_hot, diff)
            fc7 = fc7 + body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    out_list.append(softmax)
    if args.loss_type == 6:
        out_list.append(intra_loss)
    if args.loss_type == 7:
        out_list.append(inter_loss)
        #out_list.append(mx.sym.BlockGrad(counter_weight))
        #out_list.append(intra_loss)
    if args.ce_loss:
        #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
        body = mx.symbol.SoftmaxActivation(data=fc7)
        body = mx.symbol.log(body)
        _label = mx.sym.one_hot(gt_label,
                                depth=args.num_classes,
                                on_value=-1.0,
                                off_value=0.0)
        body = body * _label
        ce_loss = mx.symbol.sum(body) / args.per_batch_size
        out_list.append(mx.symbol.BlockGrad(ce_loss))
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #4
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit)
    all_label = mx.symbol.Variable('softmax_label')
    if not args.output_c2c:
        gt_label = all_label
    else:
        gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
        gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size, ))
        c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
        c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size, ))
    assert args.loss_type >= 0
    extra_loss = None
    if args.loss_type == 0:  #softmax
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 8:  #centerloss, TODO
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        print('center-loss', args.center_alpha, args.center_scale)
        extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\
              num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size)
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        if s > 0.0:
            nembedding = mx.symbol.L2Normalization(
                embedding, mode='instance', name='fc1n') * s
            fc7 = mx.sym.FullyConnected(data=nembedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=args.num_classes,
                                        name='fc7')
            if m > 0.0:
                if args.margin_verbose > 0:
                    zy = mx.sym.pick(fc7, gt_label, axis=1)
                    cos_t = zy / s
                    margin_symbols.append(mx.symbol.mean(cos_t))

                s_m = s * m
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot

                if args.margin_verbose > 0:
                    new_zy = mx.sym.pick(fc7, gt_label, axis=1)
                    new_cos_t = new_zy / s
                    margin_symbols.append(mx.symbol.mean(new_cos_t))
        else:
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=args.num_classes,
                                        name='fc7')
            if m > 0.0:
                body = embedding * embedding
                body = mx.sym.sum_axis(body, axis=1, keepdims=True)
                body = mx.sym.sqrt(body)
                body = body * m
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, body)
                fc7 = fc7 - body

    elif args.loss_type == 3:
        s = args.margin_s
        m = args.margin_m
        assert args.margin == 2 or args.margin == 4
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        if args.margin_verbose > 0:
            margin_symbols.append(mx.symbol.mean(cos_t))
        #threshold = math.cos(args.margin_m)
        #cond_v = cos_t - threshold
        #cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        #body = cos_t
        #for i in xrange(args.margin//2):
        #  body = body*body
        #  body = body*2-1
        #new_zy = body*s
        #zy_keep = zy
        #new_zy = mx.sym.where(cond, new_zy, zy_keep)
        #if args.margin_verbose>0:
        #  new_cos_t = new_zy/s
        #  margin_symbols.append(mx.symbol.mean(new_cos_t))
        #diff = new_zy - zy
        #diff = mx.sym.expand_dims(diff, 1)
        #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        #body = mx.sym.broadcast_mul(gt_one_hot, diff)
        #fc7 = fc7+body
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        if args.output_c2c == 0:
            cos_m = math.cos(m)
            sin_m = math.sin(m)
            mm = math.sin(math.pi - m) * m
            #threshold = 0.0
            threshold = math.cos(math.pi - m)
            if args.margin_verbose > 0:
                margin_symbols.append(mx.symbol.mean(cos_t))
            if args.easy_margin:
                cond = mx.symbol.Activation(data=cos_t, act_type='relu')
            else:
                cond_v = cos_t - threshold
                cond = mx.symbol.Activation(data=cond_v, act_type='relu')
            body = cos_t * cos_t
            body = 1.0 - body
            sin_t = mx.sym.sqrt(body)
            new_zy = cos_t * cos_m
            b = sin_t * sin_m
            new_zy = new_zy - b
            new_zy = new_zy * s
            if args.easy_margin:
                zy_keep = zy
            else:
                zy_keep = zy - s * mm
            new_zy = mx.sym.where(cond, new_zy, zy_keep)
        else:
            #set c2c as cosm^2 in data.py
            cos_m = mx.sym.sqrt(c2c_label)
            sin_m = 1.0 - c2c_label
            sin_m = mx.sym.sqrt(sin_m)
            body = cos_t * cos_t
            body = 1.0 - body
            sin_t = mx.sym.sqrt(body)
            new_zy = cos_t * cos_m
            b = sin_t * sin_m
            new_zy = new_zy - b
            new_zy = new_zy * s

        if args.margin_verbose > 0:
            new_cos_t = new_zy / s
            margin_symbols.append(mx.symbol.mean(new_cos_t))
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 10:  #marginal loss
        nembedding = mx.symbol.L2Normalization(embedding,
                                               mode='instance',
                                               name='fc1n')
        params = [1.2, 0.3, 1.0]
        n1 = mx.sym.expand_dims(nembedding, axis=1)  #N,1,C
        n2 = mx.sym.expand_dims(nembedding, axis=0)  #1,N,C
        body = mx.sym.broadcast_sub(n1, n2)  #N,N,C
        body = body * body
        body = mx.sym.sum(body, axis=2)  # N,N
        #body = mx.sym.sqrt(body)
        body = body - params[0]
        mask = mx.sym.Variable('extra')
        body = body * mask
        body = body + params[1]
        #body = mx.sym.maximum(body, 0.0)
        body = mx.symbol.Activation(data=body, act_type='relu')
        body = mx.sym.sum(body)
        body = body / (args.per_batch_size * args.per_batch_size -
                       args.per_batch_size)
        extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
    elif args.loss_type == 11:  #npair loss
        params = [0.9, 0.2]
        nembedding = mx.symbol.L2Normalization(embedding,
                                               mode='instance',
                                               name='fc1n')
        nembedding = mx.sym.transpose(nembedding)
        nembedding = mx.symbol.reshape(
            nembedding,
            (args.emb_size, args.per_identities, args.images_per_identity))
        nembedding = mx.sym.transpose(nembedding, axes=(2, 1, 0))  #2*id*512
        #nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities))
        #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512
        n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)
        n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)
        #n1 = []
        #n2 = []
        #for i in xrange(args.per_identities):
        #  _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1)
        #  _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2)
        #  n1.append(_n1)
        #  n2.append(_n2)
        #n1 = mx.sym.concat(*n1, dim=0)
        #n2 = mx.sym.concat(*n2, dim=0)
        #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512))
        #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1)
        #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2)
        n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size))
        n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size))
        cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2,
                                      transpose_b=True)  #id*id, id=N of N-pair
        data_extra = mx.sym.Variable('extra')
        data_extra = mx.sym.slice_axis(data_extra,
                                       axis=0,
                                       begin=0,
                                       end=args.per_identities)
        mask = cosine_matrix * data_extra
        #body = mx.sym.mean(mask)
        fii = mx.sym.sum_axis(mask, axis=1)
        fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)
        fij_fii = mx.sym.exp(fij_fii)
        row = mx.sym.sum_axis(fij_fii, axis=1)
        row = mx.sym.log(row)
        body = mx.sym.mean(row)
        extra_loss = mx.sym.MakeLoss(body)
    elif args.loss_type == 12:  #triplet loss
        nembedding = mx.symbol.L2Normalization(embedding,
                                               mode='instance',
                                               name='fc1n')
        anchor = mx.symbol.slice_axis(nembedding,
                                      axis=0,
                                      begin=0,
                                      end=args.per_batch_size // 3)
        positive = mx.symbol.slice_axis(nembedding,
                                        axis=0,
                                        begin=args.per_batch_size // 3,
                                        end=2 * args.per_batch_size // 3)
        negative = mx.symbol.slice_axis(nembedding,
                                        axis=0,
                                        begin=2 * args.per_batch_size // 3,
                                        end=args.per_batch_size)
        ap = anchor - positive
        an = anchor - negative
        ap = ap * ap
        an = an * an
        ap = mx.symbol.sum(ap, axis=1, keepdims=1)  #(T,1)
        an = mx.symbol.sum(an, axis=1, keepdims=1)  #(T,1)
        triplet_loss = mx.symbol.Activation(data=(ap - an +
                                                  args.triplet_alpha),
                                            act_type='relu')
        triplet_loss = mx.symbol.mean(triplet_loss)
        #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
        extra_loss = mx.symbol.MakeLoss(triplet_loss)
    elif args.loss_type == 9:  #coco loss
        centroids = []
        for i in xrange(args.per_identities):
            xs = mx.symbol.slice_axis(embedding,
                                      axis=0,
                                      begin=i * args.images_per_identity,
                                      end=(i + 1) * args.images_per_identity)
            mean = mx.symbol.mean(xs, axis=0, keepdims=True)
            mean = mx.symbol.L2Normalization(mean, mode='instance')
            centroids.append(mean)
        centroids = mx.symbol.concat(*centroids, dim=0)
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * args.coco_scale
        fc7 = mx.symbol.dot(nembedding, centroids,
                            transpose_b=True)  #(batchsize, per_identities)
        #extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
        #extra_loss = mx.symbol.BlockGrad(extra_loss)
    else:
        #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
        embedding = embedding * 5
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=100,
                              name='fc7')

        #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
        #                       beta=args.beta, margin=args.margin, scale=args.scale,
        #                       op_type='ASoftmax', name='fc7')
    if args.loss_type <= 1 and args.incay > 0.0:
        params = [1.e-10]
        sel = mx.symbol.argmax(data=fc7, axis=1)
        sel = (sel == gt_label)
        norm = embedding * embedding
        norm = mx.symbol.sum(norm, axis=1)
        norm = norm + params[0]
        feature_incay = sel / norm
        feature_incay = mx.symbol.mean(feature_incay) * args.incay
        extra_loss = mx.symbol.MakeLoss(feature_incay)
    #out = softmax
    #l2_embedding = mx.symbol.L2Normalization(embedding)

    #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
    #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = None
    if args.loss_type < 10:
        softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                          label=gt_label,
                                          name='softmax',
                                          normalization='valid')
        out_list.append(softmax)
    if softmax is None:
        out_list.append(mx.sym.BlockGrad(gt_label))
    if extra_loss is not None:
        out_list.append(extra_loss)
    for _sym in margin_symbols:
        _sym = mx.sym.BlockGrad(_sym)
        out_list.append(_sym)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #5
0
def get_symbol(args, arg_params, aux_params, sym_embedding=None):
    if sym_embedding is None:
        if args.network[0] == 'd':
            embedding = fdensenet.get_symbol(
                args.emb_size,
                args.num_layers,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        elif args.network[0] == 'm':
            print('init mobilenet', args.num_layers)
            if args.num_layers == 1:
                embedding = fmobilenet.get_symbol(
                    args.emb_size,
                    version_se=args.version_se,
                    version_input=args.version_input,
                    version_output=args.version_output,
                    version_unit=args.version_unit)
            else:
                embedding = fmobilenetv2.get_symbol(args.emb_size)
        elif args.network[0] == 'i':
            print('init inception-resnet-v2', args.num_layers)
            embedding = finception_resnet_v2.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        elif args.network[0] == 'x':
            print('init xception', args.num_layers)
            embedding = fxception.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        elif args.network[0] == 'p':
            print('init dpn', args.num_layers)
            embedding = fdpn.get_symbol(args.emb_size,
                                        args.num_layers,
                                        version_se=args.version_se,
                                        version_input=args.version_input,
                                        version_output=args.version_output,
                                        version_unit=args.version_unit)
        elif args.network[0] == 'n':
            print('init nasnet', args.num_layers)
            embedding = fnasnet.get_symbol(args.emb_size)
        elif args.network[0] == 's':
            print('init spherenet', args.num_layers)
            embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
        else:
            print('init resnet', args.num_layers)
            embedding = fresnet.get_symbol(args.emb_size,
                                           args.num_layers,
                                           version_se=args.version_se,
                                           version_input=args.version_input,
                                           version_output=args.version_output,
                                           version_unit=args.version_unit,
                                           version_act=args.version_act)
    else:
        embedding = sym_embedding

    gt_label = mx.symbol.Variable('softmax_label')
    nembedding = mx.symbol.L2Normalization(embedding,
                                           mode='instance',
                                           name='fc1n')
    anchor = mx.symbol.slice_axis(nembedding,
                                  axis=0,
                                  begin=0,
                                  end=args.per_batch_size // 3)
    positive = mx.symbol.slice_axis(nembedding,
                                    axis=0,
                                    begin=args.per_batch_size // 3,
                                    end=2 * args.per_batch_size // 3)
    negative = mx.symbol.slice_axis(nembedding,
                                    axis=0,
                                    begin=2 * args.per_batch_size // 3,
                                    end=args.per_batch_size)
    ap = anchor - positive
    an = anchor - negative
    ap = ap * ap
    an = an * an
    ap = mx.symbol.sum(ap, axis=1, keepdims=1)  #(T,1)
    an = mx.symbol.sum(an, axis=1, keepdims=1)  #(T,1)
    triplet_loss = mx.symbol.Activation(data=(ap - an + args.triplet_alpha),
                                        act_type='relu')
    triplet_loss = mx.symbol.mean(triplet_loss)
    #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
    triplet_loss = mx.symbol.MakeLoss(triplet_loss)
    out_list = [mx.symbol.BlockGrad(embedding)]
    out_list.append(mx.sym.BlockGrad(gt_label))
    out_list.append(triplet_loss)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #6
0
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act)
  all_label = mx.symbol.Variable('softmax_label')
  gt_label = all_label
  extra_loss = None
  s = args.margin_s
  #m = args.margin_m
  assert s>0.0
  nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
  out_list = [mx.symbol.BlockGrad(embedding)]

  _args = copy.deepcopy(args)

  if USE_FR:
      _args.grad_scale = 1.0
      fr_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
      fr_label = mx.symbol.reshape(fr_label, (args.per_batch_size,))
      fr_softmax = get_softmax(_args, embedding, nembedding, fr_label, 'fc7')
      out_list.append(fr_softmax)

  if USE_GENDER:
      _args.grad_scale = 0.2
      _args.margin_a = 0.0
      _args.num_classes = 2
      gender_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
      gender_label = mx.symbol.reshape(gender_label, (args.per_batch_size,))
      gender_softmax = get_softmax(_args, embedding, nembedding, gender_label, 'fc8')
      out_list.append(gender_softmax)

  if USE_AGE:
      _args.grad_scale = 0.01
      _args.margin_a = 0.0
      _args.num_classes = 2
      for i in xrange(AGE):
          age_label = mx.symbol.slice_axis(all_label, axis=1, begin=2+i, end=3+i)
          age_label = mx.symbol.reshape(age_label, (args.per_batch_size,))
          age_softmax = get_softmax(_args, embedding, nembedding, age_label, 'fc9_%d'%(i))
          out_list.append(age_softmax)

  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
Example #7
0
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  elif args.network[0]=='y':
    print('init mobilefacenet', args.num_layers)
    embedding = fmobilefacenet.get_symbol(args.emb_size)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act)
  all_label = mx.symbol.Variable('softmax_label')
  gt_label = all_label
  extra_loss = None
  s = args.margin_s
  #m = args.margin_m
  assert s>0.0
  nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
  out_list = [mx.symbol.BlockGrad(embedding)]

  _args = copy.deepcopy(args)

  if USE_FR:
      _args.grad_scale = 1.0
      fr_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
      fr_label = mx.symbol.reshape(fr_label, (args.per_batch_size,))
      fr_softmax = get_softmax(_args, embedding, nembedding, fr_label, 'fc7')
      out_list.append(fr_softmax)

  if USE_GENDER:
      _args.grad_scale = 0.2
      _args.margin_a = 0.0
      _args.num_classes = 2
      gender_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
      gender_label = mx.symbol.reshape(gender_label, (args.per_batch_size,))
      gender_softmax = get_softmax(_args, embedding, nembedding, gender_label, 'fc8')
      out_list.append(gender_softmax)

  if USE_AGE:
      _args.grad_scale = 0.01
      _args.margin_a = 0.0
      _args.num_classes = 2
      for i in xrange(AGE):
          age_label = mx.symbol.slice_axis(all_label, axis=1, begin=2+i, end=3+i)
          age_label = mx.symbol.reshape(age_label, (args.per_batch_size,))
          age_softmax = get_softmax(_args, embedding, nembedding, age_label, 'fc9_%d'%(i))
          out_list.append(age_softmax)

  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
Example #8
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel,args.image_h,args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0]=='d':
        embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='m':
        print('init mobilenet', args.num_layers)
        if args.num_layers==1:
            embedding = fmobilenet.get_symbol(args.emb_size, 
                version_se=args.version_se, version_input=args.version_input, 
                version_output=args.version_output, version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0]=='i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(args.emb_size,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0]=='s':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0]=='y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit,
            version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    #center_label = mx.symbol.Variable('center_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult)
    if args.loss_type==0: #softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
    elif args.loss_type==1: #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                            weight = _weight,
                            beta=args.beta, margin=args.margin, scale=args.scale,
                            beta_min=args.beta_min, verbose=1000, name='fc7')
    elif args.loss_type==2:
        s = args.margin_s
        m = args.margin_m
        assert(s>0.0)
        assert(m>0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        s_m = s*m
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
    elif args.loss_type==4:
        s = args.margin_s
        m = args.margin_m
        assert s>0.0
        assert m>=0.0
        assert m<(math.pi/2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy/s
        t = mx.sym.arccos(cos_t)
        if args.margin_a!=1.0:
          t = t*args.margin_a
        if args.margin_m>0.0:
          t = t+args.margin_m
        body = mx.sym.cos(t)
        if args.margin_b>0.0:
          body = body - args.margin_b
        new_zy = body*s
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7+body
    elif args.loss_type==5:
        s = args.margin_s
        m = args.margin_m
        assert s>0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        nembedding = mx.symbol.Dropout(data=nembedding, p=0.4)
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
            if args.margin_a==1.0 and args.margin_m==0.0:
                s_m = s*args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
                fc7 = fc7-gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy/s
                t = mx.sym.arccos(cos_t)
                if args.margin_a!=1.0:
                    t = t*args.margin_a
                if args.margin_m>0.0:
                    t = t+args.margin_m
                #threshold = math.cos(math.pi-m)
                if args.easy_margin:  # m<pi/2 so: pi-m > pi/2
                    cond = mx.symbol.Activation(data=cos_t, act_type='relu') # this means t > pi/2,  when pi/2 <t < 3*pi/2, sin is decrease
                    #cond_v = cos_t - threshold
                    #cond = mx.symbol.Activation(data=cond_v, act_type='relu')
                else:
                    cond_v = math.pi - t
                    cond = mx.symbol.Activation(data=cond_v, act_type='relu')
                    print("not easy margin: cond")
                    #cond = mx.symbol.Activation(data=cos_t, act_type='relu')
                if args.easy_margin:
                    zy_keep = mx.sym.cos(t)
                else:
                    print("not easy margin: sine")
                    zy_keep = mx.sym.sin(t)-1
                new_zy = mx.sym.cos(t)
                body = mx.sym.where(cond, new_zy, zy_keep)
                #body = mx.sym.cos(t)
                if args.margin_b>0.0:
                    body = body + args.margin_b
                new_zy = body*s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7+body
    elif args.loss_type==6:
        s = args.margin_s
        m = args.margin_m
        assert s>0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
            if args.margin_a==1.0 and args.margin_m==0.0:
                s_m = s*args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
                fc7 = fc7-gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy/s
                t = mx.sym.arccos(cos_t)
                if args.margin_a!=1.0:
                    t = t*args.margin_a
                if args.margin_m>0.0:
                    t = t+args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b>0.0:
                    body = body - args.margin_b
                new_zy = body*s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7+body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
    out_list.append(softmax)
    #print("out shape",np.shape(softmax))
    #center_in = mx.symbol.concat(embedding,fc7,dim=1)
    #center_loss_data = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss_data', op_type='centerloss',\
    #        num_class=args.num_classes, alpha=0.5, scale=0.5,lamdb=0.1,batchsize=args.per_batch_size,emb_size=args.emb_size)
    #extra_center_loss = mx.symbol.MakeLoss(name='extra_center_loss', data=center_loss_data)
    #total_loss = mx.symbol.ElementWiseSum([softmax, extra_loss],name='total_loss')
    #total_loss_op = mx.symbol.MakeLoss(name='total_loss_op',data=total_loss)
    #out_list.append(extra_center_loss)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #9
0
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  elif args.network[0]=='y':
    print('init mobilefacenet', args.num_layers)
    embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act)
  all_label = mx.symbol.Variable('softmax_label')
  gt_label = all_label
  extra_loss = None
  _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult)
  if args.loss_type==0: #softmax
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
  elif args.loss_type==1: #sphere
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                          weight = _weight,
                          beta=args.beta, margin=args.margin, scale=args.scale,
                          beta_min=args.beta_min, verbose=1000, name='fc7')
  elif args.loss_type==2:
    s = args.margin_s
    m = args.margin_m
    assert(s>0.0)
    assert(m>0.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    s_m = s*m
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
    fc7 = fc7-gt_one_hot
  elif args.loss_type==4:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    cos_m = math.cos(m)
    sin_m = math.sin(m)
    mm = math.sin(math.pi-m)*m
    #threshold = 0.0
    threshold = math.cos(math.pi-m)
    if args.easy_margin:
      cond = mx.symbol.Activation(data=cos_t, act_type='relu')
    else:
      cond_v = cos_t - threshold
      cond = mx.symbol.Activation(data=cond_v, act_type='relu')
    body = cos_t*cos_t
    body = 1.0-body
    sin_t = mx.sym.sqrt(body)
    new_zy = cos_t*cos_m
    b = sin_t*sin_m
    new_zy = new_zy - b
    new_zy = new_zy*s
    if args.easy_margin:
      zy_keep = zy
    else:
      zy_keep = zy - s*mm
    new_zy = mx.sym.where(cond, new_zy, zy_keep)

    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
  elif args.loss_type==5:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
      if args.margin_a==1.0 and args.margin_m==0.0:
        s_m = s*args.margin_b
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
      else:
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy/s
        t = mx.sym.arccos(cos_t)
        if args.margin_a!=1.0:
          t = t*args.margin_a
        if args.margin_m>0.0:
          t = t+args.margin_m
        body = mx.sym.cos(t)
        if args.margin_b>0.0:
          body = body - args.margin_b
        new_zy = body*s
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7+body
  out_list = [mx.symbol.BlockGrad(embedding)]
  softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
  out_list.append(softmax)
  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
Example #10
0
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act)
  all_label = mx.symbol.Variable('softmax_label')
  if not args.output_c2c:
    gt_label = all_label
  else:
    gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
    gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size,))
    c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
    c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size,))
  assert args.loss_type>=0
  extra_loss = None
  if args.loss_type==0: #softmax
    _weight = mx.symbol.Variable('fc7_weight')
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
  elif args.loss_type==1: #sphere
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                          weight = _weight,
                          beta=args.beta, margin=args.margin, scale=args.scale,
                          beta_min=args.beta_min, verbose=1000, name='fc7')
  elif args.loss_type==8: #centerloss, TODO
    _weight = mx.symbol.Variable('fc7_weight')
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
    print('center-loss', args.center_alpha, args.center_scale)
    extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss',\
          num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale, batchsize=args.per_batch_size)
  elif args.loss_type==2:
    s = args.margin_s
    m = args.margin_m
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    if s>0.0:
      nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
      fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
      if m>0.0:
        if args.margin_verbose>0:
          zy = mx.sym.pick(fc7, gt_label, axis=1)
          cos_t = zy/s
          margin_symbols.append(mx.symbol.mean(cos_t))

        s_m = s*m
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot

        if args.margin_verbose>0:
          new_zy = mx.sym.pick(fc7, gt_label, axis=1)
          new_cos_t = new_zy/s
          margin_symbols.append(mx.symbol.mean(new_cos_t))
    else:
      fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
      if m>0.0:
        body = embedding*embedding
        body = mx.sym.sum_axis(body, axis=1, keepdims=True)
        body = mx.sym.sqrt(body)
        body = body*m
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, body)
        fc7 = fc7-body

  elif args.loss_type==3:
    s = args.margin_s
    m = args.margin_m
    assert args.margin==2 or args.margin==4
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    if args.margin_verbose>0:
      margin_symbols.append(mx.symbol.mean(cos_t))
    if m>1.0:
      t = mx.sym.arccos(cos_t)
      t = t*m
      body = mx.sym.cos(t)
      new_zy = body*s
      if args.margin_verbose>0:
        new_cos_t = new_zy/s
        margin_symbols.append(mx.symbol.mean(new_cos_t))
      diff = new_zy - zy
      diff = mx.sym.expand_dims(diff, 1)
      gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
      body = mx.sym.broadcast_mul(gt_one_hot, diff)
      fc7 = fc7+body

    #threshold = math.cos(args.margin_m)
    #cond_v = cos_t - threshold
    #cond = mx.symbol.Activation(data=cond_v, act_type='relu')
    #body = cos_t
    #for i in xrange(args.margin//2):
    #  body = body*body
    #  body = body*2-1
    #new_zy = body*s
    #zy_keep = zy
    #new_zy = mx.sym.where(cond, new_zy, zy_keep)
    #if args.margin_verbose>0:
    #  new_cos_t = new_zy/s
    #  margin_symbols.append(mx.symbol.mean(new_cos_t))
    #diff = new_zy - zy
    #diff = mx.sym.expand_dims(diff, 1)
    #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    #body = mx.sym.broadcast_mul(gt_one_hot, diff)
    #fc7 = fc7+body
  elif args.loss_type==4:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    if args.margin_verbose>0:
      margin_symbols.append(mx.symbol.mean(cos_t))
    if args.output_c2c==0:
      cos_m = math.cos(m)
      sin_m = math.sin(m)
      mm = math.sin(math.pi-m)*m
      #threshold = 0.0
      threshold = math.cos(math.pi-m)
      if args.easy_margin:
        cond = mx.symbol.Activation(data=cos_t, act_type='relu')
      else:
        cond_v = cos_t - threshold
        cond = mx.symbol.Activation(data=cond_v, act_type='relu')
      body = cos_t*cos_t
      body = 1.0-body
      sin_t = mx.sym.sqrt(body)
      new_zy = cos_t*cos_m
      b = sin_t*sin_m
      new_zy = new_zy - b
      new_zy = new_zy*s
      if args.easy_margin:
        zy_keep = zy
      else:
        zy_keep = zy - s*mm
      new_zy = mx.sym.where(cond, new_zy, zy_keep)
    else:
      #set c2c as cosm^2 in data.py
      cos_m = mx.sym.sqrt(c2c_label)
      sin_m = 1.0-c2c_label
      sin_m = mx.sym.sqrt(sin_m)
      body = cos_t*cos_t
      body = 1.0-body
      sin_t = mx.sym.sqrt(body)
      new_zy = cos_t*cos_m
      b = sin_t*sin_m
      new_zy = new_zy - b
      new_zy = new_zy*s

    if args.margin_verbose>0:
      new_cos_t = new_zy/s
      margin_symbols.append(mx.symbol.mean(new_cos_t))
    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
  elif args.loss_type==5:
    #s = args.margin_s
    #m = args.margin_m
    #assert s>0.0
    #assert m>=0.0
    #assert m<(math.pi/2)
    #_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    #_weight = mx.symbol.L2Normalization(_weight, mode='instance')
    #nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    #fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    #zy = mx.sym.pick(fc7, gt_label, axis=1)
    #cos_t = zy/s
    #if args.margin_verbose>0:
    #  margin_symbols.append(mx.symbol.mean(cos_t))
    #if m>0.0:
    #  a1 = args.margin_a
    #  r1 = ta-a1
    #  r1 = mx.symbol.Activation(data=r1, act_type='relu')
    #  r1 = r1+a1
    #  t = mx.sym.arccos(cos_t)
    #  cond = t-1.0
    #  cond = mx.symbol.Activation(data=cond, act_type='relu')
    #  r = mx.sym.where(cond, r2, r1)
    #  t = t+var_m
    #  body = mx.sym.cos(t)
    #  new_zy = body*s
    #  if args.margin_verbose>0:
    #    new_cos_t = new_zy/s
    #    margin_symbols.append(mx.symbol.mean(new_cos_t))
    #    #margin_symbols.append(mx.symbol.mean(var_m))
    #  diff = new_zy - zy
    #  diff = mx.sym.expand_dims(diff, 1)
    #  gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    #  body = mx.sym.broadcast_mul(gt_one_hot, diff)
    #  fc7 = fc7+body
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    #assert m>=0.0
    #assert m<(math.pi/2)
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    t = mx.sym.arccos(cos_t)
    if args.margin_verbose>0:
      margin_symbols.append(mx.symbol.mean(t))
    if args.margin_a>0.0:
      t = t*args.margin_a
    if args.margin_m>0.0:
      t = t+args.margin_m
    body = mx.sym.cos(t)
    if args.margin_b>0.0:
      body = body - args.margin_b
    new_zy = body*s
    if args.margin_verbose>0:
      margin_symbols.append(mx.symbol.mean(t))
    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
  elif args.loss_type==6:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    t = mx.sym.arccos(cos_t)
    if args.margin_verbose>0:
      margin_symbols.append(mx.symbol.mean(t))
    t_min = mx.sym.min(t)
    ta = mx.sym.broadcast_div(t_min, t)

    a1 = args.margin_a
    r1 = ta-a1
    r1 = mx.symbol.Activation(data=r1, act_type='relu')
    r1 = r1+a1

    r2 = mx.symbol.zeros(shape=(args.per_batch_size,))

    cond = t-1.0
    cond = mx.symbol.Activation(data=cond, act_type='relu')
    r = mx.sym.where(cond, r2, r1)
    var_m = r*m
    t = t+var_m
    body = mx.sym.cos(t)
    new_zy = body*s
    if args.margin_verbose>0:
      #new_cos_t = new_zy/s
      #margin_symbols.append(mx.symbol.mean(new_cos_t))
      margin_symbols.append(mx.symbol.mean(t))
    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
  elif args.loss_type==7:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    t = mx.sym.arccos(cos_t)
    if args.margin_verbose>0:
      margin_symbols.append(mx.symbol.mean(t))
    var_m = mx.sym.random.uniform(low=args.margin_a, high=args.margin_m, shape=(1,))
    t = mx.sym.broadcast_add(t,var_m)
    body = mx.sym.cos(t)
    new_zy = body*s
    if args.margin_verbose>0:
      #new_cos_t = new_zy/s
      #margin_symbols.append(mx.symbol.mean(new_cos_t))
      margin_symbols.append(mx.symbol.mean(t))
    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
  elif args.loss_type==10: #marginal loss
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
    params = [1.2, 0.3, 1.0]
    n1 = mx.sym.expand_dims(nembedding, axis=1) #N,1,C
    n2 = mx.sym.expand_dims(nembedding, axis=0) #1,N,C
    body = mx.sym.broadcast_sub(n1, n2) #N,N,C
    body = body * body
    body = mx.sym.sum(body, axis=2) # N,N
    #body = mx.sym.sqrt(body)
    body = body - params[0]
    mask = mx.sym.Variable('extra')
    body = body*mask
    body = body+params[1]
    #body = mx.sym.maximum(body, 0.0)
    body = mx.symbol.Activation(data=body, act_type='relu')
    body = mx.sym.sum(body)
    body = body/(args.per_batch_size*args.per_batch_size-args.per_batch_size)
    extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
  elif args.loss_type==11: #npair loss
    params = [0.9, 0.2]
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
    nembedding = mx.sym.transpose(nembedding)
    nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.per_identities, args.images_per_identity))
    nembedding = mx.sym.transpose(nembedding, axes=(2,1,0)) #2*id*512
    #nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities))
    #nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512
    n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)
    n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)
    #n1 = []
    #n2 = []
    #for i in xrange(args.per_identities):
    #  _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1)
    #  _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2)
    #  n1.append(_n1)
    #  n2.append(_n2)
    #n1 = mx.sym.concat(*n1, dim=0)
    #n2 = mx.sym.concat(*n2, dim=0)
    #rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512))
    #n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1)
    #n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2)
    n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size))
    n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size))
    cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b = True) #id*id, id=N of N-pair
    data_extra = mx.sym.Variable('extra')
    data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities)
    mask = cosine_matrix * data_extra
    #body = mx.sym.mean(mask)
    fii = mx.sym.sum_axis(mask, axis=1)
    fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)
    fij_fii = mx.sym.exp(fij_fii)
    row = mx.sym.sum_axis(fij_fii, axis=1)
    row = mx.sym.log(row)
    body = mx.sym.mean(row)
    extra_loss = mx.sym.MakeLoss(body)
  elif args.loss_type==12: #triplet loss
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
    anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
    positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
    negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
    ap = anchor - positive
    an = anchor - negative
    ap = ap*ap
    an = an*an
    ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
    an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
    triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
    triplet_loss = mx.symbol.mean(triplet_loss)
    #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
    extra_loss = mx.symbol.MakeLoss(triplet_loss)
  elif args.loss_type==13: #triplet loss with angular margin
    m = args.margin_m
    sin_m = math.sin(m)
    cos_m = math.cos(m)
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
    anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
    positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
    negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
    ap = anchor * positive
    an = anchor * negative
    ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
    an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)

    ap = mx.symbol.arccos(ap)
    an = mx.symbol.arccos(an)
    triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu')

    #body = ap*ap
    #body = 1.0-body
    #body = mx.symbol.sqrt(body)
    #body = body*sin_m
    #ap = ap*cos_m
    #ap = ap-body
    #triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')

    triplet_loss = mx.symbol.mean(triplet_loss)
    extra_loss = mx.symbol.MakeLoss(triplet_loss)
  elif args.loss_type==9: #coco loss
    centroids = []
    for i in xrange(args.per_identities):
      xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity)
      mean = mx.symbol.mean(xs, axis=0, keepdims=True)
      mean = mx.symbol.L2Normalization(mean, mode='instance')
      centroids.append(mean)
    centroids = mx.symbol.concat(*centroids, dim=0)
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale
    fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities)
    #extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
    #extra_loss = mx.symbol.BlockGrad(extra_loss)
  else:
    #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
    embedding = embedding * 5
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                          weight = _weight,
                          beta=args.beta, margin=args.margin, scale=args.scale,
                          beta_min=args.beta_min, verbose=100, name='fc7')

    #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
    #                       beta=args.beta, margin=args.margin, scale=args.scale,
    #                       op_type='ASoftmax', name='fc7')
  if args.loss_type<=1 and args.incay>0.0:
    params = [1.e-10]
    sel = mx.symbol.argmax(data = fc7, axis=1)
    sel = (sel==gt_label)
    norm = embedding*embedding
    norm = mx.symbol.sum(norm, axis=1)
    norm = norm+params[0]
    feature_incay = sel/norm
    feature_incay = mx.symbol.mean(feature_incay) * args.incay
    extra_loss = mx.symbol.MakeLoss(feature_incay)
  #out = softmax
  #l2_embedding = mx.symbol.L2Normalization(embedding)

  #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
  #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
  out_list = [mx.symbol.BlockGrad(embedding)]
  softmax = None
  if args.loss_type<10:
    softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
    out_list.append(softmax)
    if args.logits_verbose>0:
      logits = mx.symbol.softmax(data = fc7)
      logits = mx.sym.pick(logits, gt_label, axis=1)
      margin_symbols.append(logits)
      #logit_max = mx.sym.max(logits)
      #logit_min = mx.sym.min(logits)
      #margin_symbols.append(logit_max)
      #margin_symbols.append(logit_min)
  if softmax is None:
    out_list.append(mx.sym.BlockGrad(gt_label))
  if extra_loss is not None:
    out_list.append(extra_loss)
  for _sym in margin_symbols:
    _sym = mx.sym.BlockGrad(_sym)
    out_list.append(_sym)
  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    if args.network[0] == 'so':
        print('init spherenet_o', args.num_layers)
        embedding = spherenet.get_symbol(0, args.emb_size, args.num_layers)
    elif args.network[0] == '':
        print('init spherenet', args.num_layers)
        embedding = spherenet_bn.get_symbol(args.emb_size, args.num_layers)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance')
    out_list = [mx.symbol.BlockGrad(embedding)]
    all_label = mx.symbol.Variable('softmax_label')

    label_softmax = mx.sym.slice_axis(all_label,
                                      axis=0,
                                      begin=0,
                                      end=args.batch_size // args.ctx_num)
    nembedding_softmax = mx.sym.slice_axis(nembedding,
                                           axis=0,
                                           begin=0,
                                           end=args.batch_size // args.ctx_num)
    label_inter = mx.sym.slice_axis(all_label,
                                    axis=0,
                                    begin=args.batch_size // args.ctx_num,
                                    end=args.batch_size // args.ctx_num +
                                    args.batchsize_id // args.ctx_num)
    nembedding_inter = mx.sym.slice_axis(nembedding,
                                         axis=0,
                                         begin=args.batch_size // args.ctx_num,
                                         end=args.batch_size // args.ctx_num +
                                         args.batchsize_id // args.ctx_num)

    # nembedding_inter = mx.symbol.L2Normalization(embedding_inter, mode='instance')
    nembedding_inter = mx.sym.transpose(nembedding_inter)
    nembedding_inter = mx.symbol.reshape(
        nembedding_inter,
        (args.emb_size, args.batchsize_id //
         (args.ctx_num * args.images_per_identity), args.images_per_identity))
    nembedding_inter = mx.sym.transpose(nembedding_inter,
                                        axes=(2, 1, 0))  # 3*id*512
    nembedding_inter = mx.sym.mean(nembedding_inter, axis=0)
    nembedding_inter = mx.sym.L2Normalization(nembedding_inter,
                                              mode='instance')
    emb_norm = mx.sym.norm(nembedding_inter)
    nembedding_inter_t = mx.sym.transpose(nembedding_inter)
    cosine_matrix = mx.sym.dot(nembedding_inter, nembedding_inter_t)
    cosine_matrix = cosine_matrix - mx.symbol.eye(
        args.batchsize_id // (args.ctx_num * args.images_per_identity))
    cosine_matrix = cosine_matrix * cosine_matrix
    inter_loss = args.interweight * mx.symbol.mean(cosine_matrix)
    inter_loss = mx.sym.MakeLoss(inter_loss)

    if args.loss_type == 0:
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=nembedding_softmax,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
    else:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0,
                                     wd_mult=args.fc7_wd_mult)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding_softmax = nembedding_softmax * s
        fc7 = mx.sym.FullyConnected(data=nembedding_softmax,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, label_softmax, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        # threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(label_softmax,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body

    #1
    #softmaxloss = mx.symbol.SoftmaxOutput(data=fc7, label=label_softmax, name='softmax', normalization='valid')
    #2
    #softmax=mx.sym.softmax_cross_entropy(data=fc7, label=label_softmax)
    #softmaxloss = mx.sym.MakeLoss(softmax)
    #3
    if args.noise:
        softmaxs = mx.sym.log_softmax(data=fc7, name="softmax")
        pred_label = mx.sym.argmax(softmaxs, axis=1)
        pred_one_hot = mx.sym.one_hot(pred_label,
                                      depth=args.num_classes,
                                      on_value=1.0,
                                      off_value=0.0)
        gt_one_hot = mx.sym.one_hot(label_softmax,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        cross_entropy_gt = -mx.sym.sum(
            mx.sym.broadcast_mul(gt_one_hot, softmaxs), axis=[0, 1])
        cross_entropy_pred = -mx.sym.sum(
            mx.sym.broadcast_mul(pred_one_hot, softmaxs), axis=[0, 1])
        cross_entropy = args.noise_beta * cross_entropy_gt + (
            1 - args.noise_beta) * cross_entropy_pred
        cross_entropy = cross_entropy / (args.batch_size // 2)
        softmaxloss = mx.sym.MakeLoss(cross_entropy)
    else:
        softmaxs = mx.sym.log_softmax(data=fc7, name="softmax")
        gt_one_hot = mx.sym.one_hot(label_softmax,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        cross_entropy = -mx.sym.sum(mx.sym.broadcast_mul(gt_one_hot, softmaxs),
                                    axis=[0, 1])
        cross_entropy = cross_entropy / (args.batch_size // 2)
        softmaxloss = mx.sym.MakeLoss(cross_entropy)

    out_list.append(mx.symbol.BlockGrad(label_softmax))
    out_list.append(mx.symbol.BlockGrad(fc7))
    out_list.append(mx.symbol.BlockGrad(label_inter))
    out_list.append(mx.symbol.BlockGrad(softmaxs))
    out_list.append(softmaxloss)
    out_list.append(inter_loss)
    out_list.append(mx.symbol.BlockGrad(emb_norm))
    out_list.append(mx.symbol.BlockGrad(cosine_matrix))
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #12
0
def get_symbol(args, arg_params, aux_params, sym_embedding=None):
    if sym_embedding is None:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    else:
        embedding = sym_embedding

    gt_label = mx.symbol.Variable('softmax_label')
    nembedding = mx.symbol.L2Normalization(embedding,
                                           mode='instance',
                                           name='fc1n')
    '''
  anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
  positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
  negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
  ap = anchor - positive
  an = anchor - negative
  ap = ap*ap
  an = an*an
  ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
  an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
  triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
  triplet_loss = mx.symbol.mean(triplet_loss)
  '''

    # n = mx.symbol.shape_array(nembedding)[0]
    n = args.per_batch_size
    # dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
    dist = mx.symbol.pow(nembedding, 2)
    dist = mx.symbol.sum(dist, 1, True)
    dist = mx.symbol.broadcast_to(dist, shape=(n, n))
    # dist = dist + dist.t()
    dist = dist + mx.symbol.transpose(dist)
    # dist.addmm_(1, -2, inputs, inputs.t()).clamp_(min=).sqrt_()
    dist = dist - 2 * mx.symbol.dot(nembedding,
                                    mx.symbol.transpose(nembedding))
    dist = mx.symbol.maximum(dist, 1e-12)
    dist = mx.symbol.sqrt(dist)
    # dist = dist * gl_conf.scale
    #todo####### dist = dist*
    # todo how to use triplet only, can use temprature decay/progessive learinig curriculum learning
    # For each anchor, find the hardest positive and negative
    #mask = targets.expand(n, n).eq(targets.expand(n, n).t())
    label = mx.symbol.reshape(gt_label, (n, 1))
    mask = mx.symbol.broadcast_equal(label, mx.symbol.transpose(label))
    # a = to_numpy(targets)
    # print(a.shape,  np.unique(a).shape)
    # daps = dist[mask].view(n, -1)  # here can use -1, assume the number of ap is the same, e.g., all is 4!
    mask = mx.symbol.argsort(mask)
    mask_o = mx.symbol.slice(mask,
                             begin=(0, n - args.images_per_identity),
                             end=(n, n))
    mask_o = mx.symbol.reshape(mask_o, (1, -1))

    order = mx.symbol.reshape(mx.symbol.arange(start=0, stop=n), (1, -1))
    order_p = mx.symbol.concat(order, order, dim=0)
    for i in range(args.images_per_identity - 2):
        order_p = mx.symbol.concat(order_p, order, dim=0)
    order_p = mx.symbol.reshape(order_p, (1, -1))
    order_p = mx.symbol.sort(order_p, axis=1)
    mask_p = mx.symbol.concat(order_p, mask_o, dim=0)
    #mask_p[:, 1] = mask_o
    daps = mx.symbol.gather_nd(dist, mask_p)

    daps = mx.symbol.reshape(daps, (n, -1))
    # todo how to copy with varied length?
    # dans = dist[mask == 0].view(n, -1)
    mask_o = mx.symbol.slice(mask,
                             begin=(0, 0),
                             end=(n, n - args.images_per_identity))
    mask_o = mx.symbol.reshape(mask_o, (1, -1))
    order_n = mx.symbol.concat(order, order, dim=0)
    for i in range(n - args.images_per_identity - 2):
        order_n = mx.symbol.concat(order_n, order, dim=0)
    order_n = mx.symbol.reshape(order_n, (1, -1))
    order_n = mx.symbol.sort(order_n, axis=1)
    mask_n = mx.symbol.concat(order_n, mask_o, dim=0)

    dans = mx.symbol.gather_nd(dist, mask_n)

    dans = mx.symbol.reshape(dans, (n, -1))
    # ap_wei = F.softmax(daps.detach(), dim=1)
    # an_wei = F.softmax(-dans.detach(), dim=1)
    ap_wei = mx.symbol.softmax(daps, axis=1)
    an_wei = mx.symbol.softmax(-dans, axis=1)
    ap_wei_ng = mx.symbol.BlockGrad(ap_wei)
    an_wei_ng = mx.symbol.BlockGrad(an_wei)
    # dist_ap = (daps * ap_wei).sum(dim=1)
    # dist_an = (dans * an_wei).sum(dim=1)
    dist_ap = mx.symbol.broadcast_mul(daps, ap_wei_ng)
    dist_ap = mx.symbol.sum(dist_ap, axis=1)
    dist_an = mx.symbol.broadcast_mul(dans, an_wei_ng)
    dist_an = mx.symbol.sum(dist_an, axis=1)
    # loss = F.softplus(dist_ap - dist_an).mean()
    triplet_loss = mx.symbol.relu(dist_ap - dist_an + args.triplet_alpha)
    triplet_loss = mx.symbol.mean(triplet_loss)

    triplet_loss = mx.symbol.MakeLoss(triplet_loss)
    out_list = [mx.symbol.BlockGrad(embedding)]
    out_list.append(mx.sym.BlockGrad(gt_label))
    out_list.append(triplet_loss)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #13
0
def get_symbol(args, arg_params, aux_params):
    # data_shape = (args.image_channel, args.image_h, args.image_w)  # (3L,112L,112L)
    # image_shape = ",".join([str(x) for x in data_shape]) #3,112,112

    # margin_symbols = []
    print('***network: ', args.network)  # r100

    if args.network[0] == 'd':  # densenet
        embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
                                         version_se=args.version_se, version_input=args.version_input,
                                         version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'm':  # mobilenet
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(args.emb_size,
                                              version_se=args.version_se, version_input=args.version_input,
                                              version_output=args.version_output, version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    # elif args.network[0] == 'v':
    #     print('init MobileNet-V3', args.num_layers)
    #     embedding = fmobilenetv3.get_symbol(args.emb_size)
    elif args.network[0] == 'i':  # inception-resnet-v2
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(args.emb_size,
                                                    version_se=args.version_se, version_input=args.version_input,
                                                    version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se, version_input=args.version_input,
                                         version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
                                    version_se=args.version_se, version_input=args.version_input,
                                    version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output)
    else:  # 执行resnet
        print('init resnet, 层数: ', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    # get_symbol
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    # extra_loss = None
    # 重新定义fc7的权重
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  # softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7')
    elif args.loss_type == 1:  # sphere
        print('*******'*10)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta, margin=args.margin, scale=args.scale,
                              beta_min=args.beta_min, verbose=1000, name='fc7')
    elif args.loss_type == 2:  # CosineFace
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m,
                                    off_value=0.0)  # onehot两个值最大值s_m,最小值0.0
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:  # ArcFace
        s = args.margin_s  # 参数s, 64
        m = args.margin_m  # 参数m, 0.5

        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        # pdb.set_trace()
        # 权重归一化
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')  # shape = [(4253, 512)]
        # 特征归一化,并放大到 s*x
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')  # args.num_classes:8631

        zy = mx.sym.pick(fc7, gt_label, axis=1)  # fc7每一行找出gt_label对应的值, 即s*cos_t

        cos_t = zy / s  # 网络输出output = s*x/|x|*w/|w|*cos(theta), 这里将输出除以s,得到实际的cos值,即cos(theta)
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m  # sin(pi-m)*m = sin(m) * m  0.2397
        # threshold = 0.0
        threshold = math.cos(math.pi - m)  # 这个阈值避免theta+m >= pi, 实际上threshold<0 -cos(m)    -0.8775825618903726
        if args.easy_margin:  # 将0作为阈值,得到超过阈值的索引
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold  # 将负数作为阈值
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t  # cos_t^2 + sin_t^2 = 1
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m  # cos(t+m) = cos(t)cos(m) - sin(t)sin(m)
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s  # s*cos(t + m)
        if args.easy_margin:
            zy_keep = zy  # zy_keep为zy,即s*cos(theta)
        else:
            zy_keep = zy - s * mm  # zy-s*sin(m)*m = s*cos(t)- s*m*sin(m)
        new_zy = mx.sym.where(cond, new_zy,
                              zy_keep)  # cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)  # 对应yi处为new_zy - zy
        fc7 = fc7 + body  # 对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m)
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid')
    out_list.append(softmax)

    out = mx.symbol.Group(out_list)
    # print(out)
    # sys.exit()
    return (out, arg_params, aux_params)
Example #14
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(args.emb_size,
                                              bn_mom=args.bn_mom,
                                              wd_mult=args.fc7_wd_mult)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=1.0,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=s_m,
                                    off_value=0.0)
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')

        # split fc7 into fc7(labeled part) and fc7_2(unlabeled part)
        # calculate the likelihood function of the fc7_2 part
        subtract = args.num_classes * math.log(
            args.num_classes
        )  ############### This is calculated based on the number of labeled classes
        num_unlabeled = args.num_unlabeled
        num_labeled = args.per_batch_size - args.num_unlabeled
        per_batch_size = args.per_batch_size
        assert num_labeled > 0
        fc7_2 = mx.sym.slice_axis(fc7,
                                  axis=0,
                                  begin=num_labeled,
                                  end=per_batch_size)
        fc7 = mx.sym.slice_axis(fc7, axis=0, begin=0, end=num_labeled)
        fc7_2 = mx.sym.softmax(data=fc7_2, axis=1)

        log_likelihood = -mx.sym.log_softmax(fc7_2, axis=1)
        log_likelihood_loss = (mx.sym.sum(log_likelihood) / num_unlabeled)
        log_likelihood_loss = log_likelihood_loss - subtract

        gt_label = mx.sym.slice_axis(gt_label,
                                     axis=0,
                                     begin=0,
                                     end=num_labeled)
        ############################################################

        #last_to_append = mx.symbol.BlockGrad(mx.sym.softmax(fc7))
        out_list = []

        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    out_list.append(mx.symbol.BlockGrad(embedding))

    if args.loss_type == 4:
        #softmax = mx.sym.softmax_cross_entropy(data=fc7, label=gt_label)
        cos_fc7 = mx.symbol.BlockGrad(
            fc7)  # TODO: maybe this softmax makes the accuracy inconsistent
        gt_label_one_hot = mx.sym.one_hot(gt_label, args.num_classes)
        fc7 = -mx.sym.log_softmax(fc7, axis=1)
        softmax = gt_label_one_hot * fc7
        softmax = mx.sym.sum(softmax) / (args.per_batch_size -
                                         args.num_unlabeled)
        scale = args.likelihood_scale_factor
        loss = mx.sym.make_loss(data=(softmax + log_likelihood_loss * scale))
        out_list.append(loss)
        out_list.append(cos_fc7)
        out_list.append(mx.sym.BlockGrad(softmax))
        out_list.append(mx.sym.BlockGrad(log_likelihood_loss))
    else:
        softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                          label=gt_label,
                                          name='softmax',
                                          normalization='valid')
        out_list.append(softmax)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #15
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(
            args.emb_size,
            bn_mom=args.bn_mom,
            version_output=args.version_output)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        if args.fc7_no_bias:
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=args.num_classes,
                                        name='fc7')
        else:
            _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        bias=_bias,
                                        num_hidden=args.num_classes,
                                        name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=s_m,
                                    off_value=0.0)
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)

        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        if args.finetune:
            print("{}finetuning from trained model".format('-' * 10))
            _weight = mx.symbol.Variable("finetune_weight",
                                         shape=(args.num_classes,
                                                args.emb_size),
                                         lr_mult=10.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    out_list.append(softmax)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
Example #16
0
        'stage3_unit12_bn2', 'stage3_unit12_bn3', 'stage3_unit13_bn2',
        'stage3_unit13_bn3', 'stage3_unit14_bn2', 'stage3_unit14_bn3',
        'stage4_unit1_bn2', 'stage4_unit1_bn3', 'stage4_unit1_sc',
        'stage4_unit2_bn2', 'stage4_unit2_bn3', 'stage4_unit3_bn2',
        'stage4_unit3_bn3'
    ]

    assert (len(conv_names) == len(bn_prefixes))
    for i in xrange(len(conv_names)):
        conv_name = conv_names[i]
        bn_prefix = bn_prefixes[i]
        merge_bn(arg_params, aux_params, conv_name, bn_prefix)

    emb_size = 128
    num_layers = 50
    version_se = 0
    version_input = 1
    version_output = 'E'
    version_unit = 3
    version_act = 'prelu'
    nobn_sym = fresnet.get_symbol(emb_size,
                                  num_layers,
                                  version_se=version_se,
                                  version_input=version_input,
                                  version_output=version_output,
                                  version_unit=version_unit,
                                  version_act=version_act)

    mx.model.save_checkpoint('mergebn_test', 0, nobn_sym, arg_params,
                             aux_params)
def get_symbol(args, arg_params, aux_params):
    if args.retrain:
        new_args = arg_params
    else:
        new_args = None
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    if args.network[0] == 's':
        embedding = spherenet.get_symbol(512, args.num_layers)
    elif args.network[0] == 'm':
        print('init marginal', args.num_layers)
        embedding = marginalnet.get_symbol(512, args.num_layers)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(512)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding, _ = xception.get_xception_symbol(512)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(512,
                                       args.num_layers,
                                       use_se=args.use_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit)
    gt_label = mx.symbol.Variable('softmax_label')
    assert args.loss_type >= 0
    extra_loss = None
    if args.loss_type == 0:
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
    elif args.loss_type == 1:
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, 512),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=100,
                              name='fc7')
    elif args.loss_type == 10:
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        nembedding = mx.symbol.L2Normalization(embedding,
                                               mode='instance',
                                               name='fc1n')
        params = [1.2, 0.3, 1.0]
        n1 = mx.sym.expand_dims(nembedding, axis=1)
        n2 = mx.sym.expand_dims(nembedding, axis=0)
        body = mx.sym.broadcast_sub(n1, n2)  #N,N,C
        body = body * body
        body = mx.sym.sum(body, axis=2)  # N,N
        #body = mx.sym.sqrt(body)
        body = body - params[0]
        mask = mx.sym.Variable('extra')
        body = body * mask
        body = body + params[1]
        #body = mx.sym.maximum(body, 0.0)
        body = mx.symbol.Activation(data=body, act_type='relu')
        body = mx.sym.sum(body)
        body = body / (args.per_batch_size * args.per_batch_size -
                       args.per_batch_size)
        extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
    elif args.loss_type == 11:
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        params = [0.9, 0.2]
        nembedding = mx.symbol.slice_axis(embedding,
                                          axis=0,
                                          begin=0,
                                          end=args.images_per_identity)
        nembedding = mx.symbol.L2Normalization(nembedding,
                                               mode='instance',
                                               name='fc1n')
        n1 = mx.sym.expand_dims(nembedding, axis=1)
        n2 = mx.sym.expand_dims(nembedding, axis=0)
        body = mx.sym.broadcast_sub(n1, n2)  #N,N,C
        body = body * body
        body = mx.sym.sum(body, axis=2)  # N,N
        body = body - params[0]
        body = mx.symbol.Activation(data=body, act_type='relu')
        body = mx.sym.sum(body)
        n = args.images_per_identity
        body = body / (n * n - n)
        extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[1])
        #extra_loss = None
    else:
        #embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
        embedding = embedding * 5
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, 512),
                                     lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=100,
                              name='fc7')

        #fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
        #                       beta=args.beta, margin=args.margin, scale=args.scale,
        #                       op_type='ASoftmax', name='fc7')
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    if args.loss_type <= 1 and args.incay > 0.0:
        params = [1.e-10]
        sel = mx.symbol.argmax(data=fc7, axis=1)
        sel = (sel == gt_label)
        norm = embedding * embedding
        norm = mx.symbol.sum(norm, axis=1)
        norm = norm + params[0]
        feature_incay = sel / norm
        feature_incay = mx.symbol.mean(feature_incay) * args.incay
        extra_loss = mx.symbol.MakeLoss(feature_incay)
    #out = softmax
    #l2_embedding = mx.symbol.L2Normalization(embedding)

    #ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
    #out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
    if extra_loss is not None:
        out = mx.symbol.Group(
            [mx.symbol.BlockGrad(embedding), softmax, extra_loss])
    else:
        out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax])
    return (out, new_args, aux_params)
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  elif args.network[0] == 'c':
    print('init crunet', args.num_layers)
    embedding = fcrunet.get_symbol(args.emb_size, args.num_layers)
  elif args.network[0] == 'a':
    print('init residual attention network', args.num_layers)
    embedding = fresattnet.get_symbol(args.emb_size, args.num_layers)
    
  else:
    print('init resnet', args.num_layers)
    ibn = args.ibn
    if ibn:
        embedding = fresnet_ibn_a.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input,
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act, ibn=args.ibn)
    else:
        embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act, stn=args.stn, stn1=args.stn1, stn2=args.stn2, stn3=args.stn3, stn4=args.stn4)
  
  all_label = mx.symbol.Variable('softmax_label')
  gt_label = all_label
  extra_loss = None

  if args.loss_type==0: #softmax
    _weight = mx.symbol.Variable('fc7_weight')
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)

    fc7 = mx.sym.FullyConnected(
      data=embedding, weight=_weight,
      bias=_bias, num_hidden=args.num_classes,
      name='fc7'
    )

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

  elif args.loss_type==1: #sphereface
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(
        _weight, mode='instance', name='fc7_weight_n')

    fc7 = mx.sym.LSoftmax(
      data=embedding, label=gt_label,
      num_hidden=args.num_classes, weight = _weight,
      beta=args.beta, margin=args.margin,
      scale=args.scale, beta_min=args.beta_min,
      verbose=1000, name='fc7_lsoftmax'
    )

    # for softmax ACC computation
    fc_pred = mx.sym.FullyConnected(
      data=embedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7'
    )

  elif args.loss_type==2:
    s = args.margin_s
    m = args.margin_m
    assert(s>0.0)
    assert(m>0.0)
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(
        _weight, mode='instance', name='fc7_weight_n')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s

    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7'
    )
    
    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    s_m = s*m
    gt_one_hot = mx.sym.one_hot(
      gt_label, depth=args.num_classes,
      on_value=s_m, off_value=0.0)
      
    fc7 = fc7 - gt_one_hot
    
  elif args.loss_type==4: #ArcFace
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m < (math.pi / 2)
    
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight, no_bias=True,
      num_hidden=args.num_classes, name='fc7'
    )

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    zy = mx.sym.pick(fc7, gt_label, axis=1)

    cos_t = zy/s
    cos_m = math.cos(m)
    sin_m = math.sin(m)
    mm = math.sin(math.pi-m)*m
    #threshold = 0.0
    threshold = math.cos(math.pi-m)
    if args.easy_margin:
      cond = mx.symbol.Activation(data=cos_t, act_type='relu')
    else:
      cond_v = cos_t - threshold
      cond = mx.symbol.Activation(data=cond_v, act_type='relu')
    body = cos_t*cos_t
    body = 1.0-body
    sin_t = mx.sym.sqrt(body)
    new_zy = cos_t*cos_m
    b = sin_t*sin_m
    new_zy = new_zy - b
    new_zy = new_zy * s
    
    if args.easy_margin:
      zy_keep = zy
    else:
      zy_keep = zy - s*mm
    new_zy = mx.sym.where(cond, new_zy, zy_keep)

    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7 + body
    
  elif args.loss_type==5:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
    
    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')
    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
      if args.margin_a==1.0 and args.margin_m==0.0:
        s_m = s*args.margin_b
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
      else:
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy/s
        t = mx.sym.arccos(cos_t)
        if args.margin_a!=1.0:
          t = t*args.margin_a
        if args.margin_m>0.0:
          t = t+args.margin_m
        body = mx.sym.cos(t)
        if args.margin_b>0.0:
          body = body - args.margin_b
        new_zy = body*s
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7+body

  elif args.loss_type == 6: #spa loss
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b
    assert (s >= 1.0 and
            m >= 1.0 and
            b >= 0.0)

    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    if s>0.0:
      nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
      fc7 = mx.sym.FullyConnected(
        data=nembedding, weight=_weight,
        no_bias=True, num_hidden=args.num_classes,
        name='fc7')

      # for softmax ACC computation
      # fc_pred = mx.sym.FullyConnected(
      #   data=nembedding, weight=_weight,
      #   no_bias=True, num_hidden=args.num_classes,
      #   name='fc_pred'
      # )
      fc_pred = fc7

      if m>1.0:
        s_m = s*(m - 1 + b )
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7 * m - gt_one_hot

    else:
      fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
      
      # for softmax ACC computation
      # fc_pred = mx.sym.FullyConnected(
      #   data=nembedding, weight=_weight,
      #   no_bias=True, num_hidden=args.num_classes,
      #   name='fc_pred'
      # )
      fc_pred = fc7

      if m>1.0:
        body = embedding*embedding
        body = mx.sym.sum_axis(body, axis=1, keepdims=True)
        body = mx.sym.sqrt(body)
        body = body * (m - 1 + b)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, body)
        fc7 = fc7 * m - body

  elif args.loss_type == 61: #spa-v1 loss fixed
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b
    assert (s >= 1.0 and
            m >= 1.0 and
            b >= 0.0)

    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')
 
    fc_pred = fc7

    if m > 1.0:
      cos_theta = fc7.clip(-s, s) / s  # clip cosine into [-1, 1]
      cos_theta_1 = cos_theta - 1
      # s_m = s*(m - 1 + b )
      gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1, off_value = 0.0)
      cos_theta = cos_theta_1 + cos_theta_1 * gt_one_hot * (m - 1)

      if b > 0:
        cos_theta -= gt_one_hot * b

      fc7 = cos_theta * s
               
  elif args.loss_type == 63:  # spa-v3 loss
    s = args.margin_s
    m = args.margin_m
    assert (s >= 1.0 and
            m >= 0.0)

    _weight = mx.symbol.Variable("fc7_weight", shape=(
        args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

    fc7 = mx.sym.FullyConnected(
        data=nembedding, weight=_weight,
        no_bias=True, num_hidden=args.num_classes,
        name='fc7')

    fc_pred = fc7

    if m > 0:
      cos_theta = fc7.clip(-s, s) / s  # clip cosine into [-1, 1]
      theta = mx.sym.arccos(cos_theta)

      gt_one_hot = mx.sym.one_hot(
          gt_label, depth=args.num_classes, on_value=1, off_value=0.0)
      cos_theta = cos_theta - gt_one_hot * theta * m
    
      fc7 = cos_theta * s

  elif args.loss_type == 64:  # spa-v4 loss
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b

    if b < 1.0:
      b=1.0
    assert (s >= 1.0 and
            m >= b and
            b >= 1.0)

    _weight = mx.symbol.Variable("fc7_weight", shape=(
        args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

    fc7 = mx.sym.FullyConnected(
        data=nembedding, weight=_weight,
        no_bias=True, num_hidden=args.num_classes,
        name='fc7')

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7
    
    cos_theta = fc7.clip(-s, s) / s  # clip cosine into [-1, 1]

    # calculate angle theta and normalize into [0, 1.0]
    theta = mx.sym.arccos(cos_theta) * (1.0 / np.pi)

    s *= 2  # make the output have a region span of 2 like cos(x)

    if m > b:
      gt_one_hot = mx.sym.one_hot(
          gt_label, depth=args.num_classes, on_value=1, off_value=0.0)

      fc7 = (theta * b + gt_one_hot * theta * (m-b)) * (-s)
    else:
      fc7 = theta * (-s)

  elif args.loss_type == 65:  # spa-v5 loss
    s = args.margin_s
    m = args.margin_m
    assert (s >= 1.0 and
            m >= 0 and
            m < 180)

    _weight = mx.symbol.Variable("fc7_weight", shape=(
        args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    m = m / 180.0

    nembedding = mx.symbol.L2Normalization( embedding, mode='instance', name='fc1n')

    fc7 = mx.sym.FullyConnected(
        data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes, name='fc7')

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    cos_theta = fc7.clip(-s, s) / s  # clip cosine into [-1, 1]

    # calculate angle theta and normalize into [0, 1.0]
    theta = mx.sym.arccos(cos_theta) * (1.0 / np.pi)

    s *= 2  # make the output have a region span of 2 like cos(x)

    if m > 0:
      gt_one_hot = mx.sym.one_hot(
          gt_label, depth=args.num_classes, on_value=1, off_value=0.0)

      fc7 = (theta + gt_one_hot * m) * (-s)
    else:
      fc7 = theta * (-s)

  elif args.loss_type == 7: #combine spa loss
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b
    a = args.margin_a

    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    assert m>=1.0
    assert b>=0.0
    assert s>0.0

    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    t = mx.sym.arccos(cos_t)
    t = t * a
    new_zy = mx.sym.cos(t) * s
    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)

    s_m = s*(m - 1)
    gt_one_hot2 = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
    fc7 = m * (fc7 + body) - gt_one_hot2

  elif args.loss_type == 8: #Adaloss
    s = args.margin_s
    m = args.margin_m
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy / s
    t = mx.sym.arccos(cos_t)
    body = ((1-mx.sym.sin(1.0*t/(2*m)))* mx.sym.cos(1.0 *t/m) * 2) - 1
    new_zy = body * s
    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7 + body

  elif args.loss_type == 9: #semi hard loss, have fault
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b
    a = args.margin_a
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    assert m>=1.0
    assert b>=0.0

    if s>0.0:
      nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
      fc7 = mx.sym.FullyConnected(
        data=nembedding, weight=_weight,
        no_bias=True, num_hidden=args.num_classes,
        name='fc7')

      # for softmax ACC computation
      # fc_pred = mx.sym.FullyConnected(
      #   data=nembedding, weight=_weight,
      #   no_bias=True, num_hidden=args.num_classes,
      #   name='fc_pred'
      # )
      fc_pred = fc7

      zy = mx.sym.pick(fc7, gt_label, axis=1)
      cos_t = zy/s
      t = mx.sym.arccos(cos_t)
      bounding = mx.sym.Variable('')
        
      s_m = s*(m - 1 + b )
      gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
      fc7 = fc7 * m - gt_one_hot

  elif args.loss_type==10:#combine intra loss
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert args.margin_b>0.0
    b = args.margin_b
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #   data=nembedding, weight=_weight,
    #   no_bias=True, num_hidden=args.num_classes,
    #   name='fc_pred'
    # )
    fc_pred = fc7

    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    t = mx.sym.arccos(cos_t)
    intra_loss = t/np.pi
    intra_loss = mx.sym.mean(intra_loss)
    intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = args.margin_b)

    s_m = s*(m - 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
    fc7 = fc7 * m - gt_one_hot

  elif args.loss_type == 11: #reweight
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b

    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    assert m>=1.0
    assert b>=0.0
    #reweight
    spatial_norm = embedding * embedding
    spatial_norm = mx.sym.sum(data=spatial_norm, axis=1, keepdims=True)
    spatial_sqrt = mx.sym.sqrt(spatial_norm)
    spatial_mean = mx.sym.mean(spatial_sqrt)
    spatial_div_inverse = mx.sym.broadcast_div(spatial_mean, spatial_sqrt)
    reweight_s = s * spatial_div_inverse
        
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')

    # for softmax ACC computation
    fc_pred = mx.sym.FullyConnected(
      data=embedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc_pred'
    ) * s
    # fc_pred = fc7

    nembedding = mx.symbol.broadcast_mul(nembedding, reweight_s)

    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')

    s_m = s*(m - 1 + b )
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
    fc7 = fc7 * m - gt_one_hot

  elif args.loss_type == 12: #hard example margin
    s = args.margin_s
    m = args.margin_m
    b = args.margin_b

    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance', name='fc7_weight_n')

    assert m>=1.0
    assert b>=0.0

    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(
      data=nembedding, weight=_weight,
      no_bias=True, num_hidden=args.num_classes,
      name='fc7')

    # for softmax ACC computation
    # fc_pred = mx.sym.FullyConnected(
    #     data=embedding, weight=_weight,
    #     no_bias=True, num_hidden=args.num_classes,
    #     name='fc_pred'
    # )
    fc_pred = fc7

    predict_label = mx.sym.argmax(fc7, axis=1)
    wrong_label_mask = predict_label.__lt__(gt_label)
    wrong_label = predict_label * wrong_label_mask

    s_m = s*(m - 1 + b )
    wrong_label_one_hot = mx.sym.one_hot(wrong_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
    fc7 = fc7 * m - wrong_label_one_hot


  out_list = [mx.symbol.BlockGrad(embedding)]
  softmax = mx.symbol.SoftmaxOutput(
    data=fc7, label=gt_label,
    name='softmax', normalization='valid')
  out_list.append(softmax)

  # for softmax ACC computation
  pred_softmax = mx.symbol.SoftmaxOutput(
      data=fc_pred, label=gt_label,
      name='pred_softmax', normalization='valid')
  out_list.append(mx.symbol.BlockGrad(pred_softmax))

  out_list.append(mx.symbol.BlockGrad(_weight))

  if args.loss_type == 10:
    out_list.append(intra_loss)

  if args.ce_loss:
    #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
    body = mx.symbol.SoftmaxActivation(data=fc7)
    body = mx.symbol.log(body)
    _label = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = -1.0, off_value = 0.0)
    body = body*_label
    ce_loss = mx.symbol.sum(body)/args.per_batch_size
    out_list.append(mx.symbol.BlockGrad(ce_loss))

  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)