Exemple #1
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
      if args.loss_type==10:
        args.per_batch_size = 256
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    assert len(ppatch)==5


    os.environ['BETA'] = str(args.beta)
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None
    prop = face_image.load_property(args.data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert(args.num_classes>0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(args.data_dir, "train.rec")
    val_rec = os.path.join(args.data_dir, "val.rec")
    if os.path.exists(val_rec) and args.loss_type<10:
      args.use_val = True
    else:
      val_rec = None
    #args.num_classes = 10572 #webface
    #args.num_classes = 81017
    #args.num_classes = 82395



    if args.loss_type==1 and args.num_classes>40000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    if args.loss_type==11:
      args.images_per_identity = 2
    elif args.loss_type==10:
      args.images_per_identity = 16

    if args.loss_type<10:
      assert args.images_per_identity==0
    else:
      assert args.images_per_identity>=2
      args.per_identities = int(args.per_batch_size/args.images_per_identity)

    print('Called with argument:', args)

    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None

    if args.use_val:
      val_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = val_rec,
          #path_imglist         = val_path,
          shuffle              = False,
          rand_mirror          = False,
          mean                 = mean,
      )
    else:
      val_dataiter = None



    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
      vec = args.pretrained.split(',')
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    data_extra = None
    hard_mining = False
    if args.loss_type==10:
      hard_mining = True
      _shape = (args.batch_size, args.per_batch_size)
      data_extra = np.full(_shape, -1.0, dtype=np.float32)
      c = 0
      while c<args.batch_size:
        a = 0
        while a<args.per_batch_size:
          b = a+args.images_per_identity
          data_extra[(c+a):(c+b),a:b] = 1.0
          #print(c+a, c+b, a, b)
          a = b
        c += args.per_batch_size
    elif args.loss_type==11:
      data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)
      c = 0
      while c<args.batch_size:
        for i in xrange(args.per_identities):
          data_extra[c+i][i] = 1.0
        c+=args.per_batch_size

    label_name = 'softmax_label'
    if data_extra is None:
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
      )
    else:
      data_names = ('data', 'extra')
      #label_name = ''
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
          data_names    = data_names,
          label_names   = (label_name,),
      )


    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = True,
        mean                 = mean,
        ctx_num              = args.ctx_num,
        images_per_identity  = args.images_per_identity,
        data_extra           = data_extra,
        hard_mining          = hard_mining,
        mx_model             = model,
        label_name           = label_name,
    )

    if args.loss_type<10:
      _metric = AccMetric()
    else:
      _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0]=='r':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0/args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, 20)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(args.data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, data_extra)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    def val_test():
      acc = AccMetric()
      val_metric = mx.metric.create(acc)
      val_metric.reset()
      val_dataiter.reset()
      for i, eval_batch in enumerate(val_dataiter):
        model.forward(eval_batch, is_train=False)
        model.update_metric(val_metric, eval_batch.label)
      acc_value = val_metric.get_name_value()[0][1]
      print('VACC: %f'%(acc_value))


    highest_acc = []
    for i in xrange(len(ver_list)):
      highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 70000]
      if args.loss_type==1:
        lr_steps = [50000, 70000, 80000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        lfw_score = acc_list[0]
        for i in xrange(len(acc_list)):
          acc = acc_list[i]
          if acc>=highest_acc[i]:
            highest_acc[i] = acc
            if lfw_score>=0.99:
              do_save = True
        if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
          do_save = True
        if do_save:
          print('saving', msave, acc)
          if val_dataiter is not None:
            val_test()
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
          #if acc>=highest_acc[0]:
          #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
          #  X = np.concatenate(embeddings_list, axis=0)
          #  print('saving lfw npy', X.shape)
          #  np.save(lfw_npy, X)
        #print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[0]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta
      else:
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None



    #def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = "%s-%s-p%s" % (args.prefix, args.network, args.patch)
    end_epoch = args.end_epoch
    pretrained = args.pretrained
    load_epoch = args.load_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
        if args.network[0] == 'r':
            args.per_batch_size = 128
        else:
            if args.num_layers >= 64:
                args.per_batch_size = 120
        if args.ctx_num == 2:
            args.per_batch_size *= 2
        elif args.ctx_num == 3:
            args.per_batch_size = 170
        if args.network[0] == 'm':
            args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    image_size = [int(x) for x in args.image_size.split(',')]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    assert len(ppatch) == 5
    #if args.patch%2==1:
    #  args.image_channel = 1

    #os.environ['GLOBAL_STEP'] = "0"
    os.environ['BETA'] = str(args.beta)
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None

    #path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
    #path_imglist = "/raid5data/dplearn/faceinsight_align_webface_clean.lst.new"
    for line in open(os.path.join(args.data_dir, 'property')):
        args.num_classes = int(line.strip())
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(args.data_dir, "train.rec")
    val_rec = os.path.join(args.data_dir, "val.rec")
    if os.path.exists(val_rec):
        args.use_val = True
    else:
        val_rec = None
    #args.num_classes = 10572 #webface
    #args.num_classes = 81017
    #args.num_classes = 82395

    if args.loss_type == 1 and args.num_classes > 40000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    #mean = [127.5,127.5,127.5]
    mean = None

    if args.use_val:
        val_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=val_rec,
            #path_imglist         = val_path,
            shuffle=False,
            rand_mirror=False,
            mean=mean,
        )
    else:
        val_dataiter = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = 0.9
    if not args.retrain:
        #load and initialize params
        #print(pretrained)
        #_, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        #arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        data_shape_dict = {
            'data': (args.batch_size, ) + data_shape,
            'softmax_label': (args.batch_size, )
        }
        if args.network[0] == 's':
            arg_params, aux_params = spherenet.init_weights(
                sym, data_shape_dict, args.num_layers)
        elif args.network[0] == 'm':
            arg_params, aux_params = marginalnet.init_weights(
                sym, data_shape_dict, args.num_layers)
        #resnet_dcn.init_weights(sym, data_shape_dict, arg_params, aux_params)
    else:
        #sym, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            pretrained, load_epoch)
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        #begin_epoch = load_epoch
        #end_epoch = begin_epoch+10
        #base_wd = 0.00005

    if args.loss_type != 10:
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
        )
    else:
        data_names = ('data', 'extra')
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
            data_names=data_names,
        )

    if args.loss_type <= 9:
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=True,
            mean=mean,
        )
    elif args.loss_type == 10:
        train_dataiter = FaceImageIter4(
            batch_size=args.batch_size,
            ctx_num=args.ctx_num,
            images_per_identity=args.images_per_identity,
            data_shape=data_shape,
            path_imglist=path_imglist,
            shuffle=True,
            rand_mirror=True,
            mean=mean,
            patch=ppatch,
            use_extra=True,
            model=model,
        )
    elif args.loss_type == 11:
        train_dataiter = FaceImageIter5(
            batch_size=args.batch_size,
            ctx_num=args.ctx_num,
            images_per_identity=args.images_per_identity,
            data_shape=data_shape,
            path_imglist=path_imglist,
            shuffle=True,
            rand_mirror=True,
            mean=mean,
            patch=ppatch,
        )
    #args.epoch_size = int(math.ceil(train_dataiter.num_samples()/args.batch_size))

    #_dice = DiceMetric()
    _acc = AccMetric()
    eval_metrics = [mx.metric.create(_acc)]

    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    #for child_metric in [fcn_loss_metric]:
    #    eval_metrics.add(child_metric)

    # callback
    #batch_end_callback = callback.Speedometer(input_batch_size, frequent=args.frequent)
    #epoch_end_callback = mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True)

    # decide learning rate
    #lr_step = '10,20,30'
    #train_size = 4848
    #nrof_batch_in_epoch = int(train_size/input_batch_size)
    #print('nrof_batch_in_epoch:', nrof_batch_in_epoch)
    #lr_factor = 0.1
    #lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    #lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    #lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    #lr_iters = [int(epoch * train_size / batch_size) for epoch in lr_epoch_diff]
    #print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    #lr_scheduler = MultiFactorScheduler(lr_iters, lr_factor)

    # optimizer
    #optimizer_params = {'momentum': 0.9,
    #                    'wd': 0.0005,
    #                    'learning_rate': base_lr,
    #                    'rescale_grad': 1.0,
    #                    'clip_gradient': None}
    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    #_rescale = 1.0
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    #opt = optimizer.RMSProp(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
    #opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
    #opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=1.0)
    _cb = mx.callback.Speedometer(args.batch_size, 10)

    lfw_dir = os.path.join(args.data_dir, 'lfw')
    lfw_set = lfw.load_dataset(lfw_dir, image_size)

    def lfw_test(nbatch):
        acc1, std1, acc2, std2, xnorm, embeddings_list = lfw.test(
            lfw_set, model, args.batch_size)
        print('[%d]XNorm: %f' % (nbatch, xnorm))
        print('[%d]Accuracy: %1.5f+-%1.5f' % (nbatch, acc1, std1))
        print('[%d]Accuracy-Flip: %1.5f+-%1.5f' % (nbatch, acc2, std2))
        return acc2, embeddings_list

    def val_test():
        acc = AccMetric()
        val_metric = mx.metric.create(acc)
        val_metric.reset()
        val_dataiter.reset()
        for i, eval_batch in enumerate(val_dataiter):
            model.forward(eval_batch, is_train=False)
            model.update_metric(val_metric, eval_batch.label)
        acc_value = val_metric.get_name_value()[0][1]
        print('VACC: %f' % (acc_value))

    #global_step = 0
    highest_acc = [0.0]
    last_save_acc = [0.0]
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        #lr_steps = [40000, 70000, 90000]
        lr_steps = [40000, 60000, 80000]
        if args.loss_type == 1:
            lr_steps = [100000, 140000, 160000]
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        #os.environ['GLOBAL_STEP'] = str(mbatch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc, embeddings_list = lfw_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if acc >= highest_acc[0]:
                highest_acc[0] = acc
                if acc >= 0.996:
                    do_save = True
            if mbatch > lr_steps[-1] and mbatch % 10000 == 0:
                do_save = True
            if do_save:
                print('saving', msave, acc)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
                if acc >= highest_acc[0]:
                    lfw_npy = "%s-lfw-%04d" % (prefix, msave)
                    X = np.concatenate(embeddings_list, axis=0)
                    print('saving lfw npy', X.shape)
                    np.save(lfw_npy, X)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[0]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
            #_beta = max(args.beta_min, args.beta*math.pow(0.7, move//500))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    #def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    path_imgrec = os.path.join(data_dir, "train.rec")
    path_imgrec_val = os.path.join(data_dir, "val.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
    )
    val_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec_val,
        shuffle=False,
        rand_mirror=False,
        mean=mean,
    )

    metric = mx.metric.CompositeEvalMetric(
        [AccMetric(), MAEMetric(), CUMMetric()])

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    #opt = optimizer.Nadam(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)
    lr_steps = [int(x) for x in args.lr_steps.split(',')]

    global_step = [0]
    save_step = [0]

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    def _epoch_callback(epoch, symbol, arg_params, aux_params):
        save_step[0] += 1
        msave = save_step[0]
        do_save = False
        if args.ckpt == 0:
            do_save = False
        elif args.ckpt == 2:
            do_save = True
        if do_save:
            print('saving %s' % msave)
            arg, aux = model.get_params()
            all_layers = model.symbol.get_internals()
            _sym = all_layers['fc1_output']
            mx.model.save_checkpoint(args.prefix, msave, _sym, arg, aux)

    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
    print('start fitting')

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=metric,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=_epoch_callback)