Esempio n. 1
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        print('loading', args.pretrained, args.pretrained_epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym,
                                          data=(1, 3, image_size[0],
                                                image_size[1]))
        print('Network FLOPs: %d' % FLOPs)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    if config.loss_name.find('triplet') >= 0:
        from triplet_image_iter import FaceImageIter
        triplet_params = [
            config.triplet_bag_size, config.triplet_alpha,
            config.triplet_max_ap
        ]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))

    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, msave, model.symbol, arg,
                                             aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=999999,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore=args.kvstore,
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)  #GPU num
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel,image_size[0],image_size[1]) # chw
    mean = None #[127.5,127.5,127.5]
    


    begin_epoch = 0
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym = get_symbol(args)  
      if config.net_name=='spherenet':
        data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:  #��Ԥѵ��ģ�ͣ�������,sym����get_symbol(args)������

      sym,sym_high,arg_params,aux_params,t_arg_params, t_aux_params = two_sym(args)
      d_sym = discriminator(args)

      
            
    config.count_flops=False #me add
    if config.count_flops:  #true
      all_layers = sym.get_internals()
      _sym = all_layers['fc1_output']  #ͼƬ�� 128 ά�ȵ�����fc1 ���ٶ�
      FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
      _str = flops_counter.flops_str(FLOPs)
      print('Network FLOPs: %s'%_str)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)

    val_dataiter = None

    if config.loss_name.find('triplet')>=0:
      from triplet_image_iter import FaceImageIter
      triplet_params = [config.triplet_bag_size, config.triplet_alpha, config.triplet_max_ap]
      train_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror,
        #   rand_resize          = True, #me add to differ resolution img 
          mean                 = mean,
          cutoff               = config.data_cutoff,
          ctx_num              = args.ctx_num,
          images_per_identity  = config.images_per_identity,
          triplet_params       = triplet_params,
          mx_model             = model,
      )
      _metric = LossValueMetric()
      eval_metrics = [mx.metric.create(_metric)]
    else:
      from distribute_image_iter import FaceImageIter

      train_dataiter_low = FaceImageIter(  #�õ� batch  img  label, train_dataiter_high
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror, #true
          rand_resize          = True, #me add to differ resolution img 
          mean                 = mean,
          cutoff               = config.data_cutoff,  #0
          color_jittering      = config.data_color,  #0
          images_filter        = config.data_images_filter, #0
      )
      source_imgrec = os.path.join("/home/svt/mxnet_recognition/dataes/faces_glintasia","train.rec")
      data2 = FaceImageIter(  #�õ� batch  img  label, train_dataiter_high
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = source_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror, #true
          rand_resize          = False, #me add to differ resolution img
          mean                 = mean,
          cutoff               = config.data_cutoff,  #0
          color_jittering      = config.data_color,  #0
          images_filter        = config.data_images_filter, #0
      )
      metric1 = AccMetric()  #�õ����ȼ���
      eval_metrics = [mx.metric.create(metric1)]
      if config.ce_loss:  #is True
        metric2 = LossValueMetric()  #�õ���ʧֵ
        eval_metrics.append( mx.metric.create(metric2) )  #

    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0/args.ctx_num
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    opt = optimizer.Adam(learning_rate=0.0001, beta1=0.5, beta2=0.9, epsilon=1e-08)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results



    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    high_save = 0 #  me  add
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for step in lr_steps:
        if mbatch==step:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
      
      if mbatch %4000==0:#(fc7_save):
          name=os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'modelfc7')
          arg, aux = model.get_params()
          mx.model.save_checkpoint(name, param.epoch, model.symbol, arg, aux)
          print('save model include fc7 layer')
          print("mbatch",mbatch)
      
      me_msave=0
      if mbatch>=0 and mbatch%args.verbose==0:  #default.verbose = 2000,mbatch is
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]  # batch ��512��һ��epoch1300
        me_msave=me_msave+1
        do_save = False
        is_highest = False
        #me add
        save2 = False
        if len(acc_list)>0:
          lfw_score = acc_list[0]
          if lfw_score>highest_acc[0]:
            highest_acc[0] = lfw_score
            if lfw_score>=0.9960:
              save2 = True
              
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:
            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score
            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        # if is_highest:
          # do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3 and is_highest:  #me add and is_highest
          high_save = 0   #ÿ�α���lfw��ߵ�ģ��,�и��ߵ��滻ԭ�������ģ��

        if do_save:  #������ߵ����ݲ���
          print('saving high pretrained-epoch always:  ', high_save)
          arg, aux = model.get_params()
          if config.ckpt_embedding:  #true
            all_layers = model.symbol.get_internals()
            _sym = all_layers['fc1_output']
            _arg = {}
            for k in arg:
              if not k.startswith('fc7'):#�ַ�����ʼ�� fc7 ��ͷ������ѭ�������������������㣩
                _arg[k] = arg[k]
            mx.model.save_checkpoint(prefix, high_save, _sym, _arg, aux)  #��������֣������ǰ׺������IJ���ֻ��fc1(128ά�ȵ�����)
          else:
            mx.model.save_checkpoint(prefix, high_save, model.symbol, arg, aux)
          print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
          
        if save2:
          arg, aux = model.get_params()
          if config.ckpt_embedding:  #true
            all_layers = model.symbol.get_internals()
            _sym = all_layers['fc1_output']
            _arg = {}
            for k in arg:
              if not k.startswith('fc7'):#�ַ�����ʼ�� fc7 ��ͷ������ѭ�������������������㣩
                _arg[k] = arg[k]
            mx.model.save_checkpoint(prefix, (me_msave), _sym, _arg, aux)  #��������֣������ǰ׺������IJ���ֻ��fc1(128ά�ȵ�����)
          else:
            mx.model.save_checkpoint(prefix, (me_msave), model.symbol, arg, aux)
          print("save pretrained-epoch :param.epoch + me_msave",param.epoch,me_msave)
          print('[%d]LFW Accuracy>=0.9960: %1.5f'%(mbatch, highest_acc[-1])) #mbatch  �Ǵ�0 ��13000 һ��epoch ,Ȼ���ٴ�0����
    
      if config.max_steps>0 and mbatch>config.max_steps:
        sys.exit(0)
        
    ###########################################################################
   
    
    
    epoch_cb = None
    train_dataiter_low = mx.io.PrefetchingIter(train_dataiter_low) #���̵߳�����
    data2 = mx.io.PrefetchingIter(data2)  # ���̵߳�����

    #����model, �õ����ݣ�bind(data��label,�������ִ�к󣬷�����Դ�ռ�)��Ȼ���ʼ���������params
    #Ȼ�� fit ����ѵ��
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=[100, 200, 300], factor=0.1)
    optimizer_params = {'learning_rate':0.01,
                    'momentum':0.9,
                    'wd':0.0005,
                    # 'lr_scheduler':lr_scheduler,
                    "rescale_grad":_rescale}  #���ݶȽ�����ƽ��
    ######################################################################
    # # ��ʦ����
    data_shapes = [('data', (args.batch_size, 3, 112, 112))]  #teacher model only need data, no label 
    t_module = mx.module.Module(symbol=sym_high, context=ctx, label_names=[])
    t_module.bind(data_shapes=data_shapes, for_training=False, grad_req='null')
    t_module.set_params(arg_params=t_arg_params, aux_params=t_aux_params)
    t_model=t_module
    ######################################################################
    ##ѧ������
    label_shapes = [('softmax_label', (args.batch_size, ))]
    model = mx.mod.Module(
    context       = ctx,
    symbol        = sym,
    label_names=[]
    # data_names    =  #Ĭ��data,�� softmax_label,����Ķ���label �����֣���Ҫ���´���
    )
    #ѧ��������Ҫ ���ݺͱ�ǩ����ѵ��
    #��ʦ������Ҫ���ݣ����ñ�ǩ����ѵ�������Ұ����������ֵ��ӵ���ǩ����
    # print (train_dataiter_low.provide_data)
    # print ((train_dataiter_low.provide_label))
    #opt_d = optimizer.SGD(learning_rate=args.lr*0.01, momentum=args.mom, wd=args.wd, rescale_grad=_rescale) ##lr e-5
    opt_d = optimizer.Adam(learning_rate=0.0001, beta1=0.5, beta2=0.9, epsilon=1e-08)
    model.bind(data_shapes=data_shapes,for_training=True) #label shape���ˣ����˱�ǩ��������
    model.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=True)  #���Ϊtrue����������ܰ���ȱ�ٵ�ֵ�����ҽ����ó�ʼֵ�趨���������Щȱ�ٵIJ���
    # model.init_optimizer(kvstore=args.kvstore,optimizer='sgd', optimizer_params=(optimizer_params))
    model.init_optimizer(kvstore=args.kvstore,optimizer=opt_d)
    # metric = eval_metrics  #�������㣬�б�
    ##########################################################################
    ## ����������
    # ����ģ�飬�DZ����
    model_d = mx.module.Module(symbol=d_sym, context=ctx,data_names=['data'], label_names=['softmax_label'])
    data_shapes = [('data', (args.batch_size*2,512))]
    label_shapes = [('softmax_label', (args.batch_size*2,))]  #bind ������Զ��ı�batch��С��Ҳ����ʹ�õ�ʱ���ٰ�
    model_d.bind(data_shapes=data_shapes,label_shapes = label_shapes,inputs_need_grad=True)
    model_d.init_params(initializer=initializer)
    model_d.init_optimizer(kvstore=args.kvstore,optimizer=opt) #�Ż���������Ҫ�Ķ� #lr e-3
    ## �����õ��ǣ������� discriminator  �������������
    metric_d = AccMetric_d()  #�õ����ȼ���,��metric.py ��Ӻ���AccMetric_d�������õ���softmax
    eval_metrics_d = [mx.metric.create(metric_d)]
    metric2_d = LossValueMetric_d()  #�õ���ʧֵ  ,metric.py ��Ӻ���AccMetric_d�������õ���cros entropy
    eval_metrics_d.append( mx.metric.create(metric2_d) )  #
    metric_d =eval_metrics_d  # mx.metric.create('acc')## ����������softmax��  symbol ֻ��һ�����softmax ,ʱ���,

    global_step=[0]
    batch_num=[0]
    resize_acc=[0]
    for epoch in range(0, 40):
        # if epoch==1 or epoch==2 or epoch==3:
        #     model.init_optimizer(kvstore=args.kvstore,optimizer='sgd', optimizer_params=(optimizer_params))
        if not isinstance(metric_d, mx.metric.EvalMetric):#�������������
            metric_d = mx.metric.create(metric_d)
        # metric_d = mx.metric.create(metric_d)
        metric_d.reset()
        train_dataiter_low.reset()
        data2.reset()
        print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

        data_iter = iter(train_dataiter_low)
        data2_iter = iter(data2)
        data_len=0
        for batch in data_iter:  # batch is high
            ##   1���õ� ��ʦ����train false,   ѧ������train true   ����������ϲ������� label���趨��1����0 
            ####��ʦ����õ�feature����ӳ� label����Ϊ�������ݣ�
            data_len +=len(batch.data[0])
            
            if len(batch.data[0])<args.batch_size:  #batch.data[0] is ����batch 
                print ("���data����batch,����")
                print ("data_len:",data_len)
                break
            if data_len >=2830147: #2830147,Ŀ���������ݳ���
                print ("һ��batch ����")
                break

            batch2 = data2_iter.next()
            t_model.forward(batch2, is_train=False)  #high data,�Լ� low_data,,�������������ݣ����ݿ��Դ�С��ͬ
            t_feat = t_model.get_outputs() # type list   batch.label,type list�����ֻ��fc1
            
            # print (batch.data[0].grad is None) # not None,  batch.data[0].detach.grad ,is None
            ## batch.data[0].grad ��None   ,batch.data[0].detach.grad Ҳ��None 
            ## �����û�����ݶ� ��bind, bind ������������������ݶȣ�������detach ,��ʾ������������ݶȼ���
            ## batch.data[0] #���ص����б�[batch_data] [label]����[  array[bchw]  ] [ array[0 1...]]
            ## ѧ���������ɶԿ�����  fack
            model.forward(batch,is_train=True) ##fc1 ���
            g_feat = model.get_outputs()    #get_symol ���صģ�����ֵ����,���յļ���ֵ����һ����fc1����
            label_t = nd.ones((args.batch_size,)) #1
            label_g = nd.zeros((args.batch_size,)) #0
            ## ������һ��
            label_concat = nd.concat(label_t,label_g,dim=0)
            feat_concat = nd.concat(t_feat[0],g_feat[0],dim=0) # ����nd �ϲ�nd.L2Normalization(����Ҫ
            
            ### 2.1�� �ϲ������ݽ���ѵ�����ݶȸ��£��ڶ���,�ڽ��У� is train = true,�� �����������ݵ��ݶȣ�
            ##��false,�Dz�����������ݶȣ����벻�䣬������Ҫ������ݶȣ�
            feat_data = mx.io.DataBatch([feat_concat.detach()], [label_concat])
            model_d.forward(feat_data, is_train=True) # #���е���ʧ
            model_d.backward()
            # print(feat_data.data[0].grad is None)  #is None
            ##��ֵ ģ���ݶȴ���
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in model_d._exec_group.grad_arrays]
            model_d.update()   ##�ݶȸ���
            model_d.update_metric(metric_d, [label_concat])
            
            
            ### 2.2 ,��ѧ������������õ� ����ֵ�������ݶ����ô��ݸ� ѧ�����磬�����£����ݵ��������� batch ��С
            label_g = nd.ones((args.batch_size,)) #��ǩ����Ϊ1

            feat_data = mx.io.DataBatch([g_feat[0]], [label_g])  #have input grad
            model_d.forward(feat_data, is_train=True) # #true  �õ�������ݶ�
            model_d.backward() ## �ҵ����û���ۼӹ��ܣ���һ����ִ������ forward �Ḳ���ϴεĽ��


            ####3. G �õ� �ݶ�  ���򴫵� ��ѧ������
            g_grad=model_d.get_input_grads()
            model.backward(g_grad)
            model.update()

            ## ѵ�������� s t ���������뵽���������磬�������ݶȸ��£�Ȼ�󣬵õ�s������������������н�������ʧ���ݶȴ���
            ## ������ ���� �������ǽ�ʦ��ѧ�����������ƴ�ӣ�label�ǣ�1 �� 0 
            
            # gan_label = [nd.empty((args.batch_size*2,2))]  #(batch*2,2) ����ģ�͵�������ƴ�� ��С��0 1 label,
            # discrim_data = [nd.empty((args.batch_size*2,512))]  #(batch*2,512)
            # print (gan_label[0].shape)



            lr_steps = [int(x) for x in args.lr_steps.split(',')]
            global_step[0]+=1
            batch_num[0]+=1
            mbatch = global_step[0]
            for step in lr_steps:
                if mbatch==step:
                    opt.lr *= 0.1
                    opt_d.lr*=0.1
                    print('opt.lr ,opt_d.lr lr change to', opt.lr,opt_d.lr)
                    break
            
            if mbatch %200==0 and mbatch >0: #(fc7_save):            
                print('mbath %d, Training %s' % (epoch, metric_d.get()))

            if mbatch %1000==0 and mbatch >0: 
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, epoch, model.symbol, arg, aux)
                
                arg, aux = model_d.get_params()
                mx.model.save_checkpoint(prefix+"discriminator", epoch, model_d.symbol, arg, aux)
                
                top1,top10 = my_top(epoch)
                yidong_test_top1,yidong_test_top1=my_top_yidong_test(epoch)
                if top1 >= resize_acc[0]:
                    resize_acc[0]=top1
                    #������ߵ����ݲ���
                    arg, aux = model.get_params()
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                      if not k.startswith('fc7'):#�ַ�����ʼ�� fc7 ��ͷ������ѭ�������������������㣩
                        _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix+"_best", 1, _sym, _arg, aux)  
                    acc_list = ver_test(mbatch)
                    if len(acc_list)>0:
                        print ("LFW acc is :",acc_list[0])
 
                print("batch_num",batch_num[0],"epoch",epoch, "lr ",opt.lr)
                print('mbath %d, Training %s' % (epoch, metric_d.get()))
Esempio n. 3
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")

    assert args.images_per_identity >= 2
    assert args.triplet_bag_size % args.batch_size == 0

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        sym, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        all_layers = sym.get_internals()
        sym = all_layers['fc1_output']
        sym, arg_params, aux_params = get_symbol(args,
                                                 arg_params,
                                                 aux_params,
                                                 sym_embedding=sym)

    data_extra = None
    hard_mining = False
    triplet_params = [
        args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap
    ]
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        #data_names = ('data',),
        #label_names = None,
        #label_names = ('softmax_label',),
    )
    label_shape = (args.batch_size, )

    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        ctx_num=args.ctx_num,
        images_per_identity=args.images_per_identity,
        triplet_params=triplet_params,
        mx_model=model,
    )

    _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    if args.noise_sgd > 0.0:
        print('use noise sgd')
        opt = NoiseSGD(scale=args.noise_sgd,
                       learning_rate=base_lr,
                       momentum=base_mom,
                       wd=base_wd,
                       rescale_grad=_rescale)
    else:
        opt = optimizer.SGD(learning_rate=base_lr,
                            momentum=base_mom,
                            wd=base_wd,
                            rescale_grad=_rescale)
    som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, label_shape)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [1000000000]
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            #for i in xrange(len(acc_list)):
            #  acc = acc_list[i]
            #  if acc>=highest_acc[i]:
            #    highest_acc[i] = acc
            #    if lfw_score>=0.99:
            #      do_save = True
            #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
            #  do_save = True
            if do_save:
                print('saving', msave)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Esempio n. 4
0
def train_net(args):
    ctx = []
    #ctx.append(mx.gpu(1)) #manual
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        logging.info('use cpu')
    else:
        logging.info('gpu num: %d', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    logging.info('prefix %s', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    logging.info('image_size %s', str(image_size))
    logging.info('num_classes %d', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    logging.info('Called with argument: %s %s', str(args), str(config))
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = args.pretrained_epoch
    if len(args.pretrained) == 0:  #no pretraining
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:  #load pretrained model
        logging.info('loading %s %s', str(args.pretrained),
                     str(args.pretrained_epoch))
        _, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym,
                                          data=(1, 3, image_size[0],
                                                image_size[1]))
        _str = flops_counter.flops_str(FLOPs)
        logging.info('Network FLOPs: %s' % _str)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(  #executable options and full model is loaded, loss functions and all
        context=ctx,
        symbol=sym,
        #fixed_param_names = ["conv_1_conv2d_weight","res_2_block0_conv_sep_conv2d_weight","res_2_block0_conv_dw_conv2d_weight","res_2_block0_conv_proj_conv2d_weight","res_2_block1_conv_sep_conv2d_weight","res_2_block1_conv_dw_conv2d_weight","res_2_block1_conv_proj_conv2d_weight","dconv_23_conv_sep_conv2d_weight","dconv_23_conv_dw_conv2d_weight","dconv_23_conv_proj_conv2d_weight","res_3_block0_conv_sep_conv2d_weight","res_3_block0_conv_dw_conv2d_weight","res_3_block0_conv_proj_conv2d_weight","res_3_block1_conv_sep_conv2d_weight","res_3_block1_conv_dw_conv2d_weight","res_3_block1_conv_proj_conv2d_weight","res_3_block2_conv_sep_conv2d_weight","res_3_block2_conv_dw_conv2d_weight","res_3_block2_conv_proj_conv2d_weight","res_3_block3_conv_sep_conv2d_weight","res_3_block3_conv_dw_conv2d_weight","res_3_block3_conv_proj_conv2d_weight","res_3_block4_conv_sep_conv2d_weight","res_3_block4_conv_dw_conv2d_weight","res_3_block4_conv_proj_conv2d_weight","res_3_block5_conv_sep_conv2d_weight","res_3_block5_conv_dw_conv2d_weight","res_3_block5_conv_proj_conv2d_weight","res_3_block6_conv_sep_conv2d_weight","res_3_block6_conv_dw_conv2d_weight","res_3_block6_conv_proj_conv2d_weight","res_3_block7_conv_sep_conv2d_weight","res_3_block7_conv_dw_conv2d_weight","res_3_block7_conv_proj_conv2d_weight","dconv_34_conv_sep_conv2d_weight","dconv_34_conv_dw_conv2d_weight","dconv_34_conv_proj_conv2d_weight","res_4_block0_conv_sep_conv2d_weight","res_4_block0_conv_dw_conv2d_weight","res_4_block0_conv_proj_conv2d_weight","res_4_block1_conv_sep_conv2d_weight","res_4_block1_conv_dw_conv2d_weight","res_4_block1_conv_proj_conv2d_weight","res_4_block2_conv_sep_conv2d_weight","res_4_block2_conv_dw_conv2d_weight","res_4_block2_conv_proj_conv2d_weight","res_4_block3_conv_sep_conv2d_weight","res_4_block3_conv_dw_conv2d_weight","res_4_block3_conv_proj_conv2d_weight","res_4_block4_conv_sep_conv2d_weight","res_4_block4_conv_dw_conv2d_weight","res_4_block4_conv_proj_conv2d_weight","res_4_block5_conv_sep_conv2d_weight","res_4_block5_conv_dw_conv2d_weight","res_4_block5_conv_proj_conv2d_weight","res_4_block6_conv_sep_conv2d_weight","res_4_block6_conv_dw_conv2d_weight","res_4_block6_conv_proj_conv2d_weight","res_4_block7_conv_sep_conv2d_weight","res_4_block7_conv_dw_conv2d_weight","res_4_block7_conv_proj_conv2d_weight","res_4_block8_conv_sep_conv2d_weight","res_4_block8_conv_dw_conv2d_weight","res_4_block8_conv_proj_conv2d_weight","res_4_block9_conv_sep_conv2d_weight","res_4_block9_conv_dw_conv2d_weight","res_4_block9_conv_proj_conv2d_weight","res_4_block10_conv_sep_conv2d_weight","res_4_block10_conv_dw_conv2d_weight","res_4_block10_conv_proj_conv2d_weight","res_4_block11_conv_sep_conv2d_weight","res_4_block11_conv_dw_conv2d_weight","res_4_block11_conv_proj_conv2d_weight","res_4_block12_conv_sep_conv2d_weight","res_4_block12_conv_dw_conv2d_weight","res_4_block12_conv_proj_conv2d_weight","res_4_block13_conv_sep_conv2d_weight","res_4_block13_conv_dw_conv2d_weight","res_4_block13_conv_proj_conv2d_weight","res_4_block14_conv_sep_conv2d_weight","res_4_block14_conv_dw_conv2d_weight","res_4_block14_conv_proj_conv2d_weight","res_4_block15_conv_sep_conv2d_weight","res_4_block15_conv_dw_conv2d_weight","res_4_block15_conv_proj_conv2d_weight","dconv_45_conv_sep_conv2d_weight","dconv_45_conv_dw_conv2d_weight","dconv_45_conv_proj_conv2d_weight"],
        #fixed_param_names = ['convolution'+str(i)+'_weight' for i in range(1,40)],
    )
    val_dataiter = None

    if config.loss_name.find('triplet') >= 0:  #if triplet or atriplet loss
        from triplet_image_iter import FaceImageIter
        triplet_params = [
            config.triplet_bag_size, config.triplet_alpha,
            config.triplet_max_ap
        ]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))
    gaussian_nets = [
        'fresnet', 'fmobilefacenet', 'fsqueezefacenet_v1',
        'fsqueezefacenet_v2', 'fshufflefacenetv2', 'fshufflenetv1',
        'fsqueezenet1_0', 'fsqueezenet1_1', 'fsqueezenet1_2',
        'fsqueezenet1_1_no_pool', 'fmobilenetv2', 'fmobilefacenetv1',
        'fmobilenetv2_mxnet', 'vargfacenet', 'mobilenetv3'
    ]
    if config.net_name in gaussian_nets:

        #    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet' or config.net_name=='fsqueezefacenet_v1' or config.net_name=='fsqueezefacenet_v2' or config.net_name=='fshufflefacenetv2' or config.net_name =='fefficientnet' or config.net_name == 'fshufflenetv1' or config.net_name == 'fsqueezenet1_0' or config.net_name == 'fsqueezenet1_1' or config.net_name =='fsqueezenet1_2' or config.net_name == 'fsqueezenet1_1_no_pool' or config.net_name == 'fmobilenetv2':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
        print("GAUSSIAN INITIALIZER")
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
        print("UNIFORM INITIALIZER")
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    #opt = optimizer.Adam(learning_rate=args.lr,wd=args.wd,rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            logging.info('ver %s', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            logging.info('[%s][%d]XNorm: %f' %
                         (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                         (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    highest_train_acc = [0.0, 0.0]
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    logging.info('lr_steps %s', str(lr_steps))

    def _batch_callback(param):
        #global global_step
        #weights_db = model.get_params()[0]['fire2_act_squeeze_1x1'].asnumpy()

        #weights_db = model.get_params()[0]
        #print(str(weights_db.keys()))

        #for k, v in weights_db.items():
        #print(k)
        #     if(np.any(np.isnan(v.asnumpy())) or np.any(np.isinf(v.asnumpy()))):
        #         print("nan or inf weight found at "+k)

        #name_value = param.eval_metric.get_name_value()
        #for name, value in name_value:
        #    logging.info('Epoch[%d] Validation-%s=%f', param.epoch, name, value)
        loss = param.eval_metric.get_name_value()[1]
        train_acc = param.eval_metric.get_name_value()[0]
        #print(loss)
        #if (np.isnan(loss[1])):
        #    print("Nan loss found")
        #    f = open("nan_loss_weights.txt", "w")
        #    f.write("batch #"+str(global_step[0])+"\n"+str(loss)+"\n"+str(weights_db.keys())+"\n"+str(weights_db))
        #    f.close()
        #    print("Written file at: nan_loss_weights.txt")
        #    exit()

        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                logging.info('lr change to %f', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            logging.info('lr-batch-epoch: %f %d %d', opt.lr, param.nbatch,
                         param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                logging.info('saving %d', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, param.epoch + 1, _sym,
                                             _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, param.epoch + 1,
                                             model.symbol, arg, aux)
            logging.info('[%d]Accuracy-Highest: %1.5f' %
                         (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=999999,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore=args.kvstore,
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)