Exemple #1
0
def main(args):
    _seed = 727
    random.seed(_seed)
    np.random.seed(_seed)
    mx.random.seed(_seed)
    ctx = []
    #   cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    #   if len(cvd)>0:
    #     for i in range(len(cvd.split(','))):
    #       ctx.append(mx.gpu(i))
    #   if len(ctx)==0:
    #     ctx = [mx.cpu()]
    #     print('use cpu')
    #   else:
    #     print('gpu num:', len(ctx))
    ctx = [mx.cpu()]
    args.ctx_num = len(ctx)

    args.batch_size = args.per_batch_size * args.ctx_num
    config.per_batch_size = args.per_batch_size

    print('Call with', args, config)
    train_iter = FaceSegIter(
        path_imgrec=os.path.join(config.dataset_path, 'train.rec'),
        batch_size=args.batch_size,
        per_batch_size=args.per_batch_size,
        aug_level=1,
        exf=args.exf,
        args=args,
    )

    data_shape = train_iter.get_data_shape()
    #label_shape = train_iter.get_label_shape()
    sym = sym_heatmap.get_symbol(num_classes=config.num_classes)
    if len(args.pretrained) == 0:
        #data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
        data_shape_dict = train_iter.get_shape_dict()
        arg_params, aux_params = sym_heatmap.init_weights(sym, data_shape_dict)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        #sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        label_names=train_iter.get_label_names(),
    )
    #lr = 1.0e-3
    #lr = 2.5e-4
    _rescale_grad = 1.0 / args.ctx_num
    #_rescale_grad = 1.0/args.batch_size
    #lr = args.lr
    #opt = optimizer.Nadam(learning_rate=args.lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
    if args.optimizer == 'onadam':
        opt = ONadam(learning_rate=args.lr,
                     wd=args.wd,
                     rescale_grad=_rescale_grad,
                     clip_gradient=5.0)
    elif args.optimizer == 'nadam':
        opt = optimizer.Nadam(learning_rate=args.lr,
                              rescale_grad=_rescale_grad)
    elif args.optimizer == 'rmsprop':
        opt = optimizer.RMSProp(learning_rate=args.lr,
                                rescale_grad=_rescale_grad)
    elif args.optimizer == 'adam':
        opt = optimizer.Adam(learning_rate=args.lr, rescale_grad=_rescale_grad)
    else:
        opt = optimizer.SGD(learning_rate=args.lr,
                            momentum=0.9,
                            wd=args.wd,
                            rescale_grad=_rescale_grad)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)
    _metric = LossValueMetric()
    #_metric = NMEMetric()
    #_metric2 = AccMetric()
    #eval_metrics = [_metric, _metric2]
    eval_metrics = [_metric]
    lr_steps = [int(x) for x in args.lr_step.split(',')]
    print('lr-steps', lr_steps)
    global_step = [0]

    def val_test():
        all_layers = sym.get_internals()
        vsym = all_layers['heatmap_output']
        vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names=None)
        #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
        vmodel.bind(data_shapes=[('data', (args.batch_size, ) + data_shape)])
        arg_params, aux_params = model.get_params()
        vmodel.set_params(arg_params, aux_params)
        for target in config.val_targets:
            _file = os.path.join(config.dataset_path, '%s.rec' % target)
            if not os.path.exists(_file):
                continue
            val_iter = FaceSegIter(
                path_imgrec=_file,
                batch_size=args.batch_size,
                #batch_size = 4,
                aug_level=0,
                args=args,
            )
            _metric = NMEMetric()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            val_iter.reset()
            for i, eval_batch in enumerate(val_iter):
                #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
                batch_data = mx.io.DataBatch(eval_batch.data)
                model.forward(batch_data, is_train=False)
                model.update_metric(val_metric, eval_batch.label)
            nme_value = val_metric.get_name_value()[0][1]
            print('[%d][%s]NME: %f' % (global_step[0], target, nme_value))

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.2
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        if mbatch > 0 and mbatch % args.verbose == 0:
            val_test()
            if args.ckpt == 1:
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
        if mbatch == lr_steps[-1]:
            if args.ckpt == 2:
                #msave = mbatch//args.verbose
                msave = 1
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
            sys.exit(0)

    train_iter = mx.io.PrefetchingIter(train_iter)

    model.fit(
        train_iter,
        begin_epoch=0,
        num_epoch=9999,
        #eval_data          = val_iter,
        eval_data=None,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=None,
    )
Exemple #2
0
def main(args):
    _seed = 727
    random.seed(_seed)
    np.random.seed(_seed)
    mx.random.seed(_seed)
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    # ctx = [mx.gpu(0)]
    args.ctx_num = len(ctx)

    args.batch_size = args.per_batch_size * args.ctx_num

    print('Call with', args)
    train_iter = FaceSegIter(path_imgrec=os.path.join(args.data_dir, 'train.rec'),
                             batch_size=args.batch_size,
                             per_batch_size=args.per_batch_size,
                             aug_level=1,
                             use_coherent=args.use_coherent,
                             args=args,
                             )
    targets = ['ibug', 'cofw_testset', '300W', 'AFLW2000-3D']

    data_shape = train_iter.get_data_shape()
    # label_shape = train_iter.get_label_shape()
    sym = hg.get_symbol(num_classes=args.num_classes, binarize=args.binarize, label_size=args.output_label_size,
                        input_size=args.input_img_size, use_coherent=args.use_coherent, use_dla=args.use_dla,
                        use_N=args.use_N, use_DCN=args.use_DCN, per_batch_size=args.per_batch_size)
    if len(args.pretrained) == 0:
        # data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
        data_shape_dict = train_iter.get_shape_dict()
        arg_params, aux_params = hg.init_weights(sym, data_shape_dict)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        # sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        label_names=train_iter.get_label_names(),
    )
    # lr = 1.0e-3
    # lr = 2.5e-4
    lr = args.lr
    # _rescale_grad = 1.0
    _rescale_grad = 1.0 / args.ctx_num
    # lr = args.lr
    # opt = optimizer.SGD(learning_rate=lr, momentum=0.9, wd=5.e-4, rescale_grad=_rescale_grad)
    # opt = optimizer.Adam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
    opt = optimizer.Nadam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
    # opt = optimizer.RMSProp(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
    _cb = mx.callback.Speedometer(args.batch_size, 10)
    _metric = LossValueMetric()
    # _metric2 = AccMetric()
    # eval_metrics = [_metric, _metric2]
    eval_metrics = [_metric]

    # lr_steps = [40000,60000,80000]
    # lr_steps = [12000,18000,22000]
    if len(args.lr_steps) == 0:
        lr_steps = [16000, 24000, 30000]
        # lr_steps = [14000,24000,30000]
        # lr_steps = [5000,10000]
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    _a = 40 // args.batch_size
    for i in xrange(len(lr_steps)):
        lr_steps[i] *= _a
    print('lr-steps', lr_steps)
    global_step = [0]

    def val_test():
        all_layers = sym.get_internals()
        vsym = all_layers['heatmap_output']
        vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names=None)
        # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
        vmodel.bind(data_shapes=[('data', (args.batch_size,) + data_shape)])
        arg_params, aux_params = model.get_params()
        vmodel.set_params(arg_params, aux_params)
        for target in targets:
            _file = os.path.join(args.data_dir, '%s.rec' % target)
            if not os.path.exists(_file):
                continue
            val_iter = FaceSegIter(path_imgrec=_file,
                                   batch_size=args.batch_size,
                                   # batch_size = 4,
                                   aug_level=0,
                                   args=args,
                                   )
            _metric = NMEMetric2()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            val_iter.reset()
            for i, eval_batch in enumerate(val_iter):
                # print(eval_batch.data[0].shape, eval_batch.label[0].shape)
                batch_data = mx.io.DataBatch(eval_batch.data)
                model.forward(batch_data, is_train=False)
                model.update_metric(val_metric, eval_batch.label)
            nme_value = val_metric.get_name_value()[0][1]
            print('[%d][%s]NME: %f' % (global_step[0], target, nme_value))

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.2
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        if mbatch > 0 and mbatch % args.verbose == 0:
            val_test()
            if args.ckpt > 0:
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg, aux)

    model.fit(train_iter,
              begin_epoch=0,
              num_epoch=args.end_epoch,
              # eval_data          = val_iter,
              eval_data=None,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=None,
              )
Exemple #3
0
def main(args):
    _seed = 727
    random.seed(_seed)
    np.random.seed(_seed)
    mx.random.seed(_seed)
    ctx = []
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    #ctx = [mx.gpu(0)]
    args.ctx_num = len(ctx)

    args.batch_size = args.per_batch_size * args.ctx_num
    config.per_batch_size = args.per_batch_size

    print('Call with', args, config)
    train_iter = FaceSegIter(
        path_imgrec=os.path.join(config.dataset_path, 'train.rec'),
        batch_size=args.batch_size,
        per_batch_size=args.per_batch_size,
        aug_level=1,
        exf=args.exf,
        args=args,
    )

    data_shape, data_size = train_iter.get_data_shape()
    #label_shape = train_iter.get_label_shape()
    sym = eval(config.network).get_symbol(num_classes=config.num_classes)
    if len(args.pretrained) == 0:
        #data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
        data_shape_dict = train_iter.get_shape_dict()
        arg_params, aux_params = init_weights(sym, data_shape_dict)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        #sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        label_names=train_iter.get_label_names(),
    )
    #lr = 1.0e-3
    #lr = 2.5e-4
    _rescale_grad = 1.0 / args.ctx_num
    #_rescale_grad = 1.0/args.batch_size
    #lr = args.lr
    #opt = optimizer.Nadam(learning_rate=args.lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
    if args.optimizer == 'onadam':
        opt = ONadam(learning_rate=args.lr,
                     wd=args.wd,
                     rescale_grad=_rescale_grad,
                     clip_gradient=5.0)
    elif args.optimizer == 'nadam':
        opt = optimizer.Nadam(learning_rate=args.lr,
                              rescale_grad=_rescale_grad)
    elif args.optimizer == 'rmsprop':
        opt = optimizer.RMSProp(learning_rate=args.lr,
                                rescale_grad=_rescale_grad)
    elif args.optimizer == 'adam':
        opt = optimizer.Adam(learning_rate=args.lr, rescale_grad=_rescale_grad)
    else:
        opt = optimizer.SGD(learning_rate=args.lr,
                            momentum=0.9,
                            wd=args.wd,
                            rescale_grad=_rescale_grad)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)
    _metric = LossValueMetric()
    #_metric = NMEMetric()
    #_metric2 = AccMetric()
    #eval_metrics = [_metric, _metric2]
    eval_metrics = [_metric]
    lr_epoch_steps = [int(x) for x in args.lr_epoch_step.split(',')]
    print('lr-epoch-steps', lr_epoch_steps)

    global_step = [0]
    highest_acc = [1.0, 1.0]

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        mepoch = mbatch * args.batch_size // data_size
        pre = mbatch * args.batch_size % data_size
        is_highest = False
        for _lr in lr_epoch_steps[0:-1]:
            if mepoch == _lr and pre < args.batch_size:
                opt.lr *= 0.2
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr:', opt.lr, 'batch:', param.nbatch, 'epoch:', param.epoch)
        if mbatch > 0 and mbatch % args.verbose == 0:
            acc_list = val_test(sym, model, ctx, data_shape, global_step)
            score = np.mean(acc_list)
            if acc_list[0] < highest_acc[0]:  # ibug
                is_highest = True
                highest_acc[0] = acc_list[0]
            if score < highest_acc[1]:  # mean
                is_highest = True
                highest_acc[1] = score
            if args.ckpt == 1 and is_highest == True:
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
        if mepoch == lr_epoch_steps[-1]:
            if args.ckpt == 1:
                acc_list = val_test(sym, model, ctx, data_shape, global_step)
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
            sys.exit(0)

    train_iter = mx.io.PrefetchingIter(train_iter)

    model.fit(
        train_iter,
        begin_epoch=0,
        num_epoch=9999,
        #eval_data          = val_iter,
        eval_data=None,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=None,
    )
Exemple #4
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0] == 's':
        data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
    )
    val_rec = os.path.join(data_dir, "val.rec")
    val_iter = None
    if os.path.exists(val_rec):
        val_iter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=val_rec,
            shuffle=False,
            rand_mirror=False,
            mean=mean,
        )

    eval_metrics = []
    if USE_FR:
        _metric = AccMetric(pred_idx=1)
        eval_metrics.append(_metric)
        if USE_GENDER:
            _metric = AccMetric(pred_idx=2, name='gender')
            eval_metrics.append(_metric)
    elif USE_GENDER:
        _metric = AccMetric(pred_idx=1, name='gender')
        eval_metrics.append(_metric)
    if USE_AGE:
        _metric = MAEMetric()
        eval_metrics.append(_metric)
        _metric = CUMMetric()
        eval_metrics.append(_metric)

    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    #opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    opt = optimizer.Nadam(learning_rate=base_lr,
                          wd=base_wd,
                          rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    def val_test():
        _metric = MAEMetric()
        val_metric = mx.metric.create(_metric)
        val_metric.reset()
        _metric2 = CUMMetric()
        val_metric2 = mx.metric.create(_metric2)
        val_metric2.reset()
        val_iter.reset()
        for i, eval_batch in enumerate(val_iter):
            model.forward(eval_batch, is_train=False)
            model.update_metric(val_metric, eval_batch.label)
            model.update_metric(val_metric2, eval_batch.label)
        _value = val_metric.get_name_value()[0][1]
        print('MAE: %f' % (_value))
        _value = val_metric2.get_name_value()[0][1]
        print('CUM: %f' % (_value))

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            if val_iter is not None:
                val_test()
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=None,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)