コード例 #1
0
ファイル: train.py プロジェクト: Soncaajp/SF_version_2
def run_Affectnet_training():
    config = {
        'batch_size':
        64,
        'val_batch_size':
        40,
        'img_size': (112, 112),  # (128, 128),
        'metric_update_period':
        50,
        'layers':
        50,
        'load_epoch':
        0,
        #'load_path': '/media/nlab/data/test/resnext50-valence-llr',
        'save_model_prefix':
        '/media/nlab/data/SF/mbnet-singleframe',
        'emotions_list': [
            'Neutral', 'Happy', 'Sad', 'Surprise', 'Fear', 'Anger', 'Disgust',
            'Contempt'
        ],
        # 'multiply_basic_ratio': 4
    }
    train_iter = AffectnetIter(data_json_path='../training.csv',
                               batch_size=config['batch_size'],
                               train=True,
                               img_size=config['img_size'],
                               detector=None)

    train_iter.global_num_inst = int(
        train_iter.n_objects /
        config['batch_size']) * config['batch_size'] * config['load_epoch']

    fc1 = fmobilefacenet.get_symbol()
    module = mx.mod.Module(fc1, context=mx.gpu(0))
    module.bind(data_shapes=train_iter.provide_data,
                label_shapes=train_iter.provide_label)
    module.init_params(arg_params=None,
                       aux_params=None,
                       initializer=mx.init.MSRAPrelu(),
                       allow_missing=True)

    val_iter = AffectnetIter(data_json_path='../validation.csv',
                             batch_size=config['val_batch_size'],
                             train=False,
                             img_size=config['img_size'],
                             detector=None)
    train_Affectnet(module, train_iter, val_iter, config)
コード例 #2
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(
            args.emb_size,
            bn_mom=args.bn_mom,
            version_output=args.version_output)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        if args.fc7_no_bias:
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=args.num_classes,
                                        name='fc7')
        else:
            _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        bias=_bias,
                                        num_hidden=args.num_classes,
                                        name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=s_m,
                                    off_value=0.0)
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)

        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(args.num_classes, args.emb_size),
                                     lr_mult=1.0)
        if args.finetune:
            print("{}finetuning from trained model".format('-' * 10))
            _weight = mx.symbol.Variable("finetune_weight",
                                         shape=(args.num_classes,
                                                args.emb_size),
                                         lr_mult=10.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    out_list.append(softmax)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
コード例 #3
0
        model.forward(db, is_train=False)
        embedding = model.get_outputs()[0].asnumpy()
        #embedding = sklearn.preprocessing.normalize(embedding).flatten()
        end_time = time.time()
        embedding_time += end_time - start_time
        #print('cost of generate features:' + str(end_time - start_time))

    return read_img_time / loop_time, crop_time / loop_time, embedding_time / loop_time


ave_image_read_dict = {}
ave_crop_dict = {}
ave_embedding_dict = {}

# 原始模型
embedding = fmobilefacenet.get_symbol(128, bn_mom=0.9, version_output='GNAP')
detector = MtcnnDetector(model_folder=mtcnn_path,
                         ctx=mx.cpu(0),
                         num_worker=1,
                         accurate_landmark=True,
                         threshold=[0.6, 0.7, 0.8])
ave_read_image_time, ave_crop_time, ave_embedding_time = cal_time_cost(
    embedding, detector, 50)
print(ave_read_image_time, ave_crop_time, ave_embedding_time)
ave_image_read_dict['orignal'] = ave_read_image_time
ave_crop_dict['orignal'] = ave_crop_time
ave_embedding_dict['orignal'] = ave_embedding_time

# 去掉45,5层
embedding = fmobilefacenet.get_symbol1(128, bn_mom=0.9, version_output='GNAP')
detector = MtcnnDetector(model_folder=mtcnn_path,
コード例 #4
0
ファイル: train_softmax.py プロジェクト: xiyou1024/CompreFace
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_input=args.version_input,
                version_output=args.version_output,
                version_multiplier=args.version_multiplier)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(
            args.emb_size,
            bn_mom=args.bn_mom,
            version_output=args.version_output)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        if args.fc7_no_bias:
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=args.num_classes,
                                        name='fc7')
        else:
            _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        bias=_bias,
                                        num_hidden=args.num_classes,
                                        name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=s_m,
                                    off_value=0.0)
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    elif args.loss_type == 6:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert args.margin_b > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        t = mx.sym.arccos(cos_t)
        intra_loss = t / np.pi
        intra_loss = mx.sym.mean(intra_loss)
        #intra_loss = mx.sym.exp(cos_t*-1.0)
        intra_loss = mx.sym.MakeLoss(intra_loss,
                                     name='intra_loss',
                                     grad_scale=args.margin_b)
        if m > 0.0:
            t = t + m
            body = mx.sym.cos(t)
            new_zy = body * s
            diff = new_zy - zy
            diff = mx.sym.expand_dims(diff, 1)
            gt_one_hot = mx.sym.one_hot(gt_label,
                                        depth=args.num_classes,
                                        on_value=1.0,
                                        off_value=0.0)
            body = mx.sym.broadcast_mul(gt_one_hot, diff)
            fc7 = fc7 + body
    elif args.loss_type == 7:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert args.margin_b > 0.0
        assert args.margin_a > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        t = mx.sym.arccos(cos_t)

        #counter_weight = mx.sym.take(_weight, gt_label, axis=1)
        #counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
        counter_weight = mx.sym.take(_weight, gt_label, axis=0)
        counter_cos = mx.sym.dot(counter_weight, _weight, transpose_b=True)
        #counter_cos = mx.sym.minimum(counter_cos, 1.0)
        #counter_angle = mx.sym.arccos(counter_cos)
        #counter_angle = counter_angle * -1.0
        #counter_angle = counter_angle/np.pi #[0,1]
        #inter_loss = mx.sym.exp(counter_angle)

        #counter_cos = mx.sym.dot(_weight, _weight, transpose_b=True)
        #counter_cos = mx.sym.minimum(counter_cos, 1.0)
        #counter_angle = mx.sym.arccos(counter_cos)
        #counter_angle = mx.sym.sort(counter_angle, axis=1)
        #counter_angle = mx.sym.slice_axis(counter_angle, axis=1, begin=0,end=int(args.margin_a))

        #inter_loss = counter_angle*-1.0 # [-1,0]
        #inter_loss = inter_loss+1.0 # [0,1]
        inter_loss = counter_cos
        inter_loss = mx.sym.mean(inter_loss)
        inter_loss = mx.sym.MakeLoss(inter_loss,
                                     name='inter_loss',
                                     grad_scale=args.margin_b)
        if m > 0.0:
            t = t + m
            body = mx.sym.cos(t)
            new_zy = body * s
            diff = new_zy - zy
            diff = mx.sym.expand_dims(diff, 1)
            gt_one_hot = mx.sym.one_hot(gt_label,
                                        depth=args.num_classes,
                                        on_value=1.0,
                                        off_value=0.0)
            body = mx.sym.broadcast_mul(gt_one_hot, diff)
            fc7 = fc7 + body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    out_list.append(softmax)
    if args.loss_type == 6:
        out_list.append(intra_loss)
    if args.loss_type == 7:
        out_list.append(inter_loss)
        #out_list.append(mx.sym.BlockGrad(counter_weight))
        #out_list.append(intra_loss)
    if args.ce_loss:
        #ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
        body = mx.symbol.SoftmaxActivation(data=fc7)
        body = mx.symbol.log(body)
        _label = mx.sym.one_hot(gt_label,
                                depth=args.num_classes,
                                on_value=-1.0,
                                off_value=0.0)
        body = body * _label
        ce_loss = mx.symbol.sum(body) / args.per_batch_size
        out_list.append(mx.symbol.BlockGrad(ce_loss))
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
コード例 #5
0
ファイル: train_softmax.py プロジェクト: LHQ0308/insightface
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  elif args.network[0]=='y':
    print('init mobilefacenet', args.num_layers)
    embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act)
  all_label = mx.symbol.Variable('softmax_label')
  gt_label = all_label
  extra_loss = None
  _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult)
  if args.loss_type==0: #softmax
    _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
    fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
  elif args.loss_type==1: #sphere
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                          weight = _weight,
                          beta=args.beta, margin=args.margin, scale=args.scale,
                          beta_min=args.beta_min, verbose=1000, name='fc7')
  elif args.loss_type==2:
    s = args.margin_s
    m = args.margin_m
    assert(s>0.0)
    assert(m>0.0)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    s_m = s*m
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
    fc7 = fc7-gt_one_hot
  elif args.loss_type==4:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    cos_m = math.cos(m)
    sin_m = math.sin(m)
    mm = math.sin(math.pi-m)*m
    #threshold = 0.0
    threshold = math.cos(math.pi-m)
    if args.easy_margin:
      cond = mx.symbol.Activation(data=cos_t, act_type='relu')
    else:
      cond_v = cos_t - threshold
      cond = mx.symbol.Activation(data=cond_v, act_type='relu')
    body = cos_t*cos_t
    body = 1.0-body
    sin_t = mx.sym.sqrt(body)
    new_zy = cos_t*cos_m
    b = sin_t*sin_m
    new_zy = new_zy - b
    new_zy = new_zy*s
    if args.easy_margin:
      zy_keep = zy
    else:
      zy_keep = zy - s*mm
    new_zy = mx.sym.where(cond, new_zy, zy_keep)

    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
  elif args.loss_type==5:
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
      if args.margin_a==1.0 and args.margin_m==0.0:
        s_m = s*args.margin_b
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
      else:
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy/s
        t = mx.sym.arccos(cos_t)
        if args.margin_a!=1.0:
          t = t*args.margin_a
        if args.margin_m>0.0:
          t = t+args.margin_m
        body = mx.sym.cos(t)
        if args.margin_b>0.0:
          body = body - args.margin_b
        new_zy = body*s
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7+body
  out_list = [mx.symbol.BlockGrad(embedding)]
  softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
  out_list.append(softmax)
  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
コード例 #6
0
ファイル: train_age.py プロジェクト: zmoon111/insightface
def get_symbol(args, arg_params, aux_params):
  data_shape = (args.image_channel,args.image_h,args.image_w)
  image_shape = ",".join([str(x) for x in data_shape])
  margin_symbols = []
  if args.network[0]=='d':
    embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='m':
    print('init mobilenet', args.num_layers)
    if args.num_layers==1:
      embedding = fmobilenet.get_symbol(args.emb_size, 
          version_se=args.version_se, version_input=args.version_input, 
          version_output=args.version_output, version_unit=args.version_unit)
    else:
      embedding = fmobilenetv2.get_symbol(args.emb_size)
  elif args.network[0]=='i':
    print('init inception-resnet-v2', args.num_layers)
    embedding = finception_resnet_v2.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='x':
    print('init xception', args.num_layers)
    embedding = fxception.get_symbol(args.emb_size,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='p':
    print('init dpn', args.num_layers)
    embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit)
  elif args.network[0]=='n':
    print('init nasnet', args.num_layers)
    embedding = fnasnet.get_symbol(args.emb_size)
  elif args.network[0]=='s':
    print('init spherenet', args.num_layers)
    embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
  elif args.network[0]=='y':
    print('init mobilefacenet', args.num_layers)
    embedding = fmobilefacenet.get_symbol(args.emb_size)
  else:
    print('init resnet', args.num_layers)
    embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
        version_se=args.version_se, version_input=args.version_input, 
        version_output=args.version_output, version_unit=args.version_unit,
        version_act=args.version_act)
  all_label = mx.symbol.Variable('softmax_label')
  gt_label = all_label
  extra_loss = None
  s = args.margin_s
  #m = args.margin_m
  assert s>0.0
  nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
  out_list = [mx.symbol.BlockGrad(embedding)]

  _args = copy.deepcopy(args)

  if USE_FR:
      _args.grad_scale = 1.0
      fr_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
      fr_label = mx.symbol.reshape(fr_label, (args.per_batch_size,))
      fr_softmax = get_softmax(_args, embedding, nembedding, fr_label, 'fc7')
      out_list.append(fr_softmax)

  if USE_GENDER:
      _args.grad_scale = 0.2
      _args.margin_a = 0.0
      _args.num_classes = 2
      gender_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
      gender_label = mx.symbol.reshape(gender_label, (args.per_batch_size,))
      gender_softmax = get_softmax(_args, embedding, nembedding, gender_label, 'fc8')
      out_list.append(gender_softmax)

  if USE_AGE:
      _args.grad_scale = 0.01
      _args.margin_a = 0.0
      _args.num_classes = 2
      for i in xrange(AGE):
          age_label = mx.symbol.slice_axis(all_label, axis=1, begin=2+i, end=3+i)
          age_label = mx.symbol.reshape(age_label, (args.per_batch_size,))
          age_softmax = get_softmax(_args, embedding, nembedding, age_label, 'fc9_%d'%(i))
          out_list.append(age_softmax)

  out = mx.symbol.Group(out_list)
  return (out, arg_params, aux_params)
コード例 #7
0
def get_symbol(args, arg_params, aux_params):
    # define network
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []

    args.num_layers = 1
    print('init mobilefacenet', args.num_layers)
    embedding = fmobilefacenet.get_symbol(args.emb_size,
                                          bn_mom=args.bn_mom,
                                          wd_mult=args.fc7_wd_mult)

    # define loss
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=1.0,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body

    out_list = [mx.symbol.BlockGrad(embedding)]
    logit_t_val = mx.symbol.Variable('logit_t')
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    soft_loss = mx.symbol.mean(
        mx.symbol.square(fc7 / args.tau - logit_t_val / args.tau))
    log_softmax = mx.sym.log_softmax(fc7)
    hard_loss = -mx.sym.sum(mx.sym.broadcast_mul(
        gt_one_hot, log_softmax)) / args.batch_size
    total_loss = soft_loss * args.lamda + hard_loss * (1 - args.lamda)
    total_loss = mx.symbol.MakeLoss(total_loss)
    out_list.append(mx.sym.BlockGrad(softmax))
    out_list.append(total_loss)
    out = mx.sym.Group(out_list)
    # out = mx.sym.Group([mx.sym.BlockGrad(embedding), softmax])
    return (out, arg_params, aux_params)
コード例 #8
0
def get_symbol(args, arg_params, aux_params):
    # data_shape = (args.image_channel, args.image_h, args.image_w)  # (3L,112L,112L)
    # image_shape = ",".join([str(x) for x in data_shape]) #3,112,112

    # margin_symbols = []
    print('***network: ', args.network)  # r100

    if args.network[0] == 'd':  # densenet
        embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
                                         version_se=args.version_se, version_input=args.version_input,
                                         version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'm':  # mobilenet
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(args.emb_size,
                                              version_se=args.version_se, version_input=args.version_input,
                                              version_output=args.version_output, version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    # elif args.network[0] == 'v':
    #     print('init MobileNet-V3', args.num_layers)
    #     embedding = fmobilenetv3.get_symbol(args.emb_size)
    elif args.network[0] == 'i':  # inception-resnet-v2
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(args.emb_size,
                                                    version_se=args.version_se, version_input=args.version_input,
                                                    version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se, version_input=args.version_input,
                                         version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
                                    version_se=args.version_se, version_input=args.version_input,
                                    version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom=args.bn_mom, version_output=args.version_output)
    else:  # 执行resnet
        print('init resnet, 层数: ', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    # get_symbol
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    # extra_loss = None
    # 重新定义fc7的权重
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  # softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7')
    elif args.loss_type == 1:  # sphere
        print('*******'*10)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta, margin=args.margin, scale=args.scale,
                              beta_min=args.beta_min, verbose=1000, name='fc7')
    elif args.loss_type == 2:  # CosineFace
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s

        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m,
                                    off_value=0.0)  # onehot两个值最大值s_m,最小值0.0
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:  # ArcFace
        s = args.margin_s  # 参数s, 64
        m = args.margin_m  # 参数m, 0.5

        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        # pdb.set_trace()
        # 权重归一化
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')  # shape = [(4253, 512)]
        # 特征归一化,并放大到 s*x
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')  # args.num_classes:8631

        zy = mx.sym.pick(fc7, gt_label, axis=1)  # fc7每一行找出gt_label对应的值, 即s*cos_t

        cos_t = zy / s  # 网络输出output = s*x/|x|*w/|w|*cos(theta), 这里将输出除以s,得到实际的cos值,即cos(theta)
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m  # sin(pi-m)*m = sin(m) * m  0.2397
        # threshold = 0.0
        threshold = math.cos(math.pi - m)  # 这个阈值避免theta+m >= pi, 实际上threshold<0 -cos(m)    -0.8775825618903726
        if args.easy_margin:  # 将0作为阈值,得到超过阈值的索引
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold  # 将负数作为阈值
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t  # cos_t^2 + sin_t^2 = 1
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m  # cos(t+m) = cos(t)cos(m) - sin(t)sin(m)
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s  # s*cos(t + m)
        if args.easy_margin:
            zy_keep = zy  # zy_keep为zy,即s*cos(theta)
        else:
            zy_keep = zy - s * mm  # zy-s*sin(m)*m = s*cos(t)- s*m*sin(m)
        new_zy = mx.sym.where(cond, new_zy,
                              zy_keep)  # cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)  # 对应yi处为new_zy - zy
        fc7 = fc7 + body  # 对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m)
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid')
    out_list.append(softmax)

    out = mx.symbol.Group(out_list)
    # print(out)
    # sys.exit()
    return (out, arg_params, aux_params)
コード例 #9
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size,
                                         args.num_layers,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(
                args.emb_size,
                version_se=args.version_se,
                version_input=args.version_input,
                version_output=args.version_output,
                version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(
            args.emb_size,
            version_se=args.version_se,
            version_input=args.version_input,
            version_output=args.version_output,
            version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se,
                                         version_input=args.version_input,
                                         version_output=args.version_output,
                                         version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size,
                                    args.num_layers,
                                    version_se=args.version_se,
                                    version_input=args.version_input,
                                    version_output=args.version_output,
                                    version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0] == 'y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(args.emb_size,
                                              bn_mom=args.bn_mom,
                                              wd_mult=args.fc7_wd_mult)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size,
                                       args.num_layers,
                                       version_se=args.version_se,
                                       version_input=args.version_input,
                                       version_output=args.version_output,
                                       version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=1.0,
                                 wd_mult=args.fc7_wd_mult)
    if args.loss_type == 0:  #softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding,
                                    weight=_weight,
                                    bias=_bias,
                                    num_hidden=args.num_classes,
                                    name='fc7')
    elif args.loss_type == 1:  #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding,
                              label=gt_label,
                              num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta,
                              margin=args.margin,
                              scale=args.scale,
                              beta_min=args.beta_min,
                              verbose=1000,
                              name='fc7')
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        assert (s > 0.0)
        assert (m > 0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        s_m = s * m
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=s_m,
                                    off_value=0.0)
        fc7 = fc7 - gt_one_hot
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')

        # split fc7 into fc7(labeled part) and fc7_2(unlabeled part)
        # calculate the likelihood function of the fc7_2 part
        subtract = args.num_classes * math.log(
            args.num_classes
        )  ############### This is calculated based on the number of labeled classes
        num_unlabeled = args.num_unlabeled
        num_labeled = args.per_batch_size - args.num_unlabeled
        per_batch_size = args.per_batch_size
        assert num_labeled > 0
        fc7_2 = mx.sym.slice_axis(fc7,
                                  axis=0,
                                  begin=num_labeled,
                                  end=per_batch_size)
        fc7 = mx.sym.slice_axis(fc7, axis=0, begin=0, end=num_labeled)
        fc7_2 = mx.sym.softmax(data=fc7_2, axis=1)

        log_likelihood = -mx.sym.log_softmax(fc7_2, axis=1)
        log_likelihood_loss = (mx.sym.sum(log_likelihood) / num_unlabeled)
        log_likelihood_loss = log_likelihood_loss - subtract

        gt_label = mx.sym.slice_axis(gt_label,
                                     axis=0,
                                     begin=0,
                                     end=num_labeled)
        ############################################################

        #last_to_append = mx.symbol.BlockGrad(mx.sym.softmax(fc7))
        out_list = []

        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        #threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=args.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    out_list.append(mx.symbol.BlockGrad(embedding))

    if args.loss_type == 4:
        #softmax = mx.sym.softmax_cross_entropy(data=fc7, label=gt_label)
        cos_fc7 = mx.symbol.BlockGrad(
            fc7)  # TODO: maybe this softmax makes the accuracy inconsistent
        gt_label_one_hot = mx.sym.one_hot(gt_label, args.num_classes)
        fc7 = -mx.sym.log_softmax(fc7, axis=1)
        softmax = gt_label_one_hot * fc7
        softmax = mx.sym.sum(softmax) / (args.per_batch_size -
                                         args.num_unlabeled)
        scale = args.likelihood_scale_factor
        loss = mx.sym.make_loss(data=(softmax + log_likelihood_loss * scale))
        out_list.append(loss)
        out_list.append(cos_fc7)
        out_list.append(mx.sym.BlockGrad(softmax))
        out_list.append(mx.sym.BlockGrad(log_likelihood_loss))
    else:
        softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                          label=gt_label,
                                          name='softmax',
                                          normalization='valid')
        out_list.append(softmax)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
コード例 #10
0
def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel,args.image_h,args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0]=='d':
        embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='m':
        print('init mobilenet', args.num_layers)
        if args.num_layers==1:
            embedding = fmobilenet.get_symbol(args.emb_size, 
                version_se=args.version_se, version_input=args.version_input, 
                version_output=args.version_output, version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0]=='i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(args.emb_size,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0]=='n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0]=='s':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    elif args.network[0]=='y':
        print('init mobilefacenet', args.num_layers)
        embedding = fmobilefacenet.get_symbol(args.emb_size, bn_mom = args.bn_mom, wd_mult = args.fc7_wd_mult)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size, args.num_layers, 
            version_se=args.version_se, version_input=args.version_input, 
            version_output=args.version_output, version_unit=args.version_unit,
            version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    #center_label = mx.symbol.Variable('center_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0, wd_mult=args.fc7_wd_mult)
    if args.loss_type==0: #softmax
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding, weight = _weight, bias = _bias, num_hidden=args.num_classes, name='fc7')
    elif args.loss_type==1: #sphere
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                            weight = _weight,
                            beta=args.beta, margin=args.margin, scale=args.scale,
                            beta_min=args.beta_min, verbose=1000, name='fc7')
    elif args.loss_type==2:
        s = args.margin_s
        m = args.margin_m
        assert(s>0.0)
        assert(m>0.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        s_m = s*m
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
    elif args.loss_type==4:
        s = args.margin_s
        m = args.margin_m
        assert s>0.0
        assert m>=0.0
        assert m<(math.pi/2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy/s
        t = mx.sym.arccos(cos_t)
        if args.margin_a!=1.0:
          t = t*args.margin_a
        if args.margin_m>0.0:
          t = t+args.margin_m
        body = mx.sym.cos(t)
        if args.margin_b>0.0:
          body = body - args.margin_b
        new_zy = body*s
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7+body
    elif args.loss_type==5:
        s = args.margin_s
        m = args.margin_m
        assert s>0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        nembedding = mx.symbol.Dropout(data=nembedding, p=0.4)
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
            if args.margin_a==1.0 and args.margin_m==0.0:
                s_m = s*args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
                fc7 = fc7-gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy/s
                t = mx.sym.arccos(cos_t)
                if args.margin_a!=1.0:
                    t = t*args.margin_a
                if args.margin_m>0.0:
                    t = t+args.margin_m
                #threshold = math.cos(math.pi-m)
                if args.easy_margin:  # m<pi/2 so: pi-m > pi/2
                    cond = mx.symbol.Activation(data=cos_t, act_type='relu') # this means t > pi/2,  when pi/2 <t < 3*pi/2, sin is decrease
                    #cond_v = cos_t - threshold
                    #cond = mx.symbol.Activation(data=cond_v, act_type='relu')
                else:
                    cond_v = math.pi - t
                    cond = mx.symbol.Activation(data=cond_v, act_type='relu')
                    print("not easy margin: cond")
                    #cond = mx.symbol.Activation(data=cos_t, act_type='relu')
                if args.easy_margin:
                    zy_keep = mx.sym.cos(t)
                else:
                    print("not easy margin: sine")
                    zy_keep = mx.sym.sin(t)-1
                new_zy = mx.sym.cos(t)
                body = mx.sym.where(cond, new_zy, zy_keep)
                #body = mx.sym.cos(t)
                if args.margin_b>0.0:
                    body = body + args.margin_b
                new_zy = body*s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7+body
    elif args.loss_type==6:
        s = args.margin_s
        m = args.margin_m
        assert s>0.0
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
        if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
            if args.margin_a==1.0 and args.margin_m==0.0:
                s_m = s*args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
                fc7 = fc7-gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy/s
                t = mx.sym.arccos(cos_t)
                if args.margin_a!=1.0:
                    t = t*args.margin_a
                if args.margin_m>0.0:
                    t = t+args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b>0.0:
                    body = body - args.margin_b
                new_zy = body*s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7+body
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
    out_list.append(softmax)
    #print("out shape",np.shape(softmax))
    #center_in = mx.symbol.concat(embedding,fc7,dim=1)
    #center_loss_data = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss_data', op_type='centerloss',\
    #        num_class=args.num_classes, alpha=0.5, scale=0.5,lamdb=0.1,batchsize=args.per_batch_size,emb_size=args.emb_size)
    #extra_center_loss = mx.symbol.MakeLoss(name='extra_center_loss', data=center_loss_data)
    #total_loss = mx.symbol.ElementWiseSum([softmax, extra_loss],name='total_loss')
    #total_loss_op = mx.symbol.MakeLoss(name='total_loss_op',data=total_loss)
    #out_list.append(extra_center_loss)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)
コード例 #11
0
def get_symbol(args, arg_params, aux_params):
    # data_shape = (args.image_channel, args.image_h, args.image_w)
    data_shape = (3, 112, 112)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    # print('init mobilefacenet', args.num_layers)
    embedding = fmobilefacenet.get_symbol(args.emb_size,
                                          bn_mom=args.bn_mom,
                                          version_output=args.version_output)

    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    extra_loss = None
    _weight = mx.symbol.Variable("fc7_weight",
                                 shape=(args.num_classes, args.emb_size),
                                 lr_mult=args.fc7_lr_mult,
                                 wd_mult=args.fc7_wd_mult)

    if args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        cos_m = math.cos(m)
        sin_m = math.sin(m)
        mm = math.sin(math.pi - m) * m
        # threshold = 0.0
        threshold = math.cos(math.pi - m)
        if args.easy_margin:
            cond = mx.symbol.Activation(data=cos_t, act_type='relu')
        else:
            cond_v = cos_t - threshold
            cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        body = cos_t * cos_t
        body = 1.0 - body
        sin_t = mx.sym.sqrt(body)
        new_zy = cos_t * cos_m
        b = sin_t * sin_m
        new_zy = new_zy - b
        new_zy = new_zy * s
        if args.easy_margin:
            zy_keep = zy
        else:
            zy_keep = zy - s * mm
        new_zy = mx.sym.where(cond, new_zy, zy_keep)

        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label,
                                    depth=args.num_classes,
                                    on_value=1.0,
                                    off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body

    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = mx.symbol.SoftmaxOutput(data=fc7,
                                      label=gt_label,
                                      name='softmax',
                                      normalization='valid')
    out_list.append(softmax)

    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)