示例#1
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        print('loading', args.pretrained, args.pretrained_epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym,
                                          data=(1, 3, image_size[0],
                                                image_size[1]))
        print('Network FLOPs: %d' % FLOPs)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    if config.loss_name.find('triplet') >= 0:
        from triplet_image_iter import FaceImageIter
        triplet_params = [
            config.triplet_bag_size, config.triplet_alpha,
            config.triplet_max_ap
        ]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))

    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, msave, model.symbol, arg,
                                             aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=999999,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore=args.kvstore,
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
def train_net(args):
    ## =================== parse context ==========================
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx)==0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))


    ## ==================== get model save prefix and log ============
    if len(args.extra_model_name)==0:
        prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    else:
        prefix = os.path.join(args.models_root, '%s-%s-%s-%s'%(args.network, args.loss, args.dataset, args.extra_model_name), 'model')

    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    filehandler = logging.FileHandler("{}.log".format(prefix))
    streamhandler = logging.StreamHandler()
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    ## ================ parse batch size and class info ======================
    args.ctx_num = len(ctx)
    if args.per_batch_size==0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    
    global_num_ctx = config.num_workers * args.ctx_num
    if config.num_classes % global_num_ctx == 0:
        args.ctx_num_classes = config.num_classes//global_num_ctx
    else:
        args.ctx_num_classes = config.num_classes//global_num_ctx+1

    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    logger.info("Train model with argument: {}\nconfig : {}".format(args, config))

    train_dataiter, val_dataiter = get_data_iter(config, args.batch_size)

    ## =============== get train info ============================
    image_size = config.image_shape[0:2]
    if len(args.pretrained) == 0: # train from scratch 
        esym = get_symbol_embedding(config)
        asym = functools.partial(get_symbol_arcface, config=config)
    else: # load train model to continue
        assert False

    if config.count_flops:
        all_layers = esym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
        _str = flops_counter.flops_str(FLOPs)
        print('Network FLOPs: %s'%_str)
        logging.info("Network FLOPs : %s" % _str)

    if config.num_workers==1:
        #from parall_loss_module import ParallLossModule
        from parall_module_local_v1 import ParallModule
    else: # distribute parall loop
        assert False


    model = ParallModule(
        context       = ctx,
        symbol        = esym,
        data_names    = ['data'],
        label_names    = ['softmax_label'],
        asymbol       = asym,
        args = args,
        logger=logger,
    )
    

    ## ============ get optimizer =====================================
    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)

    opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=1.0/args.batch_size)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(config.dataset_path, name+".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    

    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    ## =============== batch end callback definition ===================================
    def _batch_callback(param):
        #global global_step
        global_step[0]+=1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch==step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                logger.info('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch%1000==0:
            print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
            logger.info('lr-batch-epoch: {}'.format(opt.lr,param.nbatch,param.epoch))

        if mbatch>=0 and mbatch%args.verbose==0:
            acc_list = ver_test(mbatch)
            save_step[0]+=1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list)>0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1]>=highest_acc[-1]:
                    if acc_list[-1]>highest_acc[-1]:
                        is_highest = True
                    else:
                        if score>=highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    
            if is_highest:
                do_save = True
            if args.ckpt==0:
                do_save = False
            elif args.ckpt==2:
                do_save = True
            elif args.ckpt==3:
                msave = 1

            if do_save:
                print('saving', msave)
                logger.info('saving {}'.format(msave))

                arg, aux = model.get_export_params()
                all_layers = model.symbol.get_internals()
                _sym = all_layers['fc1_output']
                mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
            logger.info('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))

        if config.max_steps>0 and mbatch>config.max_steps:
            sys.exit(0)

    model.fit(train_dataiter,
        begin_epoch        = 0,
        num_epoch          = 999999,
        eval_data          = val_dataiter,
        kvstore            = args.kvstore,
        optimizer          = opt,
        initializer        = initializer,
        arg_params         = None,
        aux_params         = None,
        allow_missing      = True,
        batch_end_callback = _batch_callback)
示例#3
0
def train_net(args):
    #_seed = 727
    #random.seed(_seed)
    #np.random.seed(_seed)
    #mx.random.seed(_seed)
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in range(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    if len(args.extra_model_name)==0:
      prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    else:
      prefix = os.path.join(args.models_root, '%s-%s-%s-%s'%(args.network, args.loss, args.dataset, args.extra_model_name), 'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size
    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    data_shape = (args.image_channel,image_size[0],image_size[1])

    num_workers = config.num_workers
    global_num_ctx = num_workers * args.ctx_num
    if config.num_classes%global_num_ctx==0:
      args.ctx_num_classes = config.num_classes//global_num_ctx
    else:
      args.ctx_num_classes = config.num_classes//global_num_ctx+1
    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    #if len(args.partial)==0:
    #  local_classes_range = (0, args.num_classes)
    #else:
    #  _vec = args.partial.split(',')
    #  local_classes_range = (int(_vec[0]), int(_vec[1]))

    #args.partial_num_classes = local_classes_range[1] - local_classes_range[0]
    #args.partial_start = local_classes_range[0]

    print('Called with argument:', args, config)
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    if len(args.pretrained)==0:
      esym = get_symbol_embedding()
      asym = get_symbol_arcface
    else:
      assert False

    if config.count_flops:
      all_layers = esym.get_internals()
      _sym = all_layers['fc1_output']
      FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
      _str = flops_counter.flops_str(FLOPs)
      print('Network FLOPs: %s'%_str)

    if config.num_workers==1:
      from parall_module_local_v1 import ParallModule
    else:
      from parall_module_dist import ParallModule

    model = ParallModule(
        context       = ctx,
        symbol        = esym,
        data_names    = ['data'],
        label_names    = ['softmax_label'],
        asymbol       = asym,
        args = args,
    )
    val_dataiter = None
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = config.data_rand_mirror,
        mean                 = mean,
        cutoff               = config.data_cutoff,
        color_jittering      = config.data_color,
        images_filter        = config.data_images_filter,
    )


    
    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)

    _rescale = 1.0/args.batch_size
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)


    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)


    def ver_test(nbatch):
      results = []
      for i in range(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for step in lr_steps:
        if mbatch==step:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if len(acc_list)>0:
          #lfw_score = acc_list[0]
          #if lfw_score>highest_acc[0]:
          #  highest_acc[0] = lfw_score
          #  if lfw_score>=0.998:
          #    do_save = True
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:
            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score
            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        if is_highest:
          do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3:
          msave = 1

        if do_save:
          print('saving', msave)
          arg, aux = model.get_export_params()
          all_layers = model.symbol.get_internals()
          _sym = all_layers['fc1_output']
          mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if config.max_steps>0 and mbatch>config.max_steps:
        sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = 999999,
        eval_data          = val_dataiter,
        #eval_metric        = eval_metrics,
        kvstore            = args.kvstore,
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
示例#4
0
def train_net(args):
    ctx = []
    #ctx.append(mx.gpu(1)) #manual
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        logging.info('use cpu')
    else:
        logging.info('gpu num: %d', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    logging.info('prefix %s', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    logging.info('image_size %s', str(image_size))
    logging.info('num_classes %d', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    logging.info('Called with argument: %s %s', str(args), str(config))
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = args.pretrained_epoch
    if len(args.pretrained) == 0:  #no pretraining
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:  #load pretrained model
        logging.info('loading %s %s', str(args.pretrained),
                     str(args.pretrained_epoch))
        _, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym,
                                          data=(1, 3, image_size[0],
                                                image_size[1]))
        _str = flops_counter.flops_str(FLOPs)
        logging.info('Network FLOPs: %s' % _str)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(  #executable options and full model is loaded, loss functions and all
        context=ctx,
        symbol=sym,
        #fixed_param_names = ["conv_1_conv2d_weight","res_2_block0_conv_sep_conv2d_weight","res_2_block0_conv_dw_conv2d_weight","res_2_block0_conv_proj_conv2d_weight","res_2_block1_conv_sep_conv2d_weight","res_2_block1_conv_dw_conv2d_weight","res_2_block1_conv_proj_conv2d_weight","dconv_23_conv_sep_conv2d_weight","dconv_23_conv_dw_conv2d_weight","dconv_23_conv_proj_conv2d_weight","res_3_block0_conv_sep_conv2d_weight","res_3_block0_conv_dw_conv2d_weight","res_3_block0_conv_proj_conv2d_weight","res_3_block1_conv_sep_conv2d_weight","res_3_block1_conv_dw_conv2d_weight","res_3_block1_conv_proj_conv2d_weight","res_3_block2_conv_sep_conv2d_weight","res_3_block2_conv_dw_conv2d_weight","res_3_block2_conv_proj_conv2d_weight","res_3_block3_conv_sep_conv2d_weight","res_3_block3_conv_dw_conv2d_weight","res_3_block3_conv_proj_conv2d_weight","res_3_block4_conv_sep_conv2d_weight","res_3_block4_conv_dw_conv2d_weight","res_3_block4_conv_proj_conv2d_weight","res_3_block5_conv_sep_conv2d_weight","res_3_block5_conv_dw_conv2d_weight","res_3_block5_conv_proj_conv2d_weight","res_3_block6_conv_sep_conv2d_weight","res_3_block6_conv_dw_conv2d_weight","res_3_block6_conv_proj_conv2d_weight","res_3_block7_conv_sep_conv2d_weight","res_3_block7_conv_dw_conv2d_weight","res_3_block7_conv_proj_conv2d_weight","dconv_34_conv_sep_conv2d_weight","dconv_34_conv_dw_conv2d_weight","dconv_34_conv_proj_conv2d_weight","res_4_block0_conv_sep_conv2d_weight","res_4_block0_conv_dw_conv2d_weight","res_4_block0_conv_proj_conv2d_weight","res_4_block1_conv_sep_conv2d_weight","res_4_block1_conv_dw_conv2d_weight","res_4_block1_conv_proj_conv2d_weight","res_4_block2_conv_sep_conv2d_weight","res_4_block2_conv_dw_conv2d_weight","res_4_block2_conv_proj_conv2d_weight","res_4_block3_conv_sep_conv2d_weight","res_4_block3_conv_dw_conv2d_weight","res_4_block3_conv_proj_conv2d_weight","res_4_block4_conv_sep_conv2d_weight","res_4_block4_conv_dw_conv2d_weight","res_4_block4_conv_proj_conv2d_weight","res_4_block5_conv_sep_conv2d_weight","res_4_block5_conv_dw_conv2d_weight","res_4_block5_conv_proj_conv2d_weight","res_4_block6_conv_sep_conv2d_weight","res_4_block6_conv_dw_conv2d_weight","res_4_block6_conv_proj_conv2d_weight","res_4_block7_conv_sep_conv2d_weight","res_4_block7_conv_dw_conv2d_weight","res_4_block7_conv_proj_conv2d_weight","res_4_block8_conv_sep_conv2d_weight","res_4_block8_conv_dw_conv2d_weight","res_4_block8_conv_proj_conv2d_weight","res_4_block9_conv_sep_conv2d_weight","res_4_block9_conv_dw_conv2d_weight","res_4_block9_conv_proj_conv2d_weight","res_4_block10_conv_sep_conv2d_weight","res_4_block10_conv_dw_conv2d_weight","res_4_block10_conv_proj_conv2d_weight","res_4_block11_conv_sep_conv2d_weight","res_4_block11_conv_dw_conv2d_weight","res_4_block11_conv_proj_conv2d_weight","res_4_block12_conv_sep_conv2d_weight","res_4_block12_conv_dw_conv2d_weight","res_4_block12_conv_proj_conv2d_weight","res_4_block13_conv_sep_conv2d_weight","res_4_block13_conv_dw_conv2d_weight","res_4_block13_conv_proj_conv2d_weight","res_4_block14_conv_sep_conv2d_weight","res_4_block14_conv_dw_conv2d_weight","res_4_block14_conv_proj_conv2d_weight","res_4_block15_conv_sep_conv2d_weight","res_4_block15_conv_dw_conv2d_weight","res_4_block15_conv_proj_conv2d_weight","dconv_45_conv_sep_conv2d_weight","dconv_45_conv_dw_conv2d_weight","dconv_45_conv_proj_conv2d_weight"],
        #fixed_param_names = ['convolution'+str(i)+'_weight' for i in range(1,40)],
    )
    val_dataiter = None

    if config.loss_name.find('triplet') >= 0:  #if triplet or atriplet loss
        from triplet_image_iter import FaceImageIter
        triplet_params = [
            config.triplet_bag_size, config.triplet_alpha,
            config.triplet_max_ap
        ]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))
    gaussian_nets = [
        'fresnet', 'fmobilefacenet', 'fsqueezefacenet_v1',
        'fsqueezefacenet_v2', 'fshufflefacenetv2', 'fshufflenetv1',
        'fsqueezenet1_0', 'fsqueezenet1_1', 'fsqueezenet1_2',
        'fsqueezenet1_1_no_pool', 'fmobilenetv2', 'fmobilefacenetv1',
        'fmobilenetv2_mxnet', 'vargfacenet', 'mobilenetv3'
    ]
    if config.net_name in gaussian_nets:

        #    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet' or config.net_name=='fsqueezefacenet_v1' or config.net_name=='fsqueezefacenet_v2' or config.net_name=='fshufflefacenetv2' or config.net_name =='fefficientnet' or config.net_name == 'fshufflenetv1' or config.net_name == 'fsqueezenet1_0' or config.net_name == 'fsqueezenet1_1' or config.net_name =='fsqueezenet1_2' or config.net_name == 'fsqueezenet1_1_no_pool' or config.net_name == 'fmobilenetv2':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
        print("GAUSSIAN INITIALIZER")
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
        print("UNIFORM INITIALIZER")
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    #opt = optimizer.Adam(learning_rate=args.lr,wd=args.wd,rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            logging.info('ver %s', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            logging.info('[%s][%d]XNorm: %f' %
                         (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                         (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    highest_train_acc = [0.0, 0.0]
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    logging.info('lr_steps %s', str(lr_steps))

    def _batch_callback(param):
        #global global_step
        #weights_db = model.get_params()[0]['fire2_act_squeeze_1x1'].asnumpy()

        #weights_db = model.get_params()[0]
        #print(str(weights_db.keys()))

        #for k, v in weights_db.items():
        #print(k)
        #     if(np.any(np.isnan(v.asnumpy())) or np.any(np.isinf(v.asnumpy()))):
        #         print("nan or inf weight found at "+k)

        #name_value = param.eval_metric.get_name_value()
        #for name, value in name_value:
        #    logging.info('Epoch[%d] Validation-%s=%f', param.epoch, name, value)
        loss = param.eval_metric.get_name_value()[1]
        train_acc = param.eval_metric.get_name_value()[0]
        #print(loss)
        #if (np.isnan(loss[1])):
        #    print("Nan loss found")
        #    f = open("nan_loss_weights.txt", "w")
        #    f.write("batch #"+str(global_step[0])+"\n"+str(loss)+"\n"+str(weights_db.keys())+"\n"+str(weights_db))
        #    f.close()
        #    print("Written file at: nan_loss_weights.txt")
        #    exit()

        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                logging.info('lr change to %f', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            logging.info('lr-batch-epoch: %f %d %d', opt.lr, param.nbatch,
                         param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                logging.info('saving %d', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, param.epoch + 1, _sym,
                                             _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, param.epoch + 1,
                                             model.symbol, arg, aux)
            logging.info('[%d]Accuracy-Highest: %1.5f' %
                         (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=999999,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore=args.kvstore,
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)