Example #1
0
def check_features(args):
    from image_iter_gen_feature import FaceImageIter
    global image_shape
    global net

    print(args)
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    image_shape = [int(x) for x in args.image_size.split(',')]
    vec = args.model.split(',')
    assert len(vec) > 1
    prefix = vec[0]
    epoch = int(vec[1])
    print('loading', prefix, epoch)
    net = edict()
    net.ctx = ctx
    net.sym, net.arg_params, net.aux_params = mx.model.load_checkpoint(
        prefix, epoch)
    # net.arg_params, net.aux_params = ch_dev(net.arg_params, net.aux_params, net.ctx)
    all_layers = net.sym.get_internals()
    net.sym = all_layers['fc1_output']
    net.model = mx.mod.Module(symbol=net.sym,
                              context=net.ctx,
                              label_names=None)
    net.model.bind(data_shapes=[('data', (args.batch_size, 3, image_shape[1],
                                          image_shape[2]))])
    net.model.set_params(net.arg_params, net.aux_params)

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=(3, 112, 112),
        path_imgrec=args.input_data,
        shuffle=True,
        rand_mirror=False,
        mean=None,
        cutoff=False,
        color_jittering=0,
        images_filter=0,
    )

    for i in range(10):
        db, features_data = train_dataiter.next()
        net.model.forward(db, is_train=False)
        embedding = net.model.get_outputs()[0].asnumpy()

        print((embedding == features_data).any())
def get_data_iter(config, batch_size):
    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]

    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)

    path_imgrec = os.path.join(data_dir, "train.rec")
    data_shape = (config.image_shape[2], image_size[0], image_size[1])

    val_dataiter = None

    mean = None
    train_dataiter = FaceImageIter(
        batch_size           = batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = config.data_rand_mirror,
        mean                 = mean,
        cutoff               = config.data_cutoff,
        color_jittering      = config.data_color,
        images_filter        = config.data_images_filter,
    )
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
    return train_dataiter, val_dataiter
Example #3
0
def train_net(args):

    ctx = []
    # cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() #0,使用第一块GPU
    cvd = []

    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))  #讲GPU context添加到ctx,ctx = [gpu(0)]

    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))  #使用了gpu

    prefix = args.prefix  #../model-r100
    prefix_dir = os.path.dirname(prefix)  #..

    if not os.path.exists(prefix_dir):  #未执行
        os.makedirs(prefix_dir)

    end_epoch = args.end_epoch  #100 000

    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])

    print('num_layers', args.num_layers)  #100

    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num  #10

    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)  #1000.0,参见Arcface公式(6),退火训练的lambda

    data_dir_list = args.data_dir.split(',')
    print('data_dir_list: ', data_dir_list)

    data_dir = data_dir_list[0]

    # 加载数据集属性
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    print('num_classes: ', args.num_classes)

    # path_imgrec = os.path.join(data_dir, "train.rec")
    path_imgrec = os.path.join(data_dir, "all.rec")

    if args.loss_type == 1 and args.num_classes > 20000:  #sphereface
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('***Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1]
                  )  #(3L,112L,112L)

    mean = None

    begin_epoch = 0
    base_lr = args.lr  #0.1
    base_wd = args.wd  #weight decay = 0.0005
    base_mom = args.mom  #动量:0.9

    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(
            ',')  #['../models/model-r50-am-lfw/model', '0000']
        print('***loading', vec)

        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        # print(sym[1])
        # mx.viz.plot_network(sym[1]).view() #可视化
        # sys.exit()
    if args.network[0] == 's':  # spherenet
        data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )

    # print(args.batch_size)
    # print(data_shape)
    # print(path_imgrec)
    # print(args.rand_mirror)
    # print(mean)
    # print(args.cutoff)
    # sys.exit()

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,  #(3L,112L,112L)
        path_imgrec=path_imgrec,  # train.rec
        shuffle=True,
        rand_mirror=args.rand_mirror,  # 1
        mean=mean,
        cutoff=args.cutoff,  # 0
    )

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    # 创建一个评价指标
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style  mobilefacenet
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)  #多卡训练的话,rescale_grad将总的结果分开
    som = 64
    # 回调函数,用来阶段性显示训练速度和准确率
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [30000, 40000, 50000]
        if args.loss_type >= 1 and args.loss_type <= 7:

            lr_steps = [100000, 140000, 160000]
        # 单GPU,去掉p
        # p = 512.0/args.batch_size
        for l in range(len(lr_steps)):
            # lr_steps[l] = int(lr_steps[l]*p)
            lr_steps[l] = int(lr_steps[l])
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step

        mbatch = global_step[0]
        global_step[0] += 1
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        print('mbatch=', mbatch)
        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            print(acc_list)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    # 修改验证集阈值,测试最佳阈值
                    # if lfw_score>=0.998:
                    if lfw_score >= 0.99:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    # if lfw_score>=0.99: #LFW测试大于0.99时,保存模型
                    if lfw_score >= 0.99:  #LFW测试大于0.99时,保存模型
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_data=train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=None,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Example #4
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.image_channel = 3

    data_dir = args.data_dir
    if args.task == 'gender':
        data_dir = args.gender_data_dir
    elif args.task == 'age':
        data_dir = args.age_data_dir
    print('data dir', data_dir)
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    net = get_model()
    #if args.task=='':
    #  test_net = get_model_test(net)
    #print(net.__class__)
    #net = net0[0]
    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    net.hybridize()
    if args.mode == 'gluon':
        if len(args.pretrained) == 0:
            pass
        else:
            net.load_params(args.pretrained,
                            allow_missing=True,
                            ignore_extra=True)
        net.initialize(initializer)
        net.collect_params().reset_ctx(ctx)

    val_iter = None
    if args.task == '':
        train_iter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=args.rand_mirror,
            mean=mean,
            cutoff=args.cutoff,
        )
    else:
        train_iter = FaceImageIterAge(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            task=args.task,
            shuffle=True,
            rand_mirror=args.rand_mirror,
            mean=mean,
            cutoff=args.cutoff,
        )

    if args.task == 'age':
        metric = CompositeEvalMetric([MAEMetric(), CUMMetric()])
    elif args.task == 'gender':
        metric = CompositeEvalMetric([AccMetric()])
    else:
        metric = CompositeEvalMetric([AccMetric()])

    ver_list = []
    ver_name_list = []
    if args.task == '':
        for name in args.eval.split(','):
            path = os.path.join(data_dir, name + ".bin")
            if os.path.exists(path):
                data_set = verification.load_bin(path, image_size)
                ver_list.append(data_set)
                ver_name_list.append(name)
                print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], net, ctx, batch_size=args.batch_size)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    def val_test(nbatch=0):
        acc = 0.0
        #if args.task=='age':
        if len(args.age_data_dir) > 0:
            val_iter = FaceImageIterAge(
                batch_size=args.batch_size,
                data_shape=data_shape,
                path_imgrec=os.path.join(args.age_data_dir, 'val.rec'),
                task=args.task,
                shuffle=False,
                rand_mirror=False,
                mean=mean,
            )
            _metric = MAEMetric()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            _metric2 = CUMMetric()
            val_metric2 = mx.metric.create(_metric2)
            val_metric2.reset()
            val_iter.reset()
            for batch in val_iter:
                data = gluon.utils.split_and_load(batch.data[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch.label[0],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = []
                for x in data:
                    outputs.append(net(x)[2])
                val_metric.update(label, outputs)
                val_metric2.update(label, outputs)
            _value = val_metric.get_name_value()[0][1]
            print('[%d][VMAE]: %f' % (nbatch, _value))
            _value = val_metric2.get_name_value()[0][1]
            if args.task == 'age':
                acc = _value
            print('[%d][VCUM]: %f' % (nbatch, _value))
        if len(args.gender_data_dir) > 0:
            val_iter = FaceImageIterAge(
                batch_size=args.batch_size,
                data_shape=data_shape,
                path_imgrec=os.path.join(args.gender_data_dir, 'val.rec'),
                task=args.task,
                shuffle=False,
                rand_mirror=False,
                mean=mean,
            )
            _metric = AccMetric()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            val_iter.reset()
            for batch in val_iter:
                data = gluon.utils.split_and_load(batch.data[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch.label[0],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = []
                for x in data:
                    outputs.append(net(x)[1])
                val_metric.update(label, outputs)
            _value = val_metric.get_name_value()[0][1]
            if args.task == 'gender':
                acc = _value
            print('[%d][VACC]: %f' % (nbatch, _value))
        return acc

    total_time = 0
    num_epochs = 0
    best_acc = [0]
    highest_acc = [0.0, 0.0]  #lfw and target
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    kv = mx.kv.create('device')
    #kv = mx.kv.create('local')
    #_rescale = 1.0/args.ctx_num
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd)
    if args.mode == 'gluon':
        trainer = gluon.Trainer(net.collect_params(),
                                'sgd', {
                                    'learning_rate': args.lr,
                                    'wd': args.wd,
                                    'momentum': args.mom,
                                    'multi_precision': True
                                },
                                kvstore=kv)
    else:
        _rescale = 1.0 / args.ctx_num
        opt = optimizer.SGD(learning_rate=args.lr,
                            momentum=args.mom,
                            wd=args.wd,
                            rescale_grad=_rescale)
        _cb = mx.callback.Speedometer(args.batch_size, 20)
        arg_params = None
        aux_params = None
        data = mx.sym.var('data')
        label = mx.sym.var('softmax_label')
        if args.margin_a > 0.0:
            fc7 = net(data, label)
        else:
            fc7 = net(data)
        #sym = mx.symbol.SoftmaxOutput(data=fc7, label = label, name='softmax', normalization='valid')
        ceop = gluon.loss.SoftmaxCrossEntropyLoss()
        loss = ceop(fc7, label)
        #loss = loss/args.per_batch_size
        loss = mx.sym.mean(loss)
        sym = mx.sym.Group([
            mx.symbol.BlockGrad(fc7),
            mx.symbol.MakeLoss(loss, name='softmax')
        ])

    def _batch_callback():
        mbatch = global_step[0]
        global_step[0] += 1
        for _lr in lr_steps:
            if mbatch == _lr:
                args.lr *= 0.1
                if args.mode == 'gluon':
                    trainer.set_learning_rate(args.lr)
                else:
                    opt.lr = args.lr
                print('lr change to', args.lr)
                break

        #_cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', args.lr, mbatch)

        if mbatch > 0 and mbatch % args.verbose == 0:
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if args.task == 'age' or args.task == 'gender':
                acc = val_test(mbatch)
                if acc >= highest_acc[-1]:
                    highest_acc[-1] = acc
                    is_highest = True
                    do_save = True
            else:
                acc_list = ver_test(mbatch)
                if len(acc_list) > 0:
                    lfw_score = acc_list[0]
                    if lfw_score > highest_acc[0]:
                        highest_acc[0] = lfw_score
                        if lfw_score >= 0.998:
                            do_save = True
                    if acc_list[-1] >= highest_acc[-1]:
                        highest_acc[-1] = acc_list[-1]
                        if lfw_score >= 0.99:
                            do_save = True
                            is_highest = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                #print('saving gluon params')
                fname = os.path.join(args.prefix, 'model-gluon.params')
                net.save_params(fname)
                fname = os.path.join(args.prefix, 'model')
                net.export(fname, msave)
                #arg, aux = model.get_params()
                #mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    def _batch_callback_sym(param):
        _cb(param)
        _batch_callback()

    if args.mode != 'gluon':
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
        )
        model.fit(train_iter,
                  begin_epoch=0,
                  num_epoch=args.end_epoch,
                  eval_data=None,
                  eval_metric=metric,
                  kvstore='device',
                  optimizer=opt,
                  initializer=initializer,
                  arg_params=arg_params,
                  aux_params=aux_params,
                  allow_missing=True,
                  batch_end_callback=_batch_callback_sym,
                  epoch_end_callback=None)
    else:
        loss_weight = 1.0
        if args.task == 'age':
            loss_weight = 1.0 / AGE
        #loss = gluon.loss.SoftmaxCrossEntropyLoss(weight = loss_weight)
        loss = nd.SoftmaxOutput
        #loss = gluon.loss.SoftmaxCrossEntropyLoss()
        while True:
            #trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
            tic = time.time()
            train_iter.reset()
            metric.reset()
            btic = time.time()
            for i, batch in enumerate(train_iter):
                _batch_callback()
                #data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
                #label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
                data = gluon.utils.split_and_load(batch.data[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch.label[0],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = []
                Ls = []
                with ag.record():
                    for x, y in zip(data, label):
                        #print(y.asnumpy())
                        if args.task == '':
                            if args.margin_a > 0.0:
                                z = net(x, y)
                            else:
                                z = net(x)
                            #print(z[0].shape, z[1].shape)
                        else:
                            z = net(x)
                        if args.task == 'gender':
                            L = loss(z[1], y)
                            #L = L/args.per_batch_size
                            Ls.append(L)
                            outputs.append(z[1])
                        elif args.task == 'age':
                            for k in xrange(AGE):
                                _z = nd.slice_axis(z[2],
                                                   axis=1,
                                                   begin=k * 2,
                                                   end=k * 2 + 2)
                                _y = nd.slice_axis(y,
                                                   axis=1,
                                                   begin=k,
                                                   end=k + 1)
                                _y = nd.flatten(_y)
                                L = loss(_z, _y)
                                #L = L/args.per_batch_size
                                #L /= AGE
                                Ls.append(L)
                            outputs.append(z[2])
                        else:
                            L = loss(z, y)
                            #L = L/args.per_batch_size
                            Ls.append(L)
                            outputs.append(z)
                        # store the loss and do backward after we have done forward
                        # on all GPUs for better speed on multiple GPUs.
                    ag.backward(Ls)
                #trainer.step(batch.data[0].shape[0], ignore_stale_grad=True)
                #trainer.step(args.ctx_num)
                n = batch.data[0].shape[0]
                #print(n,n)
                trainer.step(n)
                metric.update(label, outputs)
                if i > 0 and i % 20 == 0:
                    name, acc = metric.get()
                    if len(name) == 2:
                        logger.info(
                            'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'
                            % (num_epochs, i, args.batch_size /
                               (time.time() - btic), name[0], acc[0], name[1],
                               acc[1]))
                    else:
                        logger.info(
                            'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'
                            % (num_epochs, i, args.batch_size /
                               (time.time() - btic), name[0], acc[0]))
                    #metric.reset()
                btic = time.time()

            epoch_time = time.time() - tic

            # First epoch will usually be much slower than the subsequent epics,
            # so don't factor into the average
            if num_epochs > 0:
                total_time = total_time + epoch_time

            #name, acc = metric.get()
            #logger.info('[Epoch %d] training: %s=%f, %s=%f'%(num_epochs, name[0], acc[0], name[1], acc[1]))
            logger.info('[Epoch %d] time cost: %f' % (num_epochs, epoch_time))
            num_epochs = num_epochs + 1
            #name, val_acc = test(ctx, val_data)
            #logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

            # save model if meet requirements
            #save_checkpoint(epoch, val_acc[0], best_acc)
        if num_epochs > 1:
            print('Average epoch time: {}'.format(
                float(total_time) / (num_epochs - 1)))
Example #5
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()

    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx), ctx, cvd)
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    # image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    sym, arg_params, aux_params = get_symbol(args,
                                             arg_params,
                                             aux_params,
                                             layer_name='ms1m_fc7')
    fixed_args = [n for n in sym.list_arguments() if 'fc7' in n]

    # sym.get_internals()
    # sym.list_arguments()
    # sym.list_auxiliary_states()
    # sym.list_inputs()
    # sym.list_outputs()

    # label_name = 'softmax_label'
    # label_shape = (args.batch_size,)
    # arg_params['glint_fc7_weight'] = arg_params['fc7_weight'].copy()
    # arg_params['ms1m_fc7_weight'] = arg_params['glint_fc7_weight'].copy()
    assert 'ms1m_fc7_weight' in arg_params
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        fixed_param_names=fixed_args,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  # resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    # opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    logging.info(f'base lr {base_lr}')
    opt = optimizer.Adam(
        learning_rate=base_lr,
        wd=base_wd,
        rescale_grad=_rescale,
    )
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    # ver_test( 0 )
    highest_acc = [0.0, 0.0]  # lfw and target
    # for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch: lr ', opt.lr, 'nbatch ', param.nbatch,
                  'epoch ', param.epoch, 'mbatch ', mbatch, 'lr_step',
                  lr_steps)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                # lfw_score = acc_list[0]
                # if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    # if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)

            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
    # model.set_params(arg_params, aux_params)
    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        # optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Example #6
0
def train_net():
    args = parse_args()
    hvd.init()

    # Size is the number of total GPU, rank is the unique process(GPU) ID from 0 to size,
    # local_rank is the unique process(GPU) ID within this server
    rank = hvd.rank()
    local_rank = hvd.local_rank()
    size = hvd.size()

    prefix = os.path.join(args.models_root, 'model')
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir) and not local_rank:
        os.makedirs(prefix_dir)
    else:
        time.sleep(2)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    set_logger(logger, rank, prefix_dir)

    data_shape = (3, config.image_size, config.image_size)

    # We equally store the class centers (softmax linear transformation matrix) on all GPUs in order.
    num_local = (config.num_classes + size - 1) // size
    num_sample = int(num_local * config.sample_ratio)
    memory_bank = MemoryBank(num_sample=num_sample,
                             num_local=num_local,
                             rank=rank,
                             local_rank=local_rank,
                             embedding_size=config.embedding_size,
                             prefix=prefix_dir,
                             gpu=True)

    if config.debug:
        train_iter = DummyIter(config.batch_size, data_shape, 1000 * 10000)
    else:
        train_iter = FaceImageIter(batch_size=config.batch_size,
                                   data_shape=data_shape,
                                   path_imgrec=config.rec,
                                   shuffle=True,
                                   rand_mirror=True,
                                   context=rank,
                                   context_num=size)
    train_data_iter = mx.io.PrefetchingIter(train_iter)

    esym, save_symbol = get_symbol_embedding()
    margins = (config.loss_m1, config.loss_m2, config.loss_m3)
    fc7_model = MarginLoss(margins, config.loss_s, config.embedding_size)

    # optimizer
    # backbone  lr_scheduler & optimizer
    backbone_lr_scheduler, memory_bank_lr_scheduler = get_scheduler()

    backbone_kwargs = {
        'learning_rate': config.backbone_lr,
        'momentum': 0.9,
        'wd': 5e-4,
        'rescale_grad': 1.0 / (config.batch_size * size) * size,
        'multi_precision': config.fp16,
        'lr_scheduler': backbone_lr_scheduler,
    }

    # memory_bank lr_scheduler & optimizer
    memory_bank_optimizer = MemoryBankSGDOptimizer(
        lr_scheduler=memory_bank_lr_scheduler,
        rescale_grad=1.0 / config.batch_size / size,
    )
    #
    train_module = SampleDistributeModule(
        symbol=esym,
        fc7_model=fc7_model,
        memory_bank=memory_bank,
        memory_optimizer=memory_bank_optimizer)
    #
    if not config.debug and local_rank == 0:
        cb_vert = CallBackVertification(esym, train_module)
    cb_speed = CallBackLogging(rank, size, prefix_dir)
    cb_save = CallBackModelSave(save_symbol,
                                train_module,
                                prefix,
                                rank,
                                save_interval=config.verbose)
    cb_center_save = CallBackCenterSave(memory_bank)

    def call_back_fn(params):
        cb_speed(params)
        if not config.debug and local_rank == 0:
            cb_vert(params)
        cb_center_save(params)
        cb_save(params)

    train_module.fit(train_data_iter,
                     optimizer_params=backbone_kwargs,
                     initializer=mx.init.Normal(0.1),
                     batch_end_callback=call_back_fn)
sys.path.append("/home/gaomingda/insightface/recognition")
from image_iter import FaceImageIter

import cv2
import os
import mxnet as mx
import numpy as np

save_root = '/home/gaomingda/insightface/recognition/createLmdb/My_Files'

train_dataiter = FaceImageIter(
    batch_size=4,
    data_shape=(3, 112, 112),
    path_imgrec=
    "/home/gaomingda/insightface/datasets/ms1m-retinaface-t1/train.rec",
    shuffle=True,
    rand_mirror=False,
    mean=None,
    cutoff=False,
    color_jittering=0,
    images_filter=0,
)
data_nums = train_dataiter.num_samples()
train_dataiter.reset()
train_dataiter.is_init = True

f = open(os.path.join(save_root, "train.txt"), 'w')
f.truncate()

for i in range(data_nums):
    label, s, _, _ = train_dataiter.next_sample()
    img_ = mx.image.imdecode(s)  #mx.ndarray
Example #8
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    data_dir = config.dataset_path
    path_imgrecs = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrecs = [os.path.join(data_dir, "train.rec")]

    data_shape = (args.image_channel, image_size[0], image_size[1])

    num_workers = config.num_workers
    global_num_ctx = num_workers * args.ctx_num
    if config.num_classes % global_num_ctx == 0:
        args.ctx_num_classes = config.num_classes // global_num_ctx
    else:
        args.ctx_num_classes = config.num_classes // global_num_ctx + 1
    print(config.num_classes, global_num_ctx, args.ctx_num_classes)
    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    #if len(args.partial)==0:
    #  local_classes_range = (0, args.num_classes)
    #else:
    #  _vec = args.partial.split(',')
    #  local_classes_range = (int(_vec[0]), int(_vec[1]))

    #args.partial_num_classes = local_classes_range[1] - local_classes_range[0]
    #args.partial_start = local_classes_range[0]

    print('Called with argument:', args, config)
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    esym = get_symbol_embedding()
    asym = get_symbol_arcface
    if config.num_workers == 1:
        sys.path.append(os.path.join(os.path.dirname(__file__), 'utils'))
        from parall_module_local_v1 import ParallModule
    else:
        from parall_module_dist import ParallModule

    model = ParallModule(
        context=ctx,
        symbol=esym,
        data_names=['data'],
        label_names=['softmax_label'],
        asymbol=asym,
        args=args,
    )

    val_dataiter = None
    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrecs=path_imgrecs,
        shuffle=True,
        rand_mirror=config.data_rand_mirror,
        mean=mean,
        cutout=default.cutout if config.data_cutout else None,
        crop=default.crop if config.data_crop else None,
        mask=default.mask if config.data_mask else None,
        gridmask=default.gridmask if config.data_grid else None,
        #color_jittering      = config.data_color,
        #images_filter        = config.data_images_filter,
        loss_type=args.loss,
        #margin_m             = config.loss_m2,
        data_names=['data'],
        downsample_back=config.downsample_back,
        motion_blur=config.motion_blur,
        use_bgr=config.use_bgr)

    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)

    _rescale = 1.0 / 8  #/ args.batch_size
    print(base_lr, base_mom, base_wd, args.batch_size)

    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_steps,
                                                        factor=0.1,
                                                        base_lr=base_lr)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': base_mom,
        'wd': base_wd,
        'rescale_grad': _rescale,
        'lr_scheduler': lr_scheduler
    }

    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)

    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            _, issame_list = ver_list[i]
            if all(issame_list):
                fp_rates, fp_dict, thred_dict, recall_dict = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    use_bgr=config.use_bgr,
                    label_shape=(args.batch_size, len(path_imgrecs)))
                for k in fp_rates:
                    print("[%s] TPR at FPR %.2e[%.2e: %.4f]:\t%.5f" %
                          (ver_name_list[i], k, fp_dict[k], thred_dict[k],
                           recall_dict[k]))

            else:
                acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    10,
                    None,
                    label_shape=(args.batch_size, len(path_imgrecs)),
                    use_bgr=config.use_bgr)
                print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
                #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
                print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                      (ver_name_list[i], nbatch, acc2, std2))
                results.append(acc2)

        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]

        #for step in lr_steps:
        #  if mbatch==step:
        #    opt.lr *= 0.1
        #    print('lr change to', opt.lr)
        #    break

        _cb(param)
        if mbatch % 1000 == 0:
            #print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
            print('batch-epoch:', param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()  #get_export_params()
                all_layers = model.symbol.get_internals()
                _sym = model.symbol  #all_layers['fc1_output']
                mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    if len(args.pretrained) != 0:
        model_prefix, epoch = args.pretrained.split(',')
        begin_epoch = int(epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_prefix, begin_epoch)
        #model.set_params(arg_params, aux_params)

    model.fit(
        train_dataiter,
        begin_epoch=0,  #begin_epoch,
        num_epoch=default.end_epoch,
        eval_data=val_dataiter,
        #eval_metric        = eval_metrics,
        kvstore=args.kvstore,
        #optimizer          = opt,
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Example #9
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.image_channel = 3

    data_dir = args.data_dir
    if args.task=='gender':
      data_dir = args.gender_data_dir
    elif args.task=='age':
      data_dir = args.age_data_dir
    print('data dir', data_dir)
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert(args.num_classes>0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")


    print('Called with argument:', args)
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None

    begin_epoch = 0
    net = get_model()
    #if args.task=='':
    #  test_net = get_model_test(net)
    #print(net.__class__)
    #net = net0[0]
    if args.network[0]=='r' or args.network[0]=='y':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    net.hybridize()
    if args.mode=='gluon':
      if len(args.pretrained)==0:
        pass
      else:
        net.load_params(args.pretrained, allow_missing=True, ignore_extra = True)
      net.initialize(initializer)
      net.collect_params().reset_ctx(ctx)

    val_iter = None
    if args.task=='':
      train_iter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = args.rand_mirror,
          mean                 = mean,
          cutoff               = args.cutoff,
      )
    else:
      train_iter = FaceImageIterAge(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          task                 = args.task,
          shuffle              = True,
          rand_mirror          = args.rand_mirror,
          mean                 = mean,
          cutoff               = args.cutoff,
      )

    if args.task=='age':
      metric = CompositeEvalMetric([MAEMetric(), CUMMetric()])
    elif args.task=='gender':
      metric = CompositeEvalMetric([AccMetric()])
    else:
      metric = CompositeEvalMetric([AccMetric()])

    ver_list = []
    ver_name_list = []
    if args.task=='':
      for name in args.eval.split(','):
        path = os.path.join(data_dir,name+".bin")
        if os.path.exists(path):
          data_set = verification.load_bin(path, image_size)
          ver_list.append(data_set)
          ver_name_list.append(name)
          print('ver', name)

    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], net, ctx, batch_size = args.batch_size)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results

    def val_test(nbatch=0):
      acc = 0.0
      #if args.task=='age':
      if len(args.age_data_dir)>0:
        val_iter = FaceImageIterAge(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = os.path.join(args.age_data_dir, 'val.rec'),
            task                 = args.task,
            shuffle              = False,
            rand_mirror          = False,
            mean                 = mean,
        )
        _metric = MAEMetric()
        val_metric = mx.metric.create(_metric)
        val_metric.reset()
        _metric2 = CUMMetric()
        val_metric2 = mx.metric.create(_metric2)
        val_metric2.reset()
        val_iter.reset()
        for batch in val_iter:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            for x in data:
                outputs.append(net(x)[2])
            val_metric.update(label, outputs)
            val_metric2.update(label, outputs)
        _value = val_metric.get_name_value()[0][1]
        print('[%d][VMAE]: %f'%(nbatch, _value))
        _value = val_metric2.get_name_value()[0][1]
        if args.task=='age':
          acc = _value
        print('[%d][VCUM]: %f'%(nbatch, _value))
      if len(args.gender_data_dir)>0:
        val_iter = FaceImageIterAge(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = os.path.join(args.gender_data_dir, 'val.rec'),
            task                 = args.task,
            shuffle              = False,
            rand_mirror          = False,
            mean                 = mean,
        )
        _metric = AccMetric()
        val_metric = mx.metric.create(_metric)
        val_metric.reset()
        val_iter.reset()
        for batch in val_iter:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            for x in data:
                outputs.append(net(x)[1])
            val_metric.update(label, outputs)
        _value = val_metric.get_name_value()[0][1]
        if args.task=='gender':
          acc = _value
        print('[%d][VACC]: %f'%(nbatch, _value))
      return acc


    total_time = 0
    num_epochs = 0
    best_acc = [0]
    highest_acc = [0.0, 0.0]  #lfw and target
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    kv = mx.kv.create('device')
    #kv = mx.kv.create('local')
    #_rescale = 1.0/args.ctx_num
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd)
    if args.mode=='gluon':
      trainer = gluon.Trainer(net.collect_params(), 'sgd', 
              {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.mom, 'multi_precision': True},
              kvstore=kv)
    else:
      _rescale = 1.0/args.ctx_num
      opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
      _cb = mx.callback.Speedometer(args.batch_size, 20)
      arg_params = None
      aux_params = None
      data = mx.sym.var('data')
      label = mx.sym.var('softmax_label')
      if args.margin_a>0.0:
        fc7 = net(data, label)
      else:
        fc7 = net(data)
      #sym = mx.symbol.SoftmaxOutput(data=fc7, label = label, name='softmax', normalization='valid')
      ceop = gluon.loss.SoftmaxCrossEntropyLoss()
      loss = ceop(fc7, label) 
      #loss = loss/args.per_batch_size
      loss = mx.sym.mean(loss)
      sym = mx.sym.Group( [mx.symbol.BlockGrad(fc7), mx.symbol.MakeLoss(loss, name='softmax')] )

    def _batch_callback():
      mbatch = global_step[0]
      global_step[0]+=1
      for _lr in lr_steps:
        if mbatch==_lr:
          args.lr *= 0.1
          if args.mode=='gluon':
            trainer.set_learning_rate(args.lr)
          else:
            opt.lr  = args.lr
          print('lr change to', args.lr)
          break

      #_cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',args.lr, mbatch)

      if mbatch>0 and mbatch%args.verbose==0:
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if args.task=='age' or args.task=='gender':
          acc = val_test(mbatch)
          if acc>=highest_acc[-1]:
            highest_acc[-1] = acc
            is_highest = True
            do_save = True
        else:
          acc_list = ver_test(mbatch)
          if len(acc_list)>0:
            lfw_score = acc_list[0]
            if lfw_score>highest_acc[0]:
              highest_acc[0] = lfw_score
              if lfw_score>=0.998:
                do_save = True
            if acc_list[-1]>=highest_acc[-1]:
              highest_acc[-1] = acc_list[-1]
              if lfw_score>=0.99:
                do_save = True
                is_highest = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt>1:
          do_save = True
        if do_save:
          print('saving', msave)
          #print('saving gluon params')
          fname = os.path.join(args.prefix, 'model-gluon.params')
          net.save_params(fname)
          fname = os.path.join(args.prefix, 'model')
          net.export(fname, msave)
          #arg, aux = model.get_params()
          #mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

    def _batch_callback_sym(param):
      _cb(param)
      _batch_callback()


    if args.mode!='gluon':
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
      )
      model.fit(train_iter,
          begin_epoch        = 0,
          num_epoch          = args.end_epoch,
          eval_data          = None,
          eval_metric        = metric,
          kvstore            = 'device',
          optimizer          = opt,
          initializer        = initializer,
          arg_params         = arg_params,
          aux_params         = aux_params,
          allow_missing      = True,
          batch_end_callback = _batch_callback_sym,
          epoch_end_callback = None )
    else:
      loss_weight = 1.0
      if args.task=='age':
        loss_weight = 1.0/AGE
      #loss = gluon.loss.SoftmaxCrossEntropyLoss(weight = loss_weight)
      loss = nd.SoftmaxOutput
      #loss = gluon.loss.SoftmaxCrossEntropyLoss()
      while True:
          #trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
          tic = time.time()
          train_iter.reset()
          metric.reset()
          btic = time.time()
          for i, batch in enumerate(train_iter):
              _batch_callback()
              #data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
              #label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
              data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
              label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
              outputs = []
              Ls = []
              with ag.record():
                  for x, y in zip(data, label):
                      #print(y.asnumpy())
                      if args.task=='':
                        if args.margin_a>0.0:
                          z = net(x,y)
                        else:
                          z = net(x)
                        #print(z[0].shape, z[1].shape)
                      else:
                        z = net(x)
                      if args.task=='gender':
                        L = loss(z[1], y)
                        #L = L/args.per_batch_size
                        Ls.append(L)
                        outputs.append(z[1])
                      elif args.task=='age':
                        for k in xrange(AGE):
                          _z = nd.slice_axis(z[2], axis=1, begin=k*2, end=k*2+2)
                          _y = nd.slice_axis(y, axis=1, begin=k, end=k+1)
                          _y = nd.flatten(_y)
                          L = loss(_z, _y)
                          #L = L/args.per_batch_size
                          #L /= AGE
                          Ls.append(L)
                        outputs.append(z[2])
                      else:
                        L = loss(z, y)
                        #L = L/args.per_batch_size
                        Ls.append(L)
                        outputs.append(z)
                      # store the loss and do backward after we have done forward
                      # on all GPUs for better speed on multiple GPUs.
                  ag.backward(Ls)
              #trainer.step(batch.data[0].shape[0], ignore_stale_grad=True)
              #trainer.step(args.ctx_num)
              n = batch.data[0].shape[0]
              #print(n,n)
              trainer.step(n)
              metric.update(label, outputs)
              if i>0 and i%20==0:
                  name, acc = metric.get()
                  if len(name)==2:
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'%(
                                   num_epochs, i, args.batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
                  else:
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
                                   num_epochs, i, args.batch_size/(time.time()-btic), name[0], acc[0]))
                  #metric.reset()
              btic = time.time()

          epoch_time = time.time()-tic

          # First epoch will usually be much slower than the subsequent epics,
          # so don't factor into the average
          if num_epochs > 0:
            total_time = total_time + epoch_time

          #name, acc = metric.get()
          #logger.info('[Epoch %d] training: %s=%f, %s=%f'%(num_epochs, name[0], acc[0], name[1], acc[1]))
          logger.info('[Epoch %d] time cost: %f'%(num_epochs, epoch_time))
          num_epochs = num_epochs + 1
          #name, val_acc = test(ctx, val_data)
          #logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

          # save model if meet requirements
          #save_checkpoint(epoch, val_acc[0], best_acc)
      if num_epochs > 1:
          print('Average epoch time: {}'.format(float(total_time)/(num_epochs - 1)))
Example #10
0
def train_net(args):
    # Set up kvstore
    kv = mx.kvstore.create(args.kv_store)
    if args.gc_type != 'none':
        kv.set_gradient_compression({
            'type': args.gc_type,
            'threshold': args.gc_threshold
        })

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    # Get ctx according to num_gpus, gpu id start from 0
    ctx = []
    ctx = [mx.cpu()] if args.num_gpus is None or args.num_gpus is 0 else [
        mx.gpu(i) for i in range(args.num_gpus)
    ]

    # model prefix, In UAI Platform, should be /data/output/xxx
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    #image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")
    path_imglist = os.path.join(data_dir, "train.lst")

    num_samples = 0
    for line in open(path_imglist).xreadlines():
        num_samples += 1

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        # Not the mode is saved each epoch, not NUM of steps as in train_softmax.py
        # args.pretrained be 'prefix,epoch'
        vec = args.pretrained.split(',')
        print('loading', vec)
        model_prefix = vec[0]
        if kv.rank > 0 and os.path.exists("%s-%d-symbol.json" %
                                          (model_prefix, kv.rank)):
            model_prefix += "-%d" % (kv.rank)
        logging.info('Loaded model %s_%d.params', model_prefix, int(vec[1]))
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_prefix, int(vec[1]))
        begin_epoch = int(vec[1])
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target

    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)

    def _batch_callback(param):
        #global global_step
        mbatch = param.nbatch

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            is_highest = False
            if len(acc_list) > 0:
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True

            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))

    # save model
    checkpoint = _save_model(args, kv.rank)
    epoch_cb = checkpoint

    rescale = 1.0 / args.ctx_num
    lr, lr_scheduler = _get_lr_scheduler(args, kv, begin_epoch, num_samples)
    # learning rate
    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True,
        'rescale_grad': rescale
    }
    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag'}
    if args.optimizer in has_momentum:
        optimizer_params['momentum'] = args.mom

    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    print('Start training')
    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore=kv,
              optimizer=args.optimizer,
              optimizer_params=optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)
Example #11
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0] == 's':
        data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(context=ctx, symbol=sym)
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size  = args.batch_size,
        data_shape  = data_shape,
        path_imgrec = path_imgrec,
        shuffle     = True,
        rand_mirror = args.rand_mirror,
        mean        = mean,
        cutoff      = args.cutoff)

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(
            rnd_type='gaussian', factor_type="out", magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(
            rnd_type='gaussian', factor_type="in", magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(
            rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(
        learning_rate = base_lr,
        momentum      = base_mom,
        wd            = base_wd,
        rescale_grad  = _rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list, best_all = verification.test(
                ver_list[i], model, min(args.batch_size, 256), 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            print('[%s][%d]Best-Threshold: %1.2f  %1.5f' %
                  (ver_name_list[i], nbatch, best_all[0], best_all[1]))
            results.append(acc2)
        return results

    def highest_cmp(acc, cpt):
        assert len(acc) > 0
        if acc[0] > cpt[1]:
            return True
        elif acc[0] < cpt[1]:
            return False
        else:
            acc_sum = 0.0
            cpt_sum = 0.0
            for i in range(1, len(acc)):
                acc_sum += acc[i]
                cpt_sum += cpt[i+1]
            if acc_sum >= cpt_sum:
                return True
            else:
                return False

    highest_acc = []  # lfw and target
    for i in range(len(ver_list)):
        highest_acc.append(0.0)
    highest_cpt = [0] + highest_acc
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in range(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            do_save = False
            if len(acc_list) > 0:
                if acc_list[0] > 0.997:  # lfw
                    for i in range(len(acc_list)):
                        if acc_list[i] >= highest_acc[i]:
                            do_save = True
                for i in range(len(acc_list)):
                    highest_acc[i] = max(highest_acc[i], acc_list[i])
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                save_step[0] += 1
                msave = save_step[0]
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
                if highest_cmp(acc_list, highest_cpt):
                    highest_cpt[0] = msave
                    for i, acc in enumerate(acc_list):
                        highest_cpt[i+1] = acc
            sys.stdout.write('[%d]Accuracy-Highest: ' % mbatch)
            for acc in highest_acc:
                sys.stdout.write('%1.5f  ' % acc)
            sys.stdout.write('\n')
            sys.stdout.write('[%d]Accuracy-BestCpt: (%d) ' % (mbatch, highest_cpt[0]))
            for acc in highest_cpt[1:]:
                sys.stdout.write('%1.5f  ' % acc)
            sys.stdout.write('\n')
            sys.stdout.flush()
            # print('[%d]Accuracy-Highest: %1.5f  %1.5f  %1.5f'%(mbatch, highest_acc[0], highest_acc[1], highest_acc[2]))
            # print('[%d]Accuracy-BestCPt: <%d> %1.5f  %1.5f  %1.5f' % ((mbatch,) + tuple(highest_cpt)))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        # optimizer_params = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb)
def train_net(args):
  # gpu / cpu 设置
    ctx = []
    # cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    # cvd = os.environ['0'].strip()
    # if len(cvd)>0:
    #   for i in range(len(cvd.split(','))):
    #     ctx.append(mx.gpu(i))
    # if len(ctx)==0:
    #   ctx = [mx.cpu()]
    #   print('use cpu')
    # else:
    #   print('gpu num:', len(ctx))
    ctx.append(mx.gpu(0))
  # 保存模型的路径设置
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
  # 参数预设
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)             #print
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    print(args.batch_size)
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    #读取property文件
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    #image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)             #print
    assert(args.num_classes>0)
    print('num_classes', args.num_classes)      #print

    path_imgrec = os.path.join(data_dir, "train.rec")    

    if args.loss_type==1 and args.num_classes>20000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    print('Called with argument:', args)         #print
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None
  # 预训练模型是否存在
    begin_epoch = 0 
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom

    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
      if args.network[0]=='s':
        data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)


  # 初始化model

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context       = ctx,
        symbol        = sym,
    )
    val_dataiter = None
  # 获取train_data参数
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = args.rand_mirror,
        mean                 = mean,
        cutoff               = args.cutoff,
        color_jittering      = args.color,
        images_filter        = args.images_filter,
    )
  # 获取eval_metric参数
    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]

    if args.ce_loss:
      metric2 = LossValueMetric()
      eval_metrics.append( mx.metric.create(metric2) )


  # initializer获取(权重初始化)  根据net类型获取  / 并获取optimizer

    if args.network[0]=='r' or args.network[0]=='y':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0/args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

  # 加载测试集数据
    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)


  # 对测试集进行测试
    def ver_test(nbatch):
      results = []
      for i in range(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results



    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

  # lr_steps的设置
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size                        #args.batch_size = 128*x
      for l in range(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

  # 模型保存和lr等一些参数的变化设置
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]

      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:        #args.beta_freeze = 5000
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:      #args.verbose = 2000
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if len(acc_list)>0:
          #lfw_score = acc_list[0]
          #if lfw_score>highest_acc[0]:
          #  highest_acc[0] = lfw_score
          #  if lfw_score>=0.998:
          #    do_save = True
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:

            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score

            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        if is_highest:
          do_save = True
        # 模型保存方式
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3:
          msave = 1

        if do_save:
          print('saving', msave)
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta                 #args.beta = 1000
      else:                               #mbatch>args.beta_freeze
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)
  # 模型保存路劲设置
    epoch_cb = None
    # mx.io.PrefetchingIter()这个好像是把几个数据迭代器合并的接口
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
  # 训练入口
    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Example #13
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    # image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    # label_name = 'softmax_label'
    # label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  # resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0, 0.0, 0.0, 0.0]
    # for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                score = {}
                score['lfw_score'] = acc_list[0]
                score['cfp_score'] = acc_list[1]
                score['agedb_score'] = acc_list[2]
                score['cplfw_score'] = acc_list[3]
                score['calfw_score'] = acc_list[4]
                print('score=', score)
                if score['lfw_score'] > highest_acc[0]:
                    highest_acc[0] = score['lfw_score']
                    if score['lfw_score'] >= 0.99:
                        do_save = True
                if score['cfp_score'] > highest_acc[1]:
                    highest_acc[1] = score['cfp_score']
                    if score['cfp_score'] > 0.94:
                        do_save = True
                if score['agedb_score'] > highest_acc[2]:
                    highest_acc[2] = score['agedb_score']
                    if score['agedb_score'] > 0.93:
                        do_save = True
                if score['cplfw_score'] > highest_acc[3]:
                    highest_acc[3] = score['cplfw_score']
                    if score['cplfw_score'] > 0.85:
                        do_save = True
                if score['calfw_score'] > highest_acc[4]:
                    highest_acc[4] = score['calfw_score']
                    if score['calfw_score'] > 0.9:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            arg, aux = model.get_params()
            print('saving', 0)
            mx.model.save_checkpoint(prefix, 0, model.symbol, arg, aux)
            if do_save:
                print('saving', msave)
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print(
                '[%d]score_highest: lfw: %1.5f cfp: %1.5f agedb: %1.5f cplfw: %1.5f calfw: %1.5f'
                % (mbatch, highest_acc[0], highest_acc[1], highest_acc[2],
                   highest_acc[3], highest_acc[4]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    # model.fit(train_dataiter,
    #           begin_epoch=begin_epoch,
    #           num_epoch=end_epoch,
    #           eval_data=val_dataiter,
    #           eval_metric=eval_metrics,
    #           kvstore='device',
    #           optimizer=opt,
    #           # optimizer_params   = optimizer_params,
    #           initializer=initializer,
    #           arg_params=arg_params,
    #           aux_params=aux_params,
    #           allow_missing=True,
    #           batch_end_callback=_batch_callback,
    #           epoch_end_callback=epoch_cb)

    model.bind(data_shapes=train_dataiter.provide_data,
               label_shapes=train_dataiter.provide_label,
               for_training=True,
               force_rebind=False)
    model.init_params(initializer=initializer,
                      arg_params=arg_params,
                      aux_params=aux_params,
                      allow_missing=True,
                      force_init=False)
    model.init_optimizer(kvstore='device', optimizer=opt)

    if not isinstance(eval_metrics, mx.model.metric.EvalMetric):
        eval_metrics = mx.model.metric.create(eval_metrics)
    epoch_eval_metric = copy.deepcopy(eval_metrics)

    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, end_epoch):
        tic = time.time()
        eval_metrics.reset()
        epoch_eval_metric.reset()
        nbatch = 0
        data_iter = iter(train_dataiter)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            model.forward_backward(data_batch)
            model.update()

            if isinstance(data_batch, list):
                model.update_metric(eval_metrics,
                                    [db.label for db in data_batch],
                                    pre_sliced=True)
                model.update_metric(epoch_eval_metric,
                                    [db.label for db in data_batch],
                                    pre_sliced=True)
            else:
                model.update_metric(eval_metrics, data_batch.label)
                model.update_metric(epoch_eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                model.prepare(next_data_batch, sparse_row_id_fn=None)
            except StopIteration:
                end_of_batch = True

            if end_of_batch:
                eval_name_vals = epoch_eval_metric.get_name_value()

            batch_end_params = mx.model.BatchEndParam(epoch=epoch,
                                                      nbatch=nbatch,
                                                      eval_metric=eval_metrics,
                                                      locals=locals())
            _batch_callback(batch_end_params)
            nbatch += 1

        # one epoch of training is finished
        for name, val in eval_name_vals:
            model.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
        toc = time.time()
        model.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = model.get_params()
        model.set_params(arg_params, aux_params)

        train_dataiter.reset()
Example #14
0
def main(args):
  # sys.path.append("/home/gaomingda/insightface/recognition")
  from image_iter import FaceImageIter

  global image_shape
  global net

  print(args)
  ctx = []
  cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
  if len(cvd)>0:
    for i in xrange(len(cvd.split(','))):
      ctx.append(mx.gpu(i))
  if len(ctx)==0:
    ctx = [mx.cpu()]
    print('use cpu')
  else:
    print('gpu num:', len(ctx))
  image_shape = [int(x) for x in args.image_size.split(',')]
  vec = args.model.split(',')
  assert len(vec)>1
  prefix = vec[0]
  epoch = int(vec[1])
  print('loading',prefix, epoch)
  net = edict()
  net.ctx = ctx
  net.sym, net.arg_params, net.aux_params = mx.model.load_checkpoint(prefix, epoch)
  #net.arg_params, net.aux_params = ch_dev(net.arg_params, net.aux_params, net.ctx)
  all_layers = net.sym.get_internals()
  net.sym = all_layers['fc1_output']
  net.model = mx.mod.Module(symbol=net.sym, context=net.ctx, label_names = None)
  net.model.bind(data_shapes=[('data', (args.batch_size, 3, image_shape[1], image_shape[2]))])
  net.model.set_params(net.arg_params, net.aux_params)


  train_dataiter = FaceImageIter(
    batch_size=4,
    data_shape=(3, 112, 112),
    path_imgrec=args.input_data,
    shuffle=True,
    rand_mirror=False,
    mean=None,
    cutoff=False,
    color_jittering=0,
    images_filter=0,
  )
  data_size = train_dataiter.num_samples()
  i = 0
  fstart = 0

  features_all = np.zeros((data_size, 512), dtype=np.float32)
  features_all_flip = np.zeros((data_size, 512), dtype=np.float32)

  # features_all = np.zeros((102, 512), dtype=np.float32)
  # features_all_flip = np.zeros((102, 512), dtype=np.float32)

  data_buff = nd.empty((args.batch_size, 3, 112, 112))
  count = 0
  for i in range(train_dataiter.num_samples()):
    if i%1000==0:
      print("processing ",i)
    label, s, box, landmark = train_dataiter.next_sample()
    img = train_dataiter.imdecode(s)
    img = nd.transpose(img, axes=(2, 0, 1))
    data_buff[count] = img
    count += 1
    if count==args.batch_size:
      embedding = get_feature(data_buff, args.batch_size)
      count = 0
      fend = fstart+embedding.shape[0]

      #print('writing', fstart, fend)
      features_all[fstart:fend,:] = embedding
      # flipped image
      data_buff_flip = mx.ndarray.flip(data=data_buff, axis=3)
      embedding_fliped = get_feature(data_buff_flip, args.batch_size)
      features_all_flip[fstart:fend, :] = embedding_fliped

      fstart = fend

    # if i==102:
    #   break

  if count>0:
    embedding = get_feature(data_buff, args.batch_size)
    fend = fstart+count
    print('writing', fstart, fend)
    features_all[fstart:fend,:] = embedding[:count, :]

    # flipped image
    data_buff_flip = mx.ndarray.flip(data=data_buff, axis=3)
    embedding_fliped = get_feature(data_buff_flip, args.batch_size)
    features_all_flip[fstart:fend, :] = embedding_fliped[:count, :]

  # write_bin(args.output, features_all)
  #os.system("bypy upload %s"%args.output)
  print("save features ...")
  features_all.tofile('train_features_oct200')



  print("save train_features_flip ...")
  features_all_flip.tofile('train_features_flip_oct200')
Example #15
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym = get_symbol(args)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Example #16
0
def train_net(args):
    #_seed = 727
    #random.seed(_seed)
    #np.random.seed(_seed)
    #mx.random.seed(_seed)
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in range(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    if len(args.extra_model_name)==0:
      prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    else:
      prefix = os.path.join(args.models_root, '%s-%s-%s-%s'%(args.network, args.loss, args.dataset, args.extra_model_name), 'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size
    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    data_shape = (args.image_channel,image_size[0],image_size[1])

    num_workers = config.num_workers
    global_num_ctx = num_workers * args.ctx_num
    if config.num_classes%global_num_ctx==0:
      args.ctx_num_classes = config.num_classes//global_num_ctx
    else:
      args.ctx_num_classes = config.num_classes//global_num_ctx+1
    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    #if len(args.partial)==0:
    #  local_classes_range = (0, args.num_classes)
    #else:
    #  _vec = args.partial.split(',')
    #  local_classes_range = (int(_vec[0]), int(_vec[1]))

    #args.partial_num_classes = local_classes_range[1] - local_classes_range[0]
    #args.partial_start = local_classes_range[0]

    print('Called with argument:', args, config)
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    if len(args.pretrained)==0:
      esym = get_symbol_embedding()
      asym = get_symbol_arcface
    else:
      assert False

    if config.count_flops:
      all_layers = esym.get_internals()
      _sym = all_layers['fc1_output']
      FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
      _str = flops_counter.flops_str(FLOPs)
      print('Network FLOPs: %s'%_str)

    if config.num_workers==1:
      from parall_module_local_v1 import ParallModule
    else:
      from parall_module_dist import ParallModule

    model = ParallModule(
        context       = ctx,
        symbol        = esym,
        data_names    = ['data'],
        label_names    = ['softmax_label'],
        asymbol       = asym,
        args = args,
    )
    val_dataiter = None
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = config.data_rand_mirror,
        mean                 = mean,
        cutoff               = config.data_cutoff,
        color_jittering      = config.data_color,
        images_filter        = config.data_images_filter,
    )


    
    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)

    _rescale = 1.0/args.batch_size
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)


    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)


    def ver_test(nbatch):
      results = []
      for i in range(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for step in lr_steps:
        if mbatch==step:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if len(acc_list)>0:
          #lfw_score = acc_list[0]
          #if lfw_score>highest_acc[0]:
          #  highest_acc[0] = lfw_score
          #  if lfw_score>=0.998:
          #    do_save = True
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:
            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score
            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        if is_highest:
          do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3:
          msave = 1

        if do_save:
          print('saving', msave)
          arg, aux = model.get_export_params()
          all_layers = model.symbol.get_internals()
          _sym = all_layers['fc1_output']
          mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if config.max_steps>0 and mbatch>config.max_steps:
        sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = 999999,
        eval_data          = val_dataiter,
        #eval_metric        = eval_metrics,
        kvstore            = args.kvstore,
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Example #17
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    network, num_layers = args.network.split(',')
    print('num_layers', num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    path_imgrecs = []
    path_imglist = None
    args.num_classes = []
    for data_idx, data_dir in enumerate(data_dir_list):
        prop = face_image.load_property(data_dir)
        args.num_classes.append(prop.num_classes)
        image_size = prop.image_size
        if data_idx == 0:
            args.image_h = image_size[0]
            args.image_w = image_size[1]
        else:
            args.image_h = min(args.image_h, image_size[0])
            args.image_w = min(args.image_w, image_size[1])
        print('image_size', image_size)
        assert (args.num_classes[-1] > 0)
        print('num_classes', args.num_classes)
        path_imgrecs.append(os.path.join(data_dir, "train.rec"))

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(network, int(num_layers),
                                                 args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(network, int(num_layers),
                                                 args, arg_params, aux_params)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    ctx_group = dict(zip(['dev%d' % (i + 1) for i in range(len(ctx))], ctx))
    ctx_group['dev0'] = ctx
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        data_names=['data'] if args.loss_type != 6 else ['data', 'margin'],
        group2ctxs=ctx_group)
    val_dataiter = None

    from config import crop
    from config import cutout

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrecs=path_imgrecs,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutout=cutout,
        crop=crop,
        loss_type=args.loss_type,
        #margin_m             = args.margin_m,
        #margin_policy        = args.margin_policy,
        #max_steps            = args.max_steps,
        #data_names           = ['data', 'margin'],
        downsample_back=args.downsample_back,
        motion_blur=args.motion_blur,
    )

    _metric = AccMetric()
    #_metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num

    if len(args.lr_steps) == 0:
        print('Error: lr_steps is not seted')
        sys.exit(0)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_steps,
                                                        factor=0.1,
                                                        base_lr=base_lr)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': base_mom,
        'wd': base_wd,
        'rescale_grad': _rescale,
        'lr_scheduler': lr_scheduler
    }

    #opt = AdaBound()
    #opt = AdaBound(lr=base_lr, wd=base_wd, gamma = 2. / args.max_steps)
    opt = optimizer.SGD(**optimizer_params)

    som = 2000
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            _, issame_list = ver_list[i]
            if all(issame_list):
                fp_rates, fp_dict, thred_dict, recall_dict = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    label_shape=(args.batch_size, len(path_imgrecs)))
                for k in fp_rates:
                    print("[%s] TPR at FPR %.2e[%.2e: %.4f]:\t%.5f" %
                          (ver_name_list[i], k, fp_dict[k], thred_dict[k],
                           recall_dict[k]))
            else:
                acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    10,
                    None,
                    label_shape=(args.batch_size, len(path_imgrecs)))
                print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
                #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
                print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                      (ver_name_list[i], nbatch, acc2, std2))
                results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]

        _cb(param)
        if mbatch % 10000 == 0:
            print('lr-batch-epoch:', opt.learning_rate, param.nbatch,
                  param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              optimizer_params=optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)
Example #18
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        # if args.finetune:
        #     def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
        #         """
        #         symbol: the pretrained network symbol
        #         arg_params: the argument parameters of the pretrained model
        #         num_classes: the number of classes for the fine-tune datasets
        #         layer_name: the layer name before the last fully-connected layer
        #         """
        #         all_layers = symbol.get_internals()
        #         # print(all_layers);exit(0)
        #         for k in arg_params:
        #             if k.startswith('fc'):
        #               print(k)
        #         exit(0)
        #         net = all_layers[layer_name + '_output']
        #         net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
        #         net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
        #         new_args = dict({k: arg_params[k] for k in arg_params if 'fc1' not in k})
        #         return (net, new_args)
        #     sym, arg_params = get_fine_tune_model(sym, arg_params, args.num_classes)

    if args.network[0] == 's':
        data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
    )

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Example #19
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
        if args.loss_type == 10:
            args.per_batch_size = 256
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    assert len(ppatch) == 5

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    if args.loss_type != 12 and args.loss_type != 13:
        assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    args.coco_scale = 0.5 * math.log(float(args.num_classes - 1)) + 3

    # path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")
    val_rec = os.path.join(data_dir, "val.rec")
    if os.path.exists(val_rec) and args.loss_type < 10:
        args.use_val = True
    else:
        val_rec = None
    # args.use_val = False

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    if args.loss_type < 9:
        assert args.images_per_identity == 0
    else:
        if args.images_per_identity == 0:
            if args.loss_type == 11:
                args.images_per_identity = 2
            elif args.loss_type == 10 or args.loss_type == 9:
                args.images_per_identity = 16
            elif args.loss_type == 12 or args.loss_type == 13:
                args.images_per_identity = 5
                assert args.per_batch_size % 3 == 0
        assert args.images_per_identity >= 2
        args.per_identities = int(args.per_batch_size / args.images_per_identity)

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0] == 's':
        data_shape_dict = {'data': (args.per_batch_size,) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    data_extra = None
    hard_mining = False
    triplet_params = None
    coco_mode = False
    if args.loss_type == 10:
        hard_mining = True
        _shape = (args.batch_size, args.per_batch_size)
        data_extra = np.full(_shape, -1.0, dtype=np.float32)
        c = 0
        while c < args.batch_size:
            a = 0
            while a < args.per_batch_size:
                b = a + args.images_per_identity
                data_extra[(c + a):(c + b), a:b] = 1.0
                # print(c+a, c+b, a, b)
                a = b
            c += args.per_batch_size
    elif args.loss_type == 11:
        data_extra = np.zeros((args.batch_size, args.per_identities), dtype=np.float32)
        c = 0
        while c < args.batch_size:
            for i in range(args.per_identities):
                data_extra[c + i][i] = 1.0
            c += args.per_batch_size
    elif args.loss_type == 12 or args.loss_type == 13:
        triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]
    elif args.loss_type == 9:
        coco_mode = True

    label_name = 'softmax_label'
    label_shape = (args.batch_size,)
    if args.output_c2c:
        label_shape = (args.batch_size, 2)
    if data_extra is None:
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
        )
    else:
        data_names = ('data', 'extra')
        # label_name = ''
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
            data_names=data_names,
            label_names=(label_name,),
        )

    if args.use_val:
        val_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=val_rec,
            # path_imglist         = val_path,
            shuffle=False,
            rand_mirror=False,
            mean=mean,
            ctx_num=args.ctx_num,
            data_extra=data_extra,
        )
    else:
        val_dataiter = None

    if len(data_dir_list) == 1 and args.loss_type != 12 and args.loss_type != 13:
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=args.rand_mirror,
            mean=mean,
            cutoff=args.cutoff,
            c2c_threshold=args.c2c_threshold,
            output_c2c=args.output_c2c,
            c2c_mode=args.c2c_mode,
            limit=args.train_limit,
            ctx_num=args.ctx_num,
            images_per_identity=args.images_per_identity,
            data_extra=data_extra,
            hard_mining=hard_mining,
            triplet_params=triplet_params,
            coco_mode=coco_mode,
            mx_model=model,
            label_name=label_name,
        )
    else:
        iter_list = []
        for _data_dir in data_dir_list:
            _path_imgrec = os.path.join(_data_dir, "train.rec")
            _dataiter = FaceImageIter(
                batch_size=args.batch_size,
                data_shape=data_shape,
                path_imgrec=_path_imgrec,
                shuffle=True,
                rand_mirror=args.rand_mirror,
                mean=mean,
                cutoff=args.cutoff,
                c2c_threshold=args.c2c_threshold,
                output_c2c=args.output_c2c,
                c2c_mode=args.c2c_mode,
                limit=args.train_limit,
                ctx_num=args.ctx_num,
                images_per_identity=args.images_per_identity,
                data_extra=data_extra,
                hard_mining=hard_mining,
                triplet_params=triplet_params,
                coco_mode=coco_mode,
                mx_model=model,
                label_name=label_name,
            )
            iter_list.append(_dataiter)
        iter_list.append(_dataiter)
        train_dataiter = FaceImageIterList(iter_list)

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2)  # resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0 / args.ctx_num
    if args.noise_sgd > 0.0:
        print('use noise sgd')
        opt = NoiseSGD(scale=args.noise_sgd, learning_rate=base_lr, momentum=base_mom, wd=base_wd,
                       rescale_grad=_rescale)
    else:
        opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    if args.loss_type == 12 or args.loss_type == 13:
        som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10,
                                                                               data_extra, label_shape)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    def val_test():
        acc = AccMetric()
        val_metric = mx.metric.create(acc)
        val_metric.reset()
        val_dataiter.reset()
        for i, eval_batch in enumerate(val_dataiter):
            model.forward(eval_batch, is_train=False)
            model.update_metric(val_metric, eval_batch.label)
        acc_value = val_metric.get_name_value()[0][1]
        print('VACC: %f' % (acc_value))

    highest_acc = [0.0, 0.0]  # lfw and target
    # for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in range(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            # for i in range(len(acc_list)):
            #  acc = acc_list[i]
            #  if acc>=highest_acc[i]:
            #    highest_acc[i] = acc
            #    if lfw_score>=0.99:
            #      do_save = True
            # if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
            #  do_save = True
            if do_save:
                print('saving', msave)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
                # if acc>=highest_acc[0]:
                #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
                #  X = np.concatenate(embeddings_list, axis=0)
                #  print('saving lfw npy', X.shape)
                #  np.save(lfw_npy, X)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(args.beta_min, args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    # epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    # def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              # optimizer_params   = optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)