예제 #1
0
 def val_test():
     all_layers = sym.get_internals()
     vsym = all_layers['heatmap_output']
     vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names=None)
     #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
     vmodel.bind(data_shapes=[('data', (args.batch_size, ) + data_shape)])
     arg_params, aux_params = model.get_params()
     vmodel.set_params(arg_params, aux_params)
     for target in config.val_targets:
         _file = os.path.join(config.dataset_path, '%s.rec' % target)
         if not os.path.exists(_file):
             continue
         val_iter = FaceSegIter(
             path_imgrec=_file,
             batch_size=args.batch_size,
             #batch_size = 4,
             aug_level=0,
             args=args,
         )
         _metric = NMEMetric()
         val_metric = mx.metric.create(_metric)
         val_metric.reset()
         val_iter.reset()
         for i, eval_batch in enumerate(val_iter):
             #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
             batch_data = mx.io.DataBatch(eval_batch.data)
             model.forward(batch_data, is_train=False)
             model.update_metric(val_metric, eval_batch.label)
         nme_value = val_metric.get_name_value()[0][1]
         print('[%d][%s]NME: %f' % (global_step[0], target, nme_value))
예제 #2
0
 def val_test():
   all_layers = sym.get_internals()
   vsym = all_layers['heatmap_output']
   vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names = None)
   #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
   vmodel.bind(data_shapes=[('data', (args.batch_size,)+data_shape)])
   arg_params, aux_params = model.get_params()
   vmodel.set_params(arg_params, aux_params)
   for target in targets:
       _file = os.path.join(args.data_dir, '%s.rec'%target)
       if not os.path.exists(_file):
           continue
       val_iter = FaceSegIter(path_imgrec = _file,
         batch_size = args.batch_size,
         #batch_size = 4,
         aug_level = 0,
         args = args,
         )
       _metric = NMEMetric2()
       val_metric = mx.metric.create(_metric)
       val_metric.reset()
       val_iter.reset()
       for i, eval_batch in enumerate(val_iter):
         #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
         batch_data = mx.io.DataBatch(eval_batch.data)
         model.forward(batch_data, is_train=False)
         model.update_metric(val_metric, eval_batch.label)
       nme_value = val_metric.get_name_value()[0][1]
       print('[%d][%s]NME: %f'%(global_step[0], target, nme_value))
예제 #3
0
def main(args):
    _seed = 727
    random.seed(_seed)
    np.random.seed(_seed)
    mx.random.seed(_seed)
    ctx = []
    #   cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    #   if len(cvd)>0:
    #     for i in range(len(cvd.split(','))):
    #       ctx.append(mx.gpu(i))
    #   if len(ctx)==0:
    #     ctx = [mx.cpu()]
    #     print('use cpu')
    #   else:
    #     print('gpu num:', len(ctx))
    ctx = [mx.cpu()]
    args.ctx_num = len(ctx)

    args.batch_size = args.per_batch_size * args.ctx_num
    config.per_batch_size = args.per_batch_size

    print('Call with', args, config)
    train_iter = FaceSegIter(
        path_imgrec=os.path.join(config.dataset_path, 'train.rec'),
        batch_size=args.batch_size,
        per_batch_size=args.per_batch_size,
        aug_level=1,
        exf=args.exf,
        args=args,
    )

    data_shape = train_iter.get_data_shape()
    #label_shape = train_iter.get_label_shape()
    sym = sym_heatmap.get_symbol(num_classes=config.num_classes)
    if len(args.pretrained) == 0:
        #data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
        data_shape_dict = train_iter.get_shape_dict()
        arg_params, aux_params = sym_heatmap.init_weights(sym, data_shape_dict)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        #sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        label_names=train_iter.get_label_names(),
    )
    #lr = 1.0e-3
    #lr = 2.5e-4
    _rescale_grad = 1.0 / args.ctx_num
    #_rescale_grad = 1.0/args.batch_size
    #lr = args.lr
    #opt = optimizer.Nadam(learning_rate=args.lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
    if args.optimizer == 'onadam':
        opt = ONadam(learning_rate=args.lr,
                     wd=args.wd,
                     rescale_grad=_rescale_grad,
                     clip_gradient=5.0)
    elif args.optimizer == 'nadam':
        opt = optimizer.Nadam(learning_rate=args.lr,
                              rescale_grad=_rescale_grad)
    elif args.optimizer == 'rmsprop':
        opt = optimizer.RMSProp(learning_rate=args.lr,
                                rescale_grad=_rescale_grad)
    elif args.optimizer == 'adam':
        opt = optimizer.Adam(learning_rate=args.lr, rescale_grad=_rescale_grad)
    else:
        opt = optimizer.SGD(learning_rate=args.lr,
                            momentum=0.9,
                            wd=args.wd,
                            rescale_grad=_rescale_grad)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)
    _metric = LossValueMetric()
    #_metric = NMEMetric()
    #_metric2 = AccMetric()
    #eval_metrics = [_metric, _metric2]
    eval_metrics = [_metric]
    lr_steps = [int(x) for x in args.lr_step.split(',')]
    print('lr-steps', lr_steps)
    global_step = [0]

    def val_test():
        all_layers = sym.get_internals()
        vsym = all_layers['heatmap_output']
        vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names=None)
        #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
        vmodel.bind(data_shapes=[('data', (args.batch_size, ) + data_shape)])
        arg_params, aux_params = model.get_params()
        vmodel.set_params(arg_params, aux_params)
        for target in config.val_targets:
            _file = os.path.join(config.dataset_path, '%s.rec' % target)
            if not os.path.exists(_file):
                continue
            val_iter = FaceSegIter(
                path_imgrec=_file,
                batch_size=args.batch_size,
                #batch_size = 4,
                aug_level=0,
                args=args,
            )
            _metric = NMEMetric()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            val_iter.reset()
            for i, eval_batch in enumerate(val_iter):
                #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
                batch_data = mx.io.DataBatch(eval_batch.data)
                model.forward(batch_data, is_train=False)
                model.update_metric(val_metric, eval_batch.label)
            nme_value = val_metric.get_name_value()[0][1]
            print('[%d][%s]NME: %f' % (global_step[0], target, nme_value))

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.2
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        if mbatch > 0 and mbatch % args.verbose == 0:
            val_test()
            if args.ckpt == 1:
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
        if mbatch == lr_steps[-1]:
            if args.ckpt == 2:
                #msave = mbatch//args.verbose
                msave = 1
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
            sys.exit(0)

    train_iter = mx.io.PrefetchingIter(train_iter)

    model.fit(
        train_iter,
        begin_epoch=0,
        num_epoch=9999,
        #eval_data          = val_iter,
        eval_data=None,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=None,
    )
예제 #4
0
def main(args):
    _seed = 727
    random.seed(_seed)
    np.random.seed(_seed)
    mx.random.seed(_seed)
    ctx = []
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    #ctx = [mx.gpu(0)]
    args.ctx_num = len(ctx)

    args.batch_size = args.per_batch_size * args.ctx_num
    config.per_batch_size = args.per_batch_size

    print('Call with', args, config)
    train_iter = FaceSegIter(
        path_imgrec=os.path.join(config.dataset_path, 'train.rec'),
        batch_size=args.batch_size,
        per_batch_size=args.per_batch_size,
        aug_level=1,
        exf=args.exf,
        args=args,
    )

    data_shape, data_size = train_iter.get_data_shape()
    #label_shape = train_iter.get_label_shape()
    sym = eval(config.network).get_symbol(num_classes=config.num_classes)
    if len(args.pretrained) == 0:
        #data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
        data_shape_dict = train_iter.get_shape_dict()
        arg_params, aux_params = init_weights(sym, data_shape_dict)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        #sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        label_names=train_iter.get_label_names(),
    )
    #lr = 1.0e-3
    #lr = 2.5e-4
    _rescale_grad = 1.0 / args.ctx_num
    #_rescale_grad = 1.0/args.batch_size
    #lr = args.lr
    #opt = optimizer.Nadam(learning_rate=args.lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
    if args.optimizer == 'onadam':
        opt = ONadam(learning_rate=args.lr,
                     wd=args.wd,
                     rescale_grad=_rescale_grad,
                     clip_gradient=5.0)
    elif args.optimizer == 'nadam':
        opt = optimizer.Nadam(learning_rate=args.lr,
                              rescale_grad=_rescale_grad)
    elif args.optimizer == 'rmsprop':
        opt = optimizer.RMSProp(learning_rate=args.lr,
                                rescale_grad=_rescale_grad)
    elif args.optimizer == 'adam':
        opt = optimizer.Adam(learning_rate=args.lr, rescale_grad=_rescale_grad)
    else:
        opt = optimizer.SGD(learning_rate=args.lr,
                            momentum=0.9,
                            wd=args.wd,
                            rescale_grad=_rescale_grad)
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)
    _metric = LossValueMetric()
    #_metric = NMEMetric()
    #_metric2 = AccMetric()
    #eval_metrics = [_metric, _metric2]
    eval_metrics = [_metric]
    lr_epoch_steps = [int(x) for x in args.lr_epoch_step.split(',')]
    print('lr-epoch-steps', lr_epoch_steps)

    global_step = [0]
    highest_acc = [1.0, 1.0]

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        mepoch = mbatch * args.batch_size // data_size
        pre = mbatch * args.batch_size % data_size
        is_highest = False
        for _lr in lr_epoch_steps[0:-1]:
            if mepoch == _lr and pre < args.batch_size:
                opt.lr *= 0.2
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr:', opt.lr, 'batch:', param.nbatch, 'epoch:', param.epoch)
        if mbatch > 0 and mbatch % args.verbose == 0:
            acc_list = val_test(sym, model, ctx, data_shape, global_step)
            score = np.mean(acc_list)
            if acc_list[0] < highest_acc[0]:  # ibug
                is_highest = True
                highest_acc[0] = acc_list[0]
            if score < highest_acc[1]:  # mean
                is_highest = True
                highest_acc[1] = score
            if args.ckpt == 1 and is_highest == True:
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
        if mepoch == lr_epoch_steps[-1]:
            if args.ckpt == 1:
                acc_list = val_test(sym, model, ctx, data_shape, global_step)
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg,
                                         aux)
            sys.exit(0)

    train_iter = mx.io.PrefetchingIter(train_iter)

    model.fit(
        train_iter,
        begin_epoch=0,
        num_epoch=9999,
        #eval_data          = val_iter,
        eval_data=None,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=None,
    )
예제 #5
0
def main(args):
    _seed = 727
    random.seed(_seed)
    np.random.seed(_seed)
    mx.random.seed(_seed)
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    # ctx = [mx.gpu(0)]
    args.ctx_num = len(ctx)

    args.batch_size = args.per_batch_size * args.ctx_num

    print('Call with', args)
    train_iter = FaceSegIter(path_imgrec=os.path.join(args.data_dir, 'train.rec'),
                             batch_size=args.batch_size,
                             per_batch_size=args.per_batch_size,
                             aug_level=1,
                             use_coherent=args.use_coherent,
                             args=args,
                             )
    targets = ['ibug', 'cofw_testset', '300W', 'AFLW2000-3D']

    data_shape = train_iter.get_data_shape()
    # label_shape = train_iter.get_label_shape()
    sym = hg.get_symbol(num_classes=args.num_classes, binarize=args.binarize, label_size=args.output_label_size,
                        input_size=args.input_img_size, use_coherent=args.use_coherent, use_dla=args.use_dla,
                        use_N=args.use_N, use_DCN=args.use_DCN, per_batch_size=args.per_batch_size)
    if len(args.pretrained) == 0:
        # data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
        data_shape_dict = train_iter.get_shape_dict()
        arg_params, aux_params = hg.init_weights(sym, data_shape_dict)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        # sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        label_names=train_iter.get_label_names(),
    )
    # lr = 1.0e-3
    # lr = 2.5e-4
    lr = args.lr
    # _rescale_grad = 1.0
    _rescale_grad = 1.0 / args.ctx_num
    # lr = args.lr
    # opt = optimizer.SGD(learning_rate=lr, momentum=0.9, wd=5.e-4, rescale_grad=_rescale_grad)
    # opt = optimizer.Adam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
    opt = optimizer.Nadam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
    # opt = optimizer.RMSProp(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
    _cb = mx.callback.Speedometer(args.batch_size, 10)
    _metric = LossValueMetric()
    # _metric2 = AccMetric()
    # eval_metrics = [_metric, _metric2]
    eval_metrics = [_metric]

    # lr_steps = [40000,60000,80000]
    # lr_steps = [12000,18000,22000]
    if len(args.lr_steps) == 0:
        lr_steps = [16000, 24000, 30000]
        # lr_steps = [14000,24000,30000]
        # lr_steps = [5000,10000]
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    _a = 40 // args.batch_size
    for i in xrange(len(lr_steps)):
        lr_steps[i] *= _a
    print('lr-steps', lr_steps)
    global_step = [0]

    def val_test():
        all_layers = sym.get_internals()
        vsym = all_layers['heatmap_output']
        vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names=None)
        # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
        vmodel.bind(data_shapes=[('data', (args.batch_size,) + data_shape)])
        arg_params, aux_params = model.get_params()
        vmodel.set_params(arg_params, aux_params)
        for target in targets:
            _file = os.path.join(args.data_dir, '%s.rec' % target)
            if not os.path.exists(_file):
                continue
            val_iter = FaceSegIter(path_imgrec=_file,
                                   batch_size=args.batch_size,
                                   # batch_size = 4,
                                   aug_level=0,
                                   args=args,
                                   )
            _metric = NMEMetric2()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            val_iter.reset()
            for i, eval_batch in enumerate(val_iter):
                # print(eval_batch.data[0].shape, eval_batch.label[0].shape)
                batch_data = mx.io.DataBatch(eval_batch.data)
                model.forward(batch_data, is_train=False)
                model.update_metric(val_metric, eval_batch.label)
            nme_value = val_metric.get_name_value()[0][1]
            print('[%d][%s]NME: %f' % (global_step[0], target, nme_value))

    def _batch_callback(param):
        _cb(param)
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.2
                print('lr change to', opt.lr)
                break
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)
        if mbatch > 0 and mbatch % args.verbose == 0:
            val_test()
            if args.ckpt > 0:
                msave = mbatch // args.verbose
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg, aux)

    model.fit(train_iter,
              begin_epoch=0,
              num_epoch=args.end_epoch,
              # eval_data          = val_iter,
              eval_data=None,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=None,
              )
예제 #6
0
if ctx_id>=0:
  ctx = mx.gpu(ctx_id)
else:
  ctx = mx.cpu()
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['heatmap_output']
#model = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=['softmax_label'])
model = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=None)
#model = mx.mod.Module(symbol=sym, context=ctx)
model.bind(for_training=False, data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)

val_iter = FaceSegIter(path_imgrec = rec_path,
  batch_size = 1,
  aug_level = 0,
  )
_metric = NMEMetric()
#val_metric = mx.metric.create(_metric)
#val_metric.reset()
#val_iter.reset()
nme = []
for i, eval_batch in enumerate(val_iter):
  if i%10==0:
    print('processing', i)
  #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
  batch_data = mx.io.DataBatch(eval_batch.data)
  model.forward(batch_data, is_train=False)
  #model.update_metric(val_metric, eval_batch.label, True)
  pred_label = model.get_outputs()[-1].asnumpy()
  label = eval_batch.label[0].asnumpy()
예제 #7
0
def main(args):
  _seed = 727
  random.seed(_seed)
  np.random.seed(_seed)
  mx.random.seed(_seed)
  ctx = []
  cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
  if len(cvd)>0:
    for i in xrange(len(cvd.split(','))):
      ctx.append(mx.gpu(i))
  if len(ctx)==0:
    ctx = [mx.cpu()]
    print('use cpu')
  else:
    print('gpu num:', len(ctx))
  #ctx = [mx.gpu(0)]
  args.ctx_num = len(ctx)

  args.batch_size = args.per_batch_size*args.ctx_num


  print('Call with', args)
  train_iter = FaceSegIter(path_imgrec = os.path.join(args.data_dir, 'train.rec'),
      batch_size = args.batch_size,
      per_batch_size = args.per_batch_size,
      aug_level = 1,
      use_coherent = args.use_coherent,
      args = args,
      )
  targets = ['ibug', 'cofw_testset', '300W', 'AFLW2000-3D']

  data_shape = train_iter.get_data_shape()
  #label_shape = train_iter.get_label_shape()
  sym = hg.get_symbol(num_classes=args.num_classes, binarize=args.binarize, label_size=args.output_label_size, input_size=args.input_img_size, use_coherent = args.use_coherent, use_dla = args.use_dla, use_N = args.use_N, use_DCN = args.use_DCN, per_batch_size = args.per_batch_size)
  if len(args.pretrained)==0:
      #data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
      data_shape_dict = train_iter.get_shape_dict()
      arg_params, aux_params = hg.init_weights(sym, data_shape_dict)
  else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      #sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

  model = mx.mod.Module(
      context       = ctx,
      symbol        = sym,
      label_names   = train_iter.get_label_names(),
  )
  #lr = 1.0e-3
  #lr = 2.5e-4
  lr = args.lr
  #_rescale_grad = 1.0
  _rescale_grad = 1.0/args.ctx_num
  #lr = args.lr
  #opt = optimizer.SGD(learning_rate=lr, momentum=0.9, wd=5.e-4, rescale_grad=_rescale_grad)
  #opt = optimizer.Adam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
  opt = optimizer.Nadam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
  #opt = optimizer.RMSProp(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
  initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
  _cb = mx.callback.Speedometer(args.batch_size, 10)
  _metric = LossValueMetric()
  #_metric2 = AccMetric()
  #eval_metrics = [_metric, _metric2]
  eval_metrics = [_metric]

  #lr_steps = [40000,60000,80000]
  #lr_steps = [12000,18000,22000]
  if len(args.lr_steps)==0:
    lr_steps = [16000,24000,30000]
    #lr_steps = [14000,24000,30000]
    #lr_steps = [5000,10000]
  else:
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
  _a = 40//args.batch_size
  for i in xrange(len(lr_steps)):
    lr_steps[i] *= _a
  print('lr-steps', lr_steps)
  global_step = [0]

  def val_test():
    all_layers = sym.get_internals()
    vsym = all_layers['heatmap_output']
    vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names = None)
    #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
    vmodel.bind(data_shapes=[('data', (args.batch_size,)+data_shape)])
    arg_params, aux_params = model.get_params()
    vmodel.set_params(arg_params, aux_params)
    for target in targets:
        _file = os.path.join(args.data_dir, '%s.rec'%target)
        if not os.path.exists(_file):
            continue
        val_iter = FaceSegIter(path_imgrec = _file,
          batch_size = args.batch_size,
          #batch_size = 4,
          aug_level = 0,
          args = args,
          )
        _metric = NMEMetric2()
        val_metric = mx.metric.create(_metric)
        val_metric.reset()
        val_iter.reset()
        for i, eval_batch in enumerate(val_iter):
          #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
          batch_data = mx.io.DataBatch(eval_batch.data)
          model.forward(batch_data, is_train=False)
          model.update_metric(val_metric, eval_batch.label)
        nme_value = val_metric.get_name_value()[0][1]
        print('[%d][%s]NME: %f'%(global_step[0], target, nme_value))
  
  def _batch_callback(param):
    _cb(param)
    global_step[0]+=1
    mbatch = global_step[0]
    for _lr in lr_steps:
      if mbatch==_lr:
        opt.lr *= 0.2
        print('lr change to', opt.lr)
        break
    if mbatch%1000==0:
      print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
    if mbatch>0 and mbatch%args.verbose==0:
      val_test()
      if args.ckpt>0:
        msave = mbatch//args.verbose
        print('saving', msave)
        arg, aux = model.get_params()
        mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg, aux)


  model.fit(train_iter,
      begin_epoch        = 0,
      num_epoch          = args.end_epoch,
      #eval_data          = val_iter,
      eval_data          = None,
      eval_metric        = eval_metrics,
      kvstore            = 'device',
      optimizer          = opt,
      initializer        = initializer,
      arg_params         = arg_params,
      aux_params         = aux_params,
      allow_missing      = True,
      batch_end_callback = _batch_callback,
      epoch_end_callback = None,
      )