Пример #1
0
def get_data(path, batch_size, use_bgr=False):
    image_size = (112, 112)
    data_set = verification.load_bin(path, image_size)

    data_list, issame_list = data_set
    _label = nd.ones((batch_size, 1))
    for i in range(1):  #len(data_list)):
        data = data_list[i]
        print('datasize: ', len(data))
        embeddings = None
        ba = 0
        while ba < data.shape[0]:
            bb = min(ba + batch_size, data.shape[0])
            count = bb - ba
            _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
            if use_bgr:
                _data = _data[:, ::-1, :, :]
            #print(_data.shape, _label.shape)
            db = mx.io.DataBatch(data=(_data, ))  #, label=(_label,))
            ba = bb
            yield (db, count)
Пример #2
0
def get_exclude_data(args, rec_list, image_size, select_ids, model,
                     all_id_list):
    if len(args.exclude) > 0:
        if os.path.isdir(args.exclude):
            _path_imgrec = os.path.join(args.exclude, 'train.rec')
            _path_imgidx = os.path.join(args.exclude, 'train.idx')
            _imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec,
                                                    'r')  # pylint: disable=redefined-variable-type
            _ds_id = len(rec_list)
            _id_list = get_ids_list(_imgrec, _ds_id, model, select_ids)
        else:
            _id_list = []
            data_set = verification.load_bin(args.exclude, image_size)[0][0]
            print(data_set.shape)
            data = nd.zeros((1, 3, image_size[0], image_size[1]))
            for i in xrange(data_set.shape[0]):
                data[0] = data_set[i]
                db = mx.io.DataBatch(data=(data, ))
                model.forward(db, is_train=False)
                net_out = model.get_outputs()
                embedding = net_out[0].asnumpy().flatten()
                _norm = np.linalg.norm(embedding)
                embedding /= _norm
                _id_list.append((i, i, embedding))
            X = []
            for id_item in all_id_list:
                X.append(id_item[2])
            X = np.array(X)
            emap = {}
            for id_item in _id_list:
                y = id_item[2]
                sim = np.dot(X, y.T)
                idx = np.where(sim >= args.param2)[0]
                for j in idx:
                    emap[j] = 1
                    all_id_list[j][1] = -1
            print('exclude', len(emap))
        return all_id_list
    else:
        return None
Пример #3
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    # image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    # label_name = 'softmax_label'
    # label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  # resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0, 0.0, 0.0, 0.0]
    # for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                score = {}
                score['lfw_score'] = acc_list[0]
                score['cfp_score'] = acc_list[1]
                score['agedb_score'] = acc_list[2]
                score['cplfw_score'] = acc_list[3]
                score['calfw_score'] = acc_list[4]
                print('score=', score)
                if score['lfw_score'] > highest_acc[0]:
                    highest_acc[0] = score['lfw_score']
                    if score['lfw_score'] >= 0.99:
                        do_save = True
                if score['cfp_score'] > highest_acc[1]:
                    highest_acc[1] = score['cfp_score']
                    if score['cfp_score'] > 0.94:
                        do_save = True
                if score['agedb_score'] > highest_acc[2]:
                    highest_acc[2] = score['agedb_score']
                    if score['agedb_score'] > 0.93:
                        do_save = True
                if score['cplfw_score'] > highest_acc[3]:
                    highest_acc[3] = score['cplfw_score']
                    if score['cplfw_score'] > 0.85:
                        do_save = True
                if score['calfw_score'] > highest_acc[4]:
                    highest_acc[4] = score['calfw_score']
                    if score['calfw_score'] > 0.9:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            arg, aux = model.get_params()
            print('saving', 0)
            mx.model.save_checkpoint(prefix, 0, model.symbol, arg, aux)
            if do_save:
                print('saving', msave)
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print(
                '[%d]score_highest: lfw: %1.5f cfp: %1.5f agedb: %1.5f cplfw: %1.5f calfw: %1.5f'
                % (mbatch, highest_acc[0], highest_acc[1], highest_acc[2],
                   highest_acc[3], highest_acc[4]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    # model.fit(train_dataiter,
    #           begin_epoch=begin_epoch,
    #           num_epoch=end_epoch,
    #           eval_data=val_dataiter,
    #           eval_metric=eval_metrics,
    #           kvstore='device',
    #           optimizer=opt,
    #           # optimizer_params   = optimizer_params,
    #           initializer=initializer,
    #           arg_params=arg_params,
    #           aux_params=aux_params,
    #           allow_missing=True,
    #           batch_end_callback=_batch_callback,
    #           epoch_end_callback=epoch_cb)

    model.bind(data_shapes=train_dataiter.provide_data,
               label_shapes=train_dataiter.provide_label,
               for_training=True,
               force_rebind=False)
    model.init_params(initializer=initializer,
                      arg_params=arg_params,
                      aux_params=aux_params,
                      allow_missing=True,
                      force_init=False)
    model.init_optimizer(kvstore='device', optimizer=opt)

    if not isinstance(eval_metrics, mx.model.metric.EvalMetric):
        eval_metrics = mx.model.metric.create(eval_metrics)
    epoch_eval_metric = copy.deepcopy(eval_metrics)

    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, end_epoch):
        tic = time.time()
        eval_metrics.reset()
        epoch_eval_metric.reset()
        nbatch = 0
        data_iter = iter(train_dataiter)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            model.forward_backward(data_batch)
            model.update()

            if isinstance(data_batch, list):
                model.update_metric(eval_metrics,
                                    [db.label for db in data_batch],
                                    pre_sliced=True)
                model.update_metric(epoch_eval_metric,
                                    [db.label for db in data_batch],
                                    pre_sliced=True)
            else:
                model.update_metric(eval_metrics, data_batch.label)
                model.update_metric(epoch_eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                model.prepare(next_data_batch, sparse_row_id_fn=None)
            except StopIteration:
                end_of_batch = True

            if end_of_batch:
                eval_name_vals = epoch_eval_metric.get_name_value()

            batch_end_params = mx.model.BatchEndParam(epoch=epoch,
                                                      nbatch=nbatch,
                                                      eval_metric=eval_metrics,
                                                      locals=locals())
            _batch_callback(batch_end_params)
            nbatch += 1

        # one epoch of training is finished
        for name, val in eval_name_vals:
            model.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
        toc = time.time()
        model.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = model.get_params()
        model.set_params(arg_params, aux_params)

        train_dataiter.reset()
Пример #4
0
def main(args):
  include_datasets = args.include.split(',')
  prop = face_image.load_property(include_datasets[0])
  image_size = prop.image_size
  print('image_size', image_size)
  model = None
  if len(args.model)>0:
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    args.ctx_num = len(ctx)
    vec = args.model.split(',')
    prefix = vec[0]
    epoch = int(vec[1])
    print('loading',prefix, epoch)
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    #arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
    all_layers = sym.get_internals()
    sym = all_layers['fc1_output']
    #model = mx.mod.Module.load(prefix, epoch, context = ctx)
    #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
    model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
    model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))])
    model.set_params(arg_params, aux_params)
  else:
    assert args.param1==0.0
  rec_list = []
  for ds in include_datasets:
    path_imgrec = os.path.join(ds, 'train.rec')
    path_imgidx = os.path.join(ds, 'train.idx')
    imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')  # pylint: disable=redefined-variable-type
    rec_list.append(imgrec)
  id_list_map = {}
  all_id_list = []
  test_limit = 0
  for ds_id in xrange(len(rec_list)):
    id_list = []
    imgrec = rec_list[ds_id]
    s = imgrec.read_idx(0)
    header, _ = mx.recordio.unpack(s)
    assert header.flag>0
    print('header0 label', header.label)
    header0 = (int(header.label[0]), int(header.label[1]))
    #assert(header.flag==1)
    imgidx = range(1, int(header.label[0]))
    id2range = {}
    seq_identity = range(int(header.label[0]), int(header.label[1]))
    pp=0
    for identity in seq_identity:
      pp+=1
      if pp%10==0:
        print('processing id', pp)
      if model is not None:
        embedding = get_embedding(args, imgrec, identity, image_size, model)
      else:
        embedding = None
      #print(embedding.shape)
      id_list.append( [ds_id, identity, embedding] )
      if test_limit>0 and pp>=test_limit:
        break
    id_list_map[ds_id] = id_list
    if ds_id==0 or model is None:
      all_id_list += id_list
      print(ds_id, len(id_list))
    else:
      X = []
      for id_item in all_id_list:
        X.append(id_item[2])
      X = np.array(X)
      for i in xrange(len(id_list)):
        id_item = id_list[i]
        y = id_item[2]
        sim = np.dot(X, y.T)
        idx = np.where(sim>=args.param1)[0]
        if len(idx)>0:
          continue
        all_id_list.append(id_item)
      print(ds_id, len(id_list), len(all_id_list))


  if len(args.exclude)>0:
    if os.path.isdir(args.exclude):
      _path_imgrec = os.path.join(args.exclude, 'train.rec')
      _path_imgidx = os.path.join(args.exclude, 'train.idx')
      _imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec, 'r')  # pylint: disable=redefined-variable-type
      _ds_id = len(rec_list)
      _id_list = []
      s = _imgrec.read_idx(0)
      header, _ = mx.recordio.unpack(s)
      assert header.flag>0
      print('header0 label', header.label)
      header0 = (int(header.label[0]), int(header.label[1]))
      #assert(header.flag==1)
      imgidx = range(1, int(header.label[0]))
      seq_identity = range(int(header.label[0]), int(header.label[1]))
      pp=0
      for identity in seq_identity:
        pp+=1
        if pp%10==0:
          print('processing ex id', pp)
        embedding = get_embedding(args, _imgrec, identity, image_size, model)
        #print(embedding.shape)
        _id_list.append( (_ds_id, identity, embedding) )
        if test_limit>0 and pp>=test_limit:
          break
    else:
      _id_list = []
      data_set = verification.load_bin(args.exclude, image_size)[0][0]
      print(data_set.shape)
      data = nd.zeros( (1,3,image_size[0], image_size[1]))
      for i in xrange(data_set.shape[0]):
        data[0] = data_set[i]
        db = mx.io.DataBatch(data=(data,))
        model.forward(db, is_train=False)
        net_out = model.get_outputs()
        embedding = net_out[0].asnumpy().flatten()
        _norm=np.linalg.norm(embedding)
        embedding /= _norm
        _id_list.append( (i, i, embedding) )

    #X = []
    #for id_item in all_id_list:
    #  X.append(id_item[2])
    #X = np.array(X)
    #param1 = 0.3
    #while param1<=1.01:
    #  emap = {}
    #  for id_item in _id_list:
    #    y = id_item[2]
    #    sim = np.dot(X, y.T)
    #    #print(sim.shape)
    #    #print(sim)
    #    idx = np.where(sim>=param1)[0]
    #    for j in idx:
    #      emap[j] = 1
    #  exclude_removed = len(emap)
    #  print(param1, exclude_removed)
    #  param1+=0.05

      X = []
      for id_item in all_id_list:
        X.append(id_item[2])
      X = np.array(X)
      emap = {}
      for id_item in _id_list:
        y = id_item[2]
        sim = np.dot(X, y.T)
        idx = np.where(sim>=args.param2)[0]
        for j in idx:
          emap[j] = 1
          all_id_list[j][1] = -1
      print('exclude', len(emap))

  if args.test>0:
    return

  if not os.path.exists(args.output):
    os.makedirs(args.output)
  writer = mx.recordio.MXIndexedRecordIO(os.path.join(args.output, 'train.idx'), os.path.join(args.output, 'train.rec'), 'w')
  idx = 1
  identities = []
  nlabel = -1
  for id_item in all_id_list:
    if id_item[1]<0:
      continue
    nlabel+=1
    ds_id = id_item[0]
    imgrec = rec_list[ds_id]
    id = id_item[1]
    s = imgrec.read_idx(id)
    header, _ = mx.recordio.unpack(s)
    a, b = int(header.label[0]), int(header.label[1])
    identities.append( (idx, idx+b-a) )
    for _idx in xrange(a,b):
      s = imgrec.read_idx(_idx)
      _header, _content = mx.recordio.unpack(s)
      nheader = mx.recordio.IRHeader(0, nlabel, idx, 0)
      s = mx.recordio.pack(nheader, _content)
      writer.write_idx(idx, s)
      idx+=1
  id_idx = idx
  for id_label in identities:
    _header = mx.recordio.IRHeader(1, id_label, idx, 0)
    s = mx.recordio.pack(_header, '')
    writer.write_idx(idx, s)
    idx+=1
  _header = mx.recordio.IRHeader(1, (id_idx, idx), 0, 0)
  s = mx.recordio.pack(_header, '')
  writer.write_idx(0, s)
  with open(os.path.join(args.output, 'property'), 'w') as f:
    f.write("%d,%d,%d"%(len(identities), image_size[0], image_size[1]))
Пример #5
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    data_dir = config.dataset_path
    path_imgrecs = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrecs = [os.path.join(data_dir, "train.rec")]

    data_shape = (args.image_channel, image_size[0], image_size[1])

    num_workers = config.num_workers
    global_num_ctx = num_workers * args.ctx_num
    if config.num_classes % global_num_ctx == 0:
        args.ctx_num_classes = config.num_classes // global_num_ctx
    else:
        args.ctx_num_classes = config.num_classes // global_num_ctx + 1
    print(config.num_classes, global_num_ctx, args.ctx_num_classes)
    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    #if len(args.partial)==0:
    #  local_classes_range = (0, args.num_classes)
    #else:
    #  _vec = args.partial.split(',')
    #  local_classes_range = (int(_vec[0]), int(_vec[1]))

    #args.partial_num_classes = local_classes_range[1] - local_classes_range[0]
    #args.partial_start = local_classes_range[0]

    print('Called with argument:', args, config)
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    esym = get_symbol_embedding()
    asym = get_symbol_arcface
    if config.num_workers == 1:
        sys.path.append(os.path.join(os.path.dirname(__file__), 'utils'))
        from parall_module_local_v1 import ParallModule
    else:
        from parall_module_dist import ParallModule

    model = ParallModule(
        context=ctx,
        symbol=esym,
        data_names=['data'],
        label_names=['softmax_label'],
        asymbol=asym,
        args=args,
    )

    val_dataiter = None
    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrecs=path_imgrecs,
        shuffle=True,
        rand_mirror=config.data_rand_mirror,
        mean=mean,
        cutout=default.cutout if config.data_cutout else None,
        crop=default.crop if config.data_crop else None,
        mask=default.mask if config.data_mask else None,
        gridmask=default.gridmask if config.data_grid else None,
        #color_jittering      = config.data_color,
        #images_filter        = config.data_images_filter,
        loss_type=args.loss,
        #margin_m             = config.loss_m2,
        data_names=['data'],
        downsample_back=config.downsample_back,
        motion_blur=config.motion_blur,
        use_bgr=config.use_bgr)

    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)

    _rescale = 1.0 / 8  #/ args.batch_size
    print(base_lr, base_mom, base_wd, args.batch_size)

    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_steps,
                                                        factor=0.1,
                                                        base_lr=base_lr)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': base_mom,
        'wd': base_wd,
        'rescale_grad': _rescale,
        'lr_scheduler': lr_scheduler
    }

    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)

    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            _, issame_list = ver_list[i]
            if all(issame_list):
                fp_rates, fp_dict, thred_dict, recall_dict = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    use_bgr=config.use_bgr,
                    label_shape=(args.batch_size, len(path_imgrecs)))
                for k in fp_rates:
                    print("[%s] TPR at FPR %.2e[%.2e: %.4f]:\t%.5f" %
                          (ver_name_list[i], k, fp_dict[k], thred_dict[k],
                           recall_dict[k]))

            else:
                acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    10,
                    None,
                    label_shape=(args.batch_size, len(path_imgrecs)),
                    use_bgr=config.use_bgr)
                print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
                #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
                print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                      (ver_name_list[i], nbatch, acc2, std2))
                results.append(acc2)

        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]

        #for step in lr_steps:
        #  if mbatch==step:
        #    opt.lr *= 0.1
        #    print('lr change to', opt.lr)
        #    break

        _cb(param)
        if mbatch % 1000 == 0:
            #print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
            print('batch-epoch:', param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()  #get_export_params()
                all_layers = model.symbol.get_internals()
                _sym = model.symbol  #all_layers['fc1_output']
                mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    if len(args.pretrained) != 0:
        model_prefix, epoch = args.pretrained.split(',')
        begin_epoch = int(epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_prefix, begin_epoch)
        #model.set_params(arg_params, aux_params)

    model.fit(
        train_dataiter,
        begin_epoch=0,  #begin_epoch,
        num_epoch=default.end_epoch,
        eval_data=val_dataiter,
        #eval_metric        = eval_metrics,
        kvstore=args.kvstore,
        #optimizer          = opt,
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #6
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        print('loading', args.pretrained, args.pretrained_epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym,
                                          data=(1, 3, image_size[0],
                                                image_size[1]))
        print('Network FLOPs: %d' % FLOPs)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    if config.loss_name.find('triplet') >= 0:
        from triplet_image_iter import FaceImageIter
        triplet_params = [
            config.triplet_bag_size, config.triplet_alpha,
            config.triplet_max_ap
        ]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))

    if config.net_name == 'fresnet' or config.net_name == 'fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, msave, model.symbol, arg,
                                             aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=999999,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore=args.kvstore,
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #7
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert(args.num_classes>0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0]=='s':
      data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
      spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context       = ctx,
        symbol        = sym,
    )

    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = args.rand_mirror,
        mean                 = mean,
        cutoff               = args.cutoff,
    )
    val_rec = os.path.join(data_dir, "val.rec")
    val_iter = None
    if os.path.exists(val_rec):
        val_iter = FaceImageIter(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = val_rec,
            shuffle              = False,
            rand_mirror          = False,
            mean                 = mean,
        )

    if args.loss_type<10:
      _metric = AccMetric()
    else:
      _metric = LossValueMetric()
    eval_metrics = []
    if USE_FR:
      _metric = AccMetric(pred_idx=1)
      eval_metrics.append(_metric)
      if USE_GENDER:
          _metric = AccMetric(pred_idx=2, name='gender')
          eval_metrics.append(_metric)
    elif USE_GENDER:
      _metric = AccMetric(pred_idx=1, name='gender')
      eval_metrics.append(_metric)
    if USE_AGE:
      _metric = MAEMetric()
      eval_metrics.append(_metric)
      _metric = CUMMetric()
      eval_metrics.append(_metric)

    if args.network[0]=='r':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0/args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results

    def val_test():
      _metric = MAEMetric()
      val_metric = mx.metric.create(_metric)
      val_metric.reset()
      _metric2 = CUMMetric()
      val_metric2 = mx.metric.create(_metric2)
      val_metric2.reset()
      val_iter.reset()
      for i, eval_batch in enumerate(val_iter):
        model.forward(eval_batch, is_train=False)
        model.update_metric(val_metric, eval_batch.label)
        model.update_metric(val_metric2, eval_batch.label)
      _value = val_metric.get_name_value()[0][1]
      print('MAE: %f'%(_value))
      _value = val_metric2.get_name_value()[0][1]
      print('CUM: %f'%(_value))


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for _lr in lr_steps:
        if mbatch==_lr:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        if val_iter is not None:
            val_test()
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        if len(acc_list)>0:
          lfw_score = acc_list[0]
          if lfw_score>highest_acc[0]:
            highest_acc[0] = lfw_score
            if lfw_score>=0.998:
              do_save = True
          if acc_list[-1]>=highest_acc[-1]:
            highest_acc[-1] = acc_list[-1]
            if lfw_score>=0.99:
              do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt>1:
          do_save = True
        if do_save:
          print('saving', msave)
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

    epoch_cb = None

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = None,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Пример #8
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0] == 's':
        data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(context=ctx, symbol=sym)
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size  = args.batch_size,
        data_shape  = data_shape,
        path_imgrec = path_imgrec,
        shuffle     = True,
        rand_mirror = args.rand_mirror,
        mean        = mean,
        cutoff      = args.cutoff)

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(
            rnd_type='gaussian', factor_type="out", magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(
            rnd_type='gaussian', factor_type="in", magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(
            rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(
        learning_rate = base_lr,
        momentum      = base_mom,
        wd            = base_wd,
        rescale_grad  = _rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list, best_all = verification.test(
                ver_list[i], model, min(args.batch_size, 256), 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            print('[%s][%d]Best-Threshold: %1.2f  %1.5f' %
                  (ver_name_list[i], nbatch, best_all[0], best_all[1]))
            results.append(acc2)
        return results

    def highest_cmp(acc, cpt):
        assert len(acc) > 0
        if acc[0] > cpt[1]:
            return True
        elif acc[0] < cpt[1]:
            return False
        else:
            acc_sum = 0.0
            cpt_sum = 0.0
            for i in range(1, len(acc)):
                acc_sum += acc[i]
                cpt_sum += cpt[i+1]
            if acc_sum >= cpt_sum:
                return True
            else:
                return False

    highest_acc = []  # lfw and target
    for i in range(len(ver_list)):
        highest_acc.append(0.0)
    highest_cpt = [0] + highest_acc
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in range(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            do_save = False
            if len(acc_list) > 0:
                if acc_list[0] > 0.997:  # lfw
                    for i in range(len(acc_list)):
                        if acc_list[i] >= highest_acc[i]:
                            do_save = True
                for i in range(len(acc_list)):
                    highest_acc[i] = max(highest_acc[i], acc_list[i])
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                save_step[0] += 1
                msave = save_step[0]
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
                if highest_cmp(acc_list, highest_cpt):
                    highest_cpt[0] = msave
                    for i, acc in enumerate(acc_list):
                        highest_cpt[i+1] = acc
            sys.stdout.write('[%d]Accuracy-Highest: ' % mbatch)
            for acc in highest_acc:
                sys.stdout.write('%1.5f  ' % acc)
            sys.stdout.write('\n')
            sys.stdout.write('[%d]Accuracy-BestCpt: (%d) ' % (mbatch, highest_cpt[0]))
            for acc in highest_cpt[1:]:
                sys.stdout.write('%1.5f  ' % acc)
            sys.stdout.write('\n')
            sys.stdout.flush()
            # print('[%d]Accuracy-Highest: %1.5f  %1.5f  %1.5f'%(mbatch, highest_acc[0], highest_acc[1], highest_acc[2]))
            # print('[%d]Accuracy-BestCPt: <%d> %1.5f  %1.5f  %1.5f' % ((mbatch,) + tuple(highest_cpt)))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        # optimizer_params = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb)
Пример #9
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.image_channel = 3

    data_dir = args.data_dir
    if args.task=='gender':
      data_dir = args.gender_data_dir
    elif args.task=='age':
      data_dir = args.age_data_dir
    print('data dir', data_dir)
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert(args.num_classes>0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")


    print('Called with argument:', args)
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None

    begin_epoch = 0
    net = get_model()
    #if args.task=='':
    #  test_net = get_model_test(net)
    #print(net.__class__)
    #net = net0[0]
    if args.network[0]=='r' or args.network[0]=='y':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    net.hybridize()
    if args.mode=='gluon':
      if len(args.pretrained)==0:
        pass
      else:
        net.load_params(args.pretrained, allow_missing=True, ignore_extra = True)
      net.initialize(initializer)
      net.collect_params().reset_ctx(ctx)

    val_iter = None
    if args.task=='':
      train_iter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = args.rand_mirror,
          mean                 = mean,
          cutoff               = args.cutoff,
      )
    else:
      train_iter = FaceImageIterAge(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          task                 = args.task,
          shuffle              = True,
          rand_mirror          = args.rand_mirror,
          mean                 = mean,
          cutoff               = args.cutoff,
      )

    if args.task=='age':
      metric = CompositeEvalMetric([MAEMetric(), CUMMetric()])
    elif args.task=='gender':
      metric = CompositeEvalMetric([AccMetric()])
    else:
      metric = CompositeEvalMetric([AccMetric()])

    ver_list = []
    ver_name_list = []
    if args.task=='':
      for name in args.eval.split(','):
        path = os.path.join(data_dir,name+".bin")
        if os.path.exists(path):
          data_set = verification.load_bin(path, image_size)
          ver_list.append(data_set)
          ver_name_list.append(name)
          print('ver', name)

    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], net, ctx, batch_size = args.batch_size)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results

    def val_test(nbatch=0):
      acc = 0.0
      #if args.task=='age':
      if len(args.age_data_dir)>0:
        val_iter = FaceImageIterAge(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = os.path.join(args.age_data_dir, 'val.rec'),
            task                 = args.task,
            shuffle              = False,
            rand_mirror          = False,
            mean                 = mean,
        )
        _metric = MAEMetric()
        val_metric = mx.metric.create(_metric)
        val_metric.reset()
        _metric2 = CUMMetric()
        val_metric2 = mx.metric.create(_metric2)
        val_metric2.reset()
        val_iter.reset()
        for batch in val_iter:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            for x in data:
                outputs.append(net(x)[2])
            val_metric.update(label, outputs)
            val_metric2.update(label, outputs)
        _value = val_metric.get_name_value()[0][1]
        print('[%d][VMAE]: %f'%(nbatch, _value))
        _value = val_metric2.get_name_value()[0][1]
        if args.task=='age':
          acc = _value
        print('[%d][VCUM]: %f'%(nbatch, _value))
      if len(args.gender_data_dir)>0:
        val_iter = FaceImageIterAge(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = os.path.join(args.gender_data_dir, 'val.rec'),
            task                 = args.task,
            shuffle              = False,
            rand_mirror          = False,
            mean                 = mean,
        )
        _metric = AccMetric()
        val_metric = mx.metric.create(_metric)
        val_metric.reset()
        val_iter.reset()
        for batch in val_iter:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            for x in data:
                outputs.append(net(x)[1])
            val_metric.update(label, outputs)
        _value = val_metric.get_name_value()[0][1]
        if args.task=='gender':
          acc = _value
        print('[%d][VACC]: %f'%(nbatch, _value))
      return acc


    total_time = 0
    num_epochs = 0
    best_acc = [0]
    highest_acc = [0.0, 0.0]  #lfw and target
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    kv = mx.kv.create('device')
    #kv = mx.kv.create('local')
    #_rescale = 1.0/args.ctx_num
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd)
    if args.mode=='gluon':
      trainer = gluon.Trainer(net.collect_params(), 'sgd', 
              {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.mom, 'multi_precision': True},
              kvstore=kv)
    else:
      _rescale = 1.0/args.ctx_num
      opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
      _cb = mx.callback.Speedometer(args.batch_size, 20)
      arg_params = None
      aux_params = None
      data = mx.sym.var('data')
      label = mx.sym.var('softmax_label')
      if args.margin_a>0.0:
        fc7 = net(data, label)
      else:
        fc7 = net(data)
      #sym = mx.symbol.SoftmaxOutput(data=fc7, label = label, name='softmax', normalization='valid')
      ceop = gluon.loss.SoftmaxCrossEntropyLoss()
      loss = ceop(fc7, label) 
      #loss = loss/args.per_batch_size
      loss = mx.sym.mean(loss)
      sym = mx.sym.Group( [mx.symbol.BlockGrad(fc7), mx.symbol.MakeLoss(loss, name='softmax')] )

    def _batch_callback():
      mbatch = global_step[0]
      global_step[0]+=1
      for _lr in lr_steps:
        if mbatch==_lr:
          args.lr *= 0.1
          if args.mode=='gluon':
            trainer.set_learning_rate(args.lr)
          else:
            opt.lr  = args.lr
          print('lr change to', args.lr)
          break

      #_cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',args.lr, mbatch)

      if mbatch>0 and mbatch%args.verbose==0:
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if args.task=='age' or args.task=='gender':
          acc = val_test(mbatch)
          if acc>=highest_acc[-1]:
            highest_acc[-1] = acc
            is_highest = True
            do_save = True
        else:
          acc_list = ver_test(mbatch)
          if len(acc_list)>0:
            lfw_score = acc_list[0]
            if lfw_score>highest_acc[0]:
              highest_acc[0] = lfw_score
              if lfw_score>=0.998:
                do_save = True
            if acc_list[-1]>=highest_acc[-1]:
              highest_acc[-1] = acc_list[-1]
              if lfw_score>=0.99:
                do_save = True
                is_highest = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt>1:
          do_save = True
        if do_save:
          print('saving', msave)
          #print('saving gluon params')
          fname = os.path.join(args.prefix, 'model-gluon.params')
          net.save_params(fname)
          fname = os.path.join(args.prefix, 'model')
          net.export(fname, msave)
          #arg, aux = model.get_params()
          #mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

    def _batch_callback_sym(param):
      _cb(param)
      _batch_callback()


    if args.mode!='gluon':
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
      )
      model.fit(train_iter,
          begin_epoch        = 0,
          num_epoch          = args.end_epoch,
          eval_data          = None,
          eval_metric        = metric,
          kvstore            = 'device',
          optimizer          = opt,
          initializer        = initializer,
          arg_params         = arg_params,
          aux_params         = aux_params,
          allow_missing      = True,
          batch_end_callback = _batch_callback_sym,
          epoch_end_callback = None )
    else:
      loss_weight = 1.0
      if args.task=='age':
        loss_weight = 1.0/AGE
      #loss = gluon.loss.SoftmaxCrossEntropyLoss(weight = loss_weight)
      loss = nd.SoftmaxOutput
      #loss = gluon.loss.SoftmaxCrossEntropyLoss()
      while True:
          #trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
          tic = time.time()
          train_iter.reset()
          metric.reset()
          btic = time.time()
          for i, batch in enumerate(train_iter):
              _batch_callback()
              #data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
              #label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
              data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
              label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
              outputs = []
              Ls = []
              with ag.record():
                  for x, y in zip(data, label):
                      #print(y.asnumpy())
                      if args.task=='':
                        if args.margin_a>0.0:
                          z = net(x,y)
                        else:
                          z = net(x)
                        #print(z[0].shape, z[1].shape)
                      else:
                        z = net(x)
                      if args.task=='gender':
                        L = loss(z[1], y)
                        #L = L/args.per_batch_size
                        Ls.append(L)
                        outputs.append(z[1])
                      elif args.task=='age':
                        for k in xrange(AGE):
                          _z = nd.slice_axis(z[2], axis=1, begin=k*2, end=k*2+2)
                          _y = nd.slice_axis(y, axis=1, begin=k, end=k+1)
                          _y = nd.flatten(_y)
                          L = loss(_z, _y)
                          #L = L/args.per_batch_size
                          #L /= AGE
                          Ls.append(L)
                        outputs.append(z[2])
                      else:
                        L = loss(z, y)
                        #L = L/args.per_batch_size
                        Ls.append(L)
                        outputs.append(z)
                      # store the loss and do backward after we have done forward
                      # on all GPUs for better speed on multiple GPUs.
                  ag.backward(Ls)
              #trainer.step(batch.data[0].shape[0], ignore_stale_grad=True)
              #trainer.step(args.ctx_num)
              n = batch.data[0].shape[0]
              #print(n,n)
              trainer.step(n)
              metric.update(label, outputs)
              if i>0 and i%20==0:
                  name, acc = metric.get()
                  if len(name)==2:
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'%(
                                   num_epochs, i, args.batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
                  else:
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
                                   num_epochs, i, args.batch_size/(time.time()-btic), name[0], acc[0]))
                  #metric.reset()
              btic = time.time()

          epoch_time = time.time()-tic

          # First epoch will usually be much slower than the subsequent epics,
          # so don't factor into the average
          if num_epochs > 0:
            total_time = total_time + epoch_time

          #name, acc = metric.get()
          #logger.info('[Epoch %d] training: %s=%f, %s=%f'%(num_epochs, name[0], acc[0], name[1], acc[1]))
          logger.info('[Epoch %d] time cost: %f'%(num_epochs, epoch_time))
          num_epochs = num_epochs + 1
          #name, val_acc = test(ctx, val_data)
          #logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

          # save model if meet requirements
          #save_checkpoint(epoch, val_acc[0], best_acc)
      if num_epochs > 1:
          print('Average epoch time: {}'.format(float(total_time)/(num_epochs - 1)))
Пример #10
0
def main(args):
    include_datasets = args.include.split(',')
    prop = face_image.load_property(include_datasets[0])
    image_size = prop.image_size
    print('image_size', image_size)
    model = None
    if len(args.model) > 0:
        ctx = []
        cvd = ''
        if 'CUDA_VISIBLE_DEVICES' in os.environ:
            cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
        if len(cvd) > 0:
            for i in xrange(len(cvd.split(','))):
                ctx.append(mx.gpu(i))
        if len(ctx) == 0:
            ctx = [mx.cpu()]
            print('use cpu')
        else:
            print('gpu num:', len(ctx))
        args.ctx_num = len(ctx)
        vec = args.model.split(',')
        prefix = vec[0]
        epoch = int(vec[1])
        print('loading', prefix, epoch)
        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

        all_layers = sym.get_internals()
        sym = all_layers['fc1_output']
        #model = mx.mod.Module.load(prefix, epoch, context = ctx)
        #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
        model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
        model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
                                          image_size[1]))])
        model.set_params(arg_params, aux_params)

    rec_list = []
    for ds in include_datasets:
        path_imgrec = os.path.join(ds, 'train.rec')
        path_imgidx = os.path.join(ds, 'train.idx')
        # allow using user paths by expanding tilde
        path_imgrec = os.path.expanduser(path_imgrec)
        path_imgidx = os.path.expanduser(path_imgidx)
        # print('path_imgrec:', path_imgrec)
        imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')  # pylint: disable=redefined-variable-type
        # print('imgrec:', imgrec)
        rec_list.append(imgrec)
    id_list_map = {}
    all_id_list = []
    test_limit = 0
    for ds_id in xrange(len(rec_list)):
        id_list = []
        imgrec = rec_list[ds_id]
        # print('----',ds_id,'----')
        # print('imgrec:', imgrec)
        # print('keys:', imgrec.keys)
        s = imgrec.read_idx(0)
        # s = imgrec.read()
        header, _ = mx.recordio.unpack(s)
        # print('header:', header)
        assert header.flag > 0
        print('header0 label', header.label)
        header0 = (int(header.label[0]), int(header.label[1]))
        #assert(header.flag==1)
        imgidx = range(1, int(header.label[0]))
        id2range = {}
        # print('** get_embedding:', int(header.label[0]), int(header.label[1]))
        seq_identity = range(int(header.label[0]), int(header.label[1]))
        pp = 0
        for identity in seq_identity:
            pp += 1
            if pp % 10 == 0:
                print('processing id', pp)
            if model is not None:
                embedding = get_embedding(args, imgrec, identity, image_size,
                                          model)
            else:
                embedding = None
            #print(embedding.shape)
            id_list.append([ds_id, identity, embedding])
            if test_limit > 0 and pp >= test_limit:
                break
        id_list_map[ds_id] = id_list
        if ds_id == 0 or model is None:
            all_id_list += id_list
            print(ds_id, len(id_list))
        else:
            X = []
            for id_item in all_id_list:
                X.append(id_item[2])
            X = np.array(X)
            for i in xrange(len(id_list)):
                id_item = id_list[i]
                y = id_item[2]
                sim = np.dot(X, y.T)
                # print('sim:', sim)
                idx = np.where(
                    sim >= args.similarity_upper_threshold_include)[0]
                # we are confident this identity already exists in the set
                if len(idx) > 0:
                    continue
                idx = np.where(
                    sim >= args.similarity_lower_threshold_include)[0]
                # this identity might not exist in the set, let's manually check that
                line = []
                if len(idx) > 0:
                    # print('possible duplicate:', idx)
                    # store in the file the current set path, current identity, [possible duplicate set path, possible duplicate identity]xn
                    # where n is number of possible duplicates found
                    line.append(include_datasets[ds_id])
                    line.append(id_item[1])
                    for duplicate_id in idx:
                        duplicate_dataset_index = all_id_list[duplicate_id][0]
                        line.append(include_datasets[duplicate_dataset_index])
                        duplicate_identity = all_id_list[duplicate_id][1]
                        line.append(duplicate_identity)
                    with open('duplicates.csv', 'a') as csv_file:
                        writer = csv.writer(csv_file, delimiter=',')
                        # print('line: ',line)
                        writer.writerows([line])
                    continue
                all_id_list.append(id_item)
            print(ds_id, len(id_list), len(all_id_list))

    if len(args.exclude) > 0:
        if os.path.isdir(args.exclude):
            _path_imgrec = os.path.join(args.exclude, 'train.rec')
            _path_imgidx = os.path.join(args.exclude, 'train.idx')
            _imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec,
                                                    'r')  # pylint: disable=redefined-variable-type
            _ds_id = len(rec_list)
            _id_list = []
            s = _imgrec.read_idx(0)
            header, _ = mx.recordio.unpack(s)
            assert header.flag > 0
            print('header0 label', header.label)
            header0 = (int(header.label[0]), int(header.label[1]))
            #assert(header.flag==1)
            imgidx = range(1, int(header.label[0]))
            seq_identity = range(int(header.label[0]), int(header.label[1]))
            pp = 0
            for identity in seq_identity:
                pp += 1
                if pp % 10 == 0:
                    print('processing ex id', pp)
                embedding = get_embedding(args, _imgrec, identity, image_size,
                                          model)
                #print(embedding.shape)
                _id_list.append((_ds_id, identity, embedding))
                if test_limit > 0 and pp >= test_limit:
                    break
        else:
            _id_list = []
            data_set = verification.load_bin(args.exclude, image_size)[0][0]
            print(data_set.shape)
            data = nd.zeros((1, 3, image_size[0], image_size[1]))
            for i in range(data_set.shape[0]):
                data[0] = data_set[i]
                db = mx.io.DataBatch(data=(data, ))
                model.forward(db, is_train=False)
                net_out = model.get_outputs()
                embedding = net_out[0].asnumpy().flatten()
                _norm = np.linalg.norm(embedding)
                embedding /= _norm
                _id_list.append((i, i, embedding))

            X = []
            for id_item in all_id_list:
                X.append(id_item[2])
            X = np.array(X)
            emap = {}
            for id_item in _id_list:
                y = id_item[2]
                sim = np.dot(X, y.T)
                idx = np.where(sim >= args.similarity_threshold_exclude)[0]
                for j in idx:
                    emap[j] = 1
                    all_id_list[j][1] = -1
            print('exclude', len(emap))

    if args.test > 0:
        return

    if not os.path.exists(args.output):
        os.makedirs(args.output)
    writer = mx.recordio.MXIndexedRecordIO(
        os.path.join(args.output, 'train.idx'),
        os.path.join(args.output, 'train.rec'), 'w')
    idx = 1
    identities = []
    nlabel = -1
    for id_item in all_id_list:
        if id_item[1] < 0:
            continue
        nlabel += 1
        ds_id = id_item[0]
        imgrec = rec_list[ds_id]
        id = id_item[1]
        s = imgrec.read_idx(id)
        header, _ = mx.recordio.unpack(s)
        a, b = int(header.label[0]), int(header.label[1])
        identities.append((idx, idx + b - a))
        for _idx in xrange(a, b):
            s = imgrec.read_idx(_idx)
            _header, _content = mx.recordio.unpack(s)
            nheader = mx.recordio.IRHeader(0, nlabel, idx, 0)
            s = mx.recordio.pack(nheader, _content)
            writer.write_idx(idx, s)
            idx += 1
    id_idx = idx
    for id_label in identities:
        _header = mx.recordio.IRHeader(1, id_label, idx, 0)
        s = mx.recordio.pack(_header, '')
        writer.write_idx(idx, s)
        idx += 1
    _header = mx.recordio.IRHeader(1, (id_idx, idx), 0, 0)
    s = mx.recordio.pack(_header, '')
    writer.write_idx(0, s)
    with open(os.path.join(args.output, 'property'), 'w') as f:
        f.write("%d,%d,%d" % (len(identities), image_size[0], image_size[1]))
Пример #11
0
def train_net(args):
    ## =================== parse context ==========================
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx)==0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))


    ## ==================== get model save prefix and log ============
    if len(args.extra_model_name)==0:
        prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    else:
        prefix = os.path.join(args.models_root, '%s-%s-%s-%s'%(args.network, args.loss, args.dataset, args.extra_model_name), 'model')

    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    filehandler = logging.FileHandler("{}.log".format(prefix))
    streamhandler = logging.StreamHandler()
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    ## ================ parse batch size and class info ======================
    args.ctx_num = len(ctx)
    if args.per_batch_size==0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    
    global_num_ctx = config.num_workers * args.ctx_num
    if config.num_classes % global_num_ctx == 0:
        args.ctx_num_classes = config.num_classes//global_num_ctx
    else:
        args.ctx_num_classes = config.num_classes//global_num_ctx+1

    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    logger.info("Train model with argument: {}\nconfig : {}".format(args, config))

    train_dataiter, val_dataiter = get_data_iter(config, args.batch_size)

    ## =============== get train info ============================
    image_size = config.image_shape[0:2]
    if len(args.pretrained) == 0: # train from scratch 
        esym = get_symbol_embedding(config)
        asym = functools.partial(get_symbol_arcface, config=config)
    else: # load train model to continue
        assert False

    if config.count_flops:
        all_layers = esym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
        _str = flops_counter.flops_str(FLOPs)
        print('Network FLOPs: %s'%_str)
        logging.info("Network FLOPs : %s" % _str)

    if config.num_workers==1:
        #from parall_loss_module import ParallLossModule
        from parall_module_local_v1 import ParallModule
    else: # distribute parall loop
        assert False


    model = ParallModule(
        context       = ctx,
        symbol        = esym,
        data_names    = ['data'],
        label_names    = ['softmax_label'],
        asymbol       = asym,
        args = args,
        logger=logger,
    )
    

    ## ============ get optimizer =====================================
    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
        initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)

    opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=1.0/args.batch_size)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(config.dataset_path, name+".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    

    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    ## =============== batch end callback definition ===================================
    def _batch_callback(param):
        #global global_step
        global_step[0]+=1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch==step:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                logger.info('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch%1000==0:
            print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
            logger.info('lr-batch-epoch: {}'.format(opt.lr,param.nbatch,param.epoch))

        if mbatch>=0 and mbatch%args.verbose==0:
            acc_list = ver_test(mbatch)
            save_step[0]+=1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list)>0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1]>=highest_acc[-1]:
                    if acc_list[-1]>highest_acc[-1]:
                        is_highest = True
                    else:
                        if score>=highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    
            if is_highest:
                do_save = True
            if args.ckpt==0:
                do_save = False
            elif args.ckpt==2:
                do_save = True
            elif args.ckpt==3:
                msave = 1

            if do_save:
                print('saving', msave)
                logger.info('saving {}'.format(msave))

                arg, aux = model.get_export_params()
                all_layers = model.symbol.get_internals()
                _sym = all_layers['fc1_output']
                mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
            logger.info('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))

        if config.max_steps>0 and mbatch>config.max_steps:
            sys.exit(0)

    model.fit(train_dataiter,
        begin_epoch        = 0,
        num_epoch          = 999999,
        eval_data          = val_dataiter,
        kvstore            = args.kvstore,
        optimizer          = opt,
        initializer        = initializer,
        arg_params         = None,
        aux_params         = None,
        allow_missing      = True,
        batch_end_callback = _batch_callback)
Пример #12
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
      if args.loss_type==10:
        args.per_batch_size = 256
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    assert len(ppatch)==5


    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    if args.loss_type!=12 and args.loss_type!=13:
      assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert(args.num_classes>0)
    print('num_classes', args.num_classes)
    args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")
    val_rec = os.path.join(data_dir, "val.rec")
    if os.path.exists(val_rec) and args.loss_type<10:
      args.use_val = True
    else:
      val_rec = None
    #args.use_val = False

    if args.loss_type==1 and args.num_classes>20000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    if args.loss_type<9:
      assert args.images_per_identity==0
    else:
      if args.images_per_identity==0:
        if args.loss_type==11:
          args.images_per_identity = 2
        elif args.loss_type==10 or args.loss_type==9:
          args.images_per_identity = 16
        elif args.loss_type==12 or args.loss_type==13:
          args.images_per_identity = 5
          assert args.per_batch_size%3==0
      assert args.images_per_identity>=2
      args.per_identities = int(args.per_batch_size/args.images_per_identity)

    print('Called with argument:', args)

    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None




    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0]=='s':
      data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
      spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    data_extra = None
    hard_mining = False
    triplet_params = None
    coco_mode = False
    if args.loss_type==10:
      hard_mining = True
      _shape = (args.batch_size, args.per_batch_size)
      data_extra = np.full(_shape, -1.0, dtype=np.float32)
      c = 0
      while c<args.batch_size:
        a = 0
        while a<args.per_batch_size:
          b = a+args.images_per_identity
          data_extra[(c+a):(c+b),a:b] = 1.0
          #print(c+a, c+b, a, b)
          a = b
        c += args.per_batch_size
    elif args.loss_type==11:
      data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)
      c = 0
      while c<args.batch_size:
        for i in xrange(args.per_identities):
          data_extra[c+i][i] = 1.0
        c+=args.per_batch_size
    elif args.loss_type==12 or args.loss_type==13:
      triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]
    elif args.loss_type==9:
      coco_mode = True

    label_name = 'softmax_label'
    label_shape = (args.batch_size,)
    if args.output_c2c:
      label_shape = (args.batch_size,2)
    if data_extra is None:
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
      )
    else:
      data_names = ('data', 'extra')
      #label_name = ''
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
          data_names    = data_names,
          label_names   = (label_name,),
      )

    if args.use_val:
      val_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = val_rec,
          #path_imglist         = val_path,
          shuffle              = False,
          rand_mirror          = False,
          mean                 = mean,
          ctx_num              = args.ctx_num,
          data_extra           = data_extra,
      )
    else:
      val_dataiter = None

    if len(data_dir_list)==1 and args.loss_type!=12 and args.loss_type!=13:
      train_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = args.rand_mirror,
          mean                 = mean,
          cutoff               = args.cutoff,
          c2c_threshold        = args.c2c_threshold,
          output_c2c           = args.output_c2c,
          c2c_mode             = args.c2c_mode,
          limit                = args.train_limit,
          ctx_num              = args.ctx_num,
          images_per_identity  = args.images_per_identity,
          data_extra           = data_extra,
          hard_mining          = hard_mining,
          triplet_params       = triplet_params,
          coco_mode            = coco_mode,
          mx_model             = model,
          label_name           = label_name,
      )
    else:
      iter_list = []
      for _data_dir in data_dir_list:
        _path_imgrec = os.path.join(_data_dir, "train.rec")
        _dataiter = FaceImageIter(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = _path_imgrec,
            shuffle              = True,
            rand_mirror          = args.rand_mirror,
            mean                 = mean,
            cutoff               = args.cutoff,
            c2c_threshold        = args.c2c_threshold,
            output_c2c           = args.output_c2c,
            c2c_mode             = args.c2c_mode,
            limit                = args.train_limit,
            ctx_num              = args.ctx_num,
            images_per_identity  = args.images_per_identity,
            data_extra           = data_extra,
            hard_mining          = hard_mining,
            triplet_params       = triplet_params,
            coco_mode            = coco_mode,
            mx_model             = model,
            label_name           = label_name,
        )
        iter_list.append(_dataiter)
      iter_list.append(_dataiter)
      train_dataiter = FaceImageIterList(iter_list)

    if args.loss_type<10:
      _metric = AccMetric()
    else:
      _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0]=='r':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0/args.ctx_num
    if args.noise_sgd>0.0:
      print('use noise sgd')
      opt = NoiseSGD(scale = args.noise_sgd, learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    else:
      opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    if args.loss_type==12 or args.loss_type==13:
      som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, data_extra, label_shape)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    def val_test():
      acc = AccMetric()
      val_metric = mx.metric.create(acc)
      val_metric.reset()
      val_dataiter.reset()
      for i, eval_batch in enumerate(val_dataiter):
        model.forward(eval_batch, is_train=False)
        model.update_metric(val_metric, eval_batch.label)
      acc_value = val_metric.get_name_value()[0][1]
      print('VACC: %f'%(acc_value))


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        if len(acc_list)>0:
          lfw_score = acc_list[0]
          if lfw_score>highest_acc[0]:
            highest_acc[0] = lfw_score
            if lfw_score>=0.998:
              do_save = True
          if acc_list[-1]>=highest_acc[-1]:
            highest_acc[-1] = acc_list[-1]
            if lfw_score>=0.99:
              do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt>1:
          do_save = True
        #for i in xrange(len(acc_list)):
        #  acc = acc_list[i]
        #  if acc>=highest_acc[i]:
        #    highest_acc[i] = acc
        #    if lfw_score>=0.99:
        #      do_save = True
        #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
        #  do_save = True
        if do_save:
          print('saving', msave)
          if val_dataiter is not None:
            val_test()
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
          #if acc>=highest_acc[0]:
          #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
          #  X = np.concatenate(embeddings_list, axis=0)
          #  print('saving lfw npy', X.shape)
          #  np.save(lfw_npy, X)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta
      else:
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None



    #def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Пример #13
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom

    vec_s = args.pretrained_s.split(',')
    print('loading', vec_s)
    _, arg_params_s, aux_params_s = mx.model.load_checkpoint(
        vec_s[0], int(vec_s[1]))
    sym_s, arg_params_s, aux_params_s = get_symbol(args, arg_params_s,
                                                   aux_params_s)

    model_s = mx.mod.Module(context=ctx,
                            symbol=sym_s,
                            data_names=[
                                'data',
                            ],
                            label_names=['softmax_label', 'logit_t'])

    val_dataiter = None

    train_dataiter = FaceImageIter(
        args=args,
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
    )

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    mAcc = mx.metric.create(_metric)

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    som = 10
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model_s, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.3
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model_s.get_params()
                mx.model.save_checkpoint(prefix, msave, model_s.symbol, arg,
                                         aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
    BatchEndParam = namedtuple('BatchEndParams',
                               ['epoch', 'nbatch', 'eval_metric', 'locals'])

    model_s.bind(data_shapes=[('data', (args.batch_size, 3, 112, 112))], \
                label_shapes=[('softmax_label', (args.batch_size, )), ('logit_t', (args.batch_size, args.num_classes))], \
                for_training=True)
    model_s.init_params(initializer=initializer,
                        arg_params=arg_params_s,
                        aux_params=aux_params_s,
                        allow_missing=True,
                        force_init=False)
    model_s.init_optimizer(kvstore='device', optimizer=opt)

    for epoch in range(begin_epoch, end_epoch + 1):
        tic = time.time()
        mAcc.reset()
        nbatch = 0
        data_iter = iter(train_dataiter)
        end_of_batch = False

        while not end_of_batch:
            next_data_batch = next(data_iter)
            data_batch = next_data_batch
            model_s.forward(mx.io.DataBatch(data_batch.data, data_batch.label),
                            is_train=True)
            model_s.backward()
            model_s.update()
            model_s.update_metric(mAcc, data_batch.label)
            try:
                next_data_batch = next(data_iter)
                model_s.prepare(next_data_batch)
            except StopIteration:
                end_of_batch = True
            if end_of_batch:
                eval_name_vals = mAcc.get_name_value()
            batch_end_params = BatchEndParam(epoch=epoch,
                                             nbatch=nbatch,
                                             eval_metric=mAcc,
                                             locals=locals())
            for callback in _as_list(_batch_callback):
                callback(batch_end_params)
            nbatch += 1

        for name, val in eval_name_vals:
            model_s.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
        toc = time.time()
        model_s.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        arg_params, aux_params = model_s.get_params()
        model_s.set_params(arg_params, aux_params)

        train_dataiter.reset()
Пример #14
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()  # 0,使用第一块GPU

    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))  # 讲GPU context添加到ctx,ctx = [gpu(0)]

    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))  # 使用了gpu

    prefix = args.prefix  # ../model-r100
    prefix_dir = os.path.dirname(prefix)  # ..

    if not os.path.exists(prefix_dir):  # 未执行
        os.makedirs(prefix_dir)

    end_epoch = args.end_epoch  # 100 000

    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])

    print('num_layers', args.num_layers)  # 100

    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num  # 10

    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)  # 1000.0,参见Arcface公式(6),退火训练的lambda

    data_dir_list = args.data_dir.split(',')
    print('data_dir_list: ', data_dir_list)

    data_dir = data_dir_list[0]

    # 加载数据集属性
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    print('num_classes: ', args.num_classes)

    path_imgrec = os.path.join(data_dir, "train8631_list.rec")

    if args.loss_type == 1 and args.num_classes > 20000:  # sphereface
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('***Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])  # (3L,112L,112L)

    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd  # weight decay = 0.0005
    base_mom = args.mom  # 动量:0.9

    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')  # ['../models/model-r50-am-lfw/model', '0000']
        print('***loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        # print('sym[1]:',sym[1])
        # # mx.viz.plot_network(sym[1]).view() #可视化
        # sys.exit()
    if args.network[0] == 's':  # spherenet
        data_shape_dict = {'data': (args.per_batch_size,) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    # label_name = 'softmax_label'
    # label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )

    # print('args.batch_size:',args.batch_size)
    # print('data_shape:',data_shape)
    # print('path_imgrec:',path_imgrec)
    # print('args.rand_mirror:',args.rand_mirror)
    # print('mean:',mean)
    # print('args.cutoff:',args.cutoff)
    # sys.exit()

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,  # (3L,112L,112L)
        path_imgrec=path_imgrec,  # train.rec
        shuffle=True,
        rand_mirror=args.rand_mirror,  # 1
        mean=mean,
        cutoff=args.cutoff,  # 0
    )

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    # 创建一个评价指标
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y' or args.network[0] == 'v':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2)  # resnet style  mobilefacenet
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd,
                        rescale_grad=_rescale)  # 多卡训练的话,rescale_grad将总的结果分开
    # opt = optimizer.Adam(learning_rate=base_lr, wd=base_wd,rescale_grad=_rescale)
    som = 64
    # 回调函数,用来阶段性显示训练速度和准确率
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10,
                                                                               None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results


    highest_acc = [0.0, 0.0]  # lfw and target
    # for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [30000, 40000, 50000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [10000, 20000, 40000, 70000, 100000, 150000]
        # 单GPU,去掉p
        # p = 512.0/args.batch_size
        for l in range(len(lr_steps)):
            # lr_steps[l] = int(lr_steps[l]*p)
            lr_steps[l] = int(lr_steps[l])
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step

        mbatch = global_step[0]
        global_step[0] += 1
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            print(acc_list)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]

                # if lfw_score > highest_acc[0]:
                # if lfw_score >= 0.50:
                #     do_save = True
                #     highest_acc[0] = lfw_score
                    # 修改验证集阈值,测试最佳阈值
                    # if lfw_score>=0.998:
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    # if lfw_score>=0.99: #LFW测试大于0.99时,保存模型
                    if lfw_score >= 0.90:  # LFW测试大于0.99时,保存模型
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)

            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(args.beta_min, args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)  5
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
    print('data fit...........')
    model.fit(train_data=train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=None,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              # optimizer_params = optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)
Пример #15
0
def train_net(args):
    # Set up kvstore
    kv = mx.kvstore.create(args.kv_store)
    if args.gc_type != 'none':
        kv.set_gradient_compression({
            'type': args.gc_type,
            'threshold': args.gc_threshold
        })

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    # Get ctx according to num_gpus, gpu id start from 0
    ctx = []
    ctx = [mx.cpu()] if args.num_gpus is None or args.num_gpus is 0 else [
        mx.gpu(i) for i in range(args.num_gpus)
    ]

    # model prefix, In UAI Platform, should be /data/output/xxx
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    #image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")
    path_imglist = os.path.join(data_dir, "train.lst")

    num_samples = 0
    for line in open(path_imglist).xreadlines():
        num_samples += 1

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        # Not the mode is saved each epoch, not NUM of steps as in train_softmax.py
        # args.pretrained be 'prefix,epoch'
        vec = args.pretrained.split(',')
        print('loading', vec)
        model_prefix = vec[0]
        if kv.rank > 0 and os.path.exists("%s-%d-symbol.json" %
                                          (model_prefix, kv.rank)):
            model_prefix += "-%d" % (kv.rank)
        logging.info('Loaded model %s_%d.params', model_prefix, int(vec[1]))
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_prefix, int(vec[1]))
        begin_epoch = int(vec[1])
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target

    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)

    def _batch_callback(param):
        #global global_step
        mbatch = param.nbatch

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            is_highest = False
            if len(acc_list) > 0:
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True

            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))

    # save model
    checkpoint = _save_model(args, kv.rank)
    epoch_cb = checkpoint

    rescale = 1.0 / args.ctx_num
    lr, lr_scheduler = _get_lr_scheduler(args, kv, begin_epoch, num_samples)
    # learning rate
    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True,
        'rescale_grad': rescale
    }
    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag'}
    if args.optimizer in has_momentum:
        optimizer_params['momentum'] = args.mom

    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    print('Start training')
    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore=kv,
              optimizer=args.optimizer,
              optimizer_params=optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)
Пример #16
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    args.num_classes = 0
    image_size = (112,112)
    if os.path.exists(os.path.join(data_dir, 'property')):
      prop = face_image.load_property(data_dir)
      args.num_classes = prop.num_classes
      image_size = prop.image_size
      assert(args.num_classes>0)
      print('num_classes', args.num_classes)
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0]=='s':
      data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
      spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context       = ctx,
        symbol        = sym,
    )

    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = args.rand_mirror,
        mean                 = mean,
        cutoff               = args.cutoff,
    )
    val_rec = os.path.join(data_dir, "val.rec")
    val_iter = None
    if os.path.exists(val_rec):
        val_iter = FaceImageIter(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = val_rec,
            shuffle              = False,
            rand_mirror          = False,
            mean                 = mean,
        )

    eval_metrics = []
    if USE_FR:
      _metric = AccMetric(pred_idx=1, label_idx=0)
      eval_metrics.append(_metric)
      if USE_GENDER:
          _metric = AccMetric(pred_idx=2, label_idx=1, name='gender')
          eval_metrics.append(_metric)
    elif USE_GENDER:
      _metric = AccMetric(pred_idx=1, label_idx=1, name='gender')
      eval_metrics.append(_metric)
    if USE_AGE:
      _metric = MAEMetric()
      eval_metrics.append(_metric)
      _metric = CUMMetric()
      eval_metrics.append(_metric)

    if args.network[0]=='r':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0/args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    #opt = optimizer.Nadam(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results

    def val_test():
      _metric = MAEMetric()
      val_metric = mx.metric.create(_metric)
      val_metric.reset()
      _metric2 = CUMMetric()
      val_metric2 = mx.metric.create(_metric2)
      val_metric2.reset()
      val_iter.reset()
      for i, eval_batch in enumerate(val_iter):
        model.forward(eval_batch, is_train=False)
        model.update_metric(val_metric, eval_batch.label)
        model.update_metric(val_metric2, eval_batch.label)
      _value = val_metric.get_name_value()[0][1]
      print('MAE: %f'%(_value))
      _value = val_metric2.get_name_value()[0][1]
      print('CUM: %f'%(_value))


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for _lr in lr_steps:
        if mbatch==_lr:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        if val_iter is not None:
            val_test()
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        if len(acc_list)>0:
          lfw_score = acc_list[0]
          if lfw_score>highest_acc[0]:
            highest_acc[0] = lfw_score
            if lfw_score>=0.998:
              do_save = True
          if acc_list[-1]>=highest_acc[-1]:
            highest_acc[-1] = acc_list[-1]
            if lfw_score>=0.99:
              do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt>1:
          do_save = True
        if do_save:
          print('saving', msave)
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

    epoch_cb = None

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = None,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)  #GPU num
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args, config)
    data_shape = (args.image_channel,image_size[0],image_size[1]) # chw
    mean = None #[127.5,127.5,127.5]
    


    begin_epoch = 0
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym = get_symbol(args)  
      if config.net_name=='spherenet':
        data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:  #��Ԥѵ��ģ�ͣ�������,sym����get_symbol(args)������

      sym,sym_high,arg_params,aux_params,t_arg_params, t_aux_params = two_sym(args)
      d_sym = discriminator(args)

      
            
    config.count_flops=False #me add
    if config.count_flops:  #true
      all_layers = sym.get_internals()
      _sym = all_layers['fc1_output']  #ͼƬ�� 128 ά�ȵ�����fc1 ���ٶ�
      FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
      _str = flops_counter.flops_str(FLOPs)
      print('Network FLOPs: %s'%_str)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)

    val_dataiter = None

    if config.loss_name.find('triplet')>=0:
      from triplet_image_iter import FaceImageIter
      triplet_params = [config.triplet_bag_size, config.triplet_alpha, config.triplet_max_ap]
      train_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror,
        #   rand_resize          = True, #me add to differ resolution img 
          mean                 = mean,
          cutoff               = config.data_cutoff,
          ctx_num              = args.ctx_num,
          images_per_identity  = config.images_per_identity,
          triplet_params       = triplet_params,
          mx_model             = model,
      )
      _metric = LossValueMetric()
      eval_metrics = [mx.metric.create(_metric)]
    else:
      from distribute_image_iter import FaceImageIter

      train_dataiter_low = FaceImageIter(  #�õ� batch  img  label, train_dataiter_high
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror, #true
          rand_resize          = True, #me add to differ resolution img 
          mean                 = mean,
          cutoff               = config.data_cutoff,  #0
          color_jittering      = config.data_color,  #0
          images_filter        = config.data_images_filter, #0
      )
      source_imgrec = os.path.join("/home/svt/mxnet_recognition/dataes/faces_glintasia","train.rec")
      data2 = FaceImageIter(  #�õ� batch  img  label, train_dataiter_high
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = source_imgrec,
          shuffle              = True,
          rand_mirror          = config.data_rand_mirror, #true
          rand_resize          = False, #me add to differ resolution img
          mean                 = mean,
          cutoff               = config.data_cutoff,  #0
          color_jittering      = config.data_color,  #0
          images_filter        = config.data_images_filter, #0
      )
      metric1 = AccMetric()  #�õ����ȼ���
      eval_metrics = [mx.metric.create(metric1)]
      if config.ce_loss:  #is True
        metric2 = LossValueMetric()  #�õ���ʧֵ
        eval_metrics.append( mx.metric.create(metric2) )  #

    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0/args.ctx_num
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    opt = optimizer.Adam(learning_rate=0.0001, beta1=0.5, beta2=0.9, epsilon=1e-08)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results



    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    high_save = 0 #  me  add
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for step in lr_steps:
        if mbatch==step:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
      
      if mbatch %4000==0:#(fc7_save):
          name=os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'modelfc7')
          arg, aux = model.get_params()
          mx.model.save_checkpoint(name, param.epoch, model.symbol, arg, aux)
          print('save model include fc7 layer')
          print("mbatch",mbatch)
      
      me_msave=0
      if mbatch>=0 and mbatch%args.verbose==0:  #default.verbose = 2000,mbatch is
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]  # batch ��512��һ��epoch1300
        me_msave=me_msave+1
        do_save = False
        is_highest = False
        #me add
        save2 = False
        if len(acc_list)>0:
          lfw_score = acc_list[0]
          if lfw_score>highest_acc[0]:
            highest_acc[0] = lfw_score
            if lfw_score>=0.9960:
              save2 = True
              
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:
            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score
            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        # if is_highest:
          # do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3 and is_highest:  #me add and is_highest
          high_save = 0   #ÿ�α���lfw��ߵ�ģ��,�и��ߵ��滻ԭ�������ģ��

        if do_save:  #������ߵ����ݲ���
          print('saving high pretrained-epoch always:  ', high_save)
          arg, aux = model.get_params()
          if config.ckpt_embedding:  #true
            all_layers = model.symbol.get_internals()
            _sym = all_layers['fc1_output']
            _arg = {}
            for k in arg:
              if not k.startswith('fc7'):#�ַ�����ʼ�� fc7 ��ͷ������ѭ�������������������㣩
                _arg[k] = arg[k]
            mx.model.save_checkpoint(prefix, high_save, _sym, _arg, aux)  #��������֣������ǰ׺������IJ���ֻ��fc1(128ά�ȵ�����)
          else:
            mx.model.save_checkpoint(prefix, high_save, model.symbol, arg, aux)
          print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
          
        if save2:
          arg, aux = model.get_params()
          if config.ckpt_embedding:  #true
            all_layers = model.symbol.get_internals()
            _sym = all_layers['fc1_output']
            _arg = {}
            for k in arg:
              if not k.startswith('fc7'):#�ַ�����ʼ�� fc7 ��ͷ������ѭ�������������������㣩
                _arg[k] = arg[k]
            mx.model.save_checkpoint(prefix, (me_msave), _sym, _arg, aux)  #��������֣������ǰ׺������IJ���ֻ��fc1(128ά�ȵ�����)
          else:
            mx.model.save_checkpoint(prefix, (me_msave), model.symbol, arg, aux)
          print("save pretrained-epoch :param.epoch + me_msave",param.epoch,me_msave)
          print('[%d]LFW Accuracy>=0.9960: %1.5f'%(mbatch, highest_acc[-1])) #mbatch  �Ǵ�0 ��13000 һ��epoch ,Ȼ���ٴ�0����
    
      if config.max_steps>0 and mbatch>config.max_steps:
        sys.exit(0)
        
    ###########################################################################
   
    
    
    epoch_cb = None
    train_dataiter_low = mx.io.PrefetchingIter(train_dataiter_low) #���̵߳�����
    data2 = mx.io.PrefetchingIter(data2)  # ���̵߳�����

    #����model, �õ����ݣ�bind(data��label,�������ִ�к󣬷�����Դ�ռ�)��Ȼ���ʼ���������params
    #Ȼ�� fit ����ѵ��
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=[100, 200, 300], factor=0.1)
    optimizer_params = {'learning_rate':0.01,
                    'momentum':0.9,
                    'wd':0.0005,
                    # 'lr_scheduler':lr_scheduler,
                    "rescale_grad":_rescale}  #���ݶȽ�����ƽ��
    ######################################################################
    # # ��ʦ����
    data_shapes = [('data', (args.batch_size, 3, 112, 112))]  #teacher model only need data, no label 
    t_module = mx.module.Module(symbol=sym_high, context=ctx, label_names=[])
    t_module.bind(data_shapes=data_shapes, for_training=False, grad_req='null')
    t_module.set_params(arg_params=t_arg_params, aux_params=t_aux_params)
    t_model=t_module
    ######################################################################
    ##ѧ������
    label_shapes = [('softmax_label', (args.batch_size, ))]
    model = mx.mod.Module(
    context       = ctx,
    symbol        = sym,
    label_names=[]
    # data_names    =  #Ĭ��data,�� softmax_label,����Ķ���label �����֣���Ҫ���´���
    )
    #ѧ��������Ҫ ���ݺͱ�ǩ����ѵ��
    #��ʦ������Ҫ���ݣ����ñ�ǩ����ѵ�������Ұ����������ֵ��ӵ���ǩ����
    # print (train_dataiter_low.provide_data)
    # print ((train_dataiter_low.provide_label))
    #opt_d = optimizer.SGD(learning_rate=args.lr*0.01, momentum=args.mom, wd=args.wd, rescale_grad=_rescale) ##lr e-5
    opt_d = optimizer.Adam(learning_rate=0.0001, beta1=0.5, beta2=0.9, epsilon=1e-08)
    model.bind(data_shapes=data_shapes,for_training=True) #label shape���ˣ����˱�ǩ��������
    model.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=True)  #���Ϊtrue����������ܰ���ȱ�ٵ�ֵ�����ҽ����ó�ʼֵ�趨���������Щȱ�ٵIJ���
    # model.init_optimizer(kvstore=args.kvstore,optimizer='sgd', optimizer_params=(optimizer_params))
    model.init_optimizer(kvstore=args.kvstore,optimizer=opt_d)
    # metric = eval_metrics  #�������㣬�б�
    ##########################################################################
    ## ����������
    # ����ģ�飬�DZ����
    model_d = mx.module.Module(symbol=d_sym, context=ctx,data_names=['data'], label_names=['softmax_label'])
    data_shapes = [('data', (args.batch_size*2,512))]
    label_shapes = [('softmax_label', (args.batch_size*2,))]  #bind ������Զ��ı�batch��С��Ҳ����ʹ�õ�ʱ���ٰ�
    model_d.bind(data_shapes=data_shapes,label_shapes = label_shapes,inputs_need_grad=True)
    model_d.init_params(initializer=initializer)
    model_d.init_optimizer(kvstore=args.kvstore,optimizer=opt) #�Ż���������Ҫ�Ķ� #lr e-3
    ## �����õ��ǣ������� discriminator  �������������
    metric_d = AccMetric_d()  #�õ����ȼ���,��metric.py ��Ӻ���AccMetric_d�������õ���softmax
    eval_metrics_d = [mx.metric.create(metric_d)]
    metric2_d = LossValueMetric_d()  #�õ���ʧֵ  ,metric.py ��Ӻ���AccMetric_d�������õ���cros entropy
    eval_metrics_d.append( mx.metric.create(metric2_d) )  #
    metric_d =eval_metrics_d  # mx.metric.create('acc')## ����������softmax��  symbol ֻ��һ�����softmax ,ʱ���,

    global_step=[0]
    batch_num=[0]
    resize_acc=[0]
    for epoch in range(0, 40):
        # if epoch==1 or epoch==2 or epoch==3:
        #     model.init_optimizer(kvstore=args.kvstore,optimizer='sgd', optimizer_params=(optimizer_params))
        if not isinstance(metric_d, mx.metric.EvalMetric):#�������������
            metric_d = mx.metric.create(metric_d)
        # metric_d = mx.metric.create(metric_d)
        metric_d.reset()
        train_dataiter_low.reset()
        data2.reset()
        print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")

        data_iter = iter(train_dataiter_low)
        data2_iter = iter(data2)
        data_len=0
        for batch in data_iter:  # batch is high
            ##   1���õ� ��ʦ����train false,   ѧ������train true   ����������ϲ������� label���趨��1����0 
            ####��ʦ����õ�feature����ӳ� label����Ϊ�������ݣ�
            data_len +=len(batch.data[0])
            
            if len(batch.data[0])<args.batch_size:  #batch.data[0] is ����batch 
                print ("���data����batch,����")
                print ("data_len:",data_len)
                break
            if data_len >=2830147: #2830147,Ŀ���������ݳ���
                print ("һ��batch ����")
                break

            batch2 = data2_iter.next()
            t_model.forward(batch2, is_train=False)  #high data,�Լ� low_data,,�������������ݣ����ݿ��Դ�С��ͬ
            t_feat = t_model.get_outputs() # type list   batch.label,type list�����ֻ��fc1
            
            # print (batch.data[0].grad is None) # not None,  batch.data[0].detach.grad ,is None
            ## batch.data[0].grad ��None   ,batch.data[0].detach.grad Ҳ��None 
            ## �����û�����ݶ� ��bind, bind ������������������ݶȣ�������detach ,��ʾ������������ݶȼ���
            ## batch.data[0] #���ص����б�[batch_data] [label]����[  array[bchw]  ] [ array[0 1...]]
            ## ѧ���������ɶԿ�����  fack
            model.forward(batch,is_train=True) ##fc1 ���
            g_feat = model.get_outputs()    #get_symol ���صģ�����ֵ����,���յļ���ֵ����һ����fc1����
            label_t = nd.ones((args.batch_size,)) #1
            label_g = nd.zeros((args.batch_size,)) #0
            ## ������һ��
            label_concat = nd.concat(label_t,label_g,dim=0)
            feat_concat = nd.concat(t_feat[0],g_feat[0],dim=0) # ����nd �ϲ�nd.L2Normalization(����Ҫ
            
            ### 2.1�� �ϲ������ݽ���ѵ�����ݶȸ��£��ڶ���,�ڽ��У� is train = true,�� �����������ݵ��ݶȣ�
            ##��false,�Dz�����������ݶȣ����벻�䣬������Ҫ������ݶȣ�
            feat_data = mx.io.DataBatch([feat_concat.detach()], [label_concat])
            model_d.forward(feat_data, is_train=True) # #���е���ʧ
            model_d.backward()
            # print(feat_data.data[0].grad is None)  #is None
            ##��ֵ ģ���ݶȴ���
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in model_d._exec_group.grad_arrays]
            model_d.update()   ##�ݶȸ���
            model_d.update_metric(metric_d, [label_concat])
            
            
            ### 2.2 ,��ѧ������������õ� ����ֵ�������ݶ����ô��ݸ� ѧ�����磬�����£����ݵ��������� batch ��С
            label_g = nd.ones((args.batch_size,)) #��ǩ����Ϊ1

            feat_data = mx.io.DataBatch([g_feat[0]], [label_g])  #have input grad
            model_d.forward(feat_data, is_train=True) # #true  �õ�������ݶ�
            model_d.backward() ## �ҵ����û���ۼӹ��ܣ���һ����ִ������ forward �Ḳ���ϴεĽ��


            ####3. G �õ� �ݶ�  ���򴫵� ��ѧ������
            g_grad=model_d.get_input_grads()
            model.backward(g_grad)
            model.update()

            ## ѵ�������� s t ���������뵽���������磬�������ݶȸ��£�Ȼ�󣬵õ�s������������������н�������ʧ���ݶȴ���
            ## ������ ���� �������ǽ�ʦ��ѧ�����������ƴ�ӣ�label�ǣ�1 �� 0 
            
            # gan_label = [nd.empty((args.batch_size*2,2))]  #(batch*2,2) ����ģ�͵�������ƴ�� ��С��0 1 label,
            # discrim_data = [nd.empty((args.batch_size*2,512))]  #(batch*2,512)
            # print (gan_label[0].shape)



            lr_steps = [int(x) for x in args.lr_steps.split(',')]
            global_step[0]+=1
            batch_num[0]+=1
            mbatch = global_step[0]
            for step in lr_steps:
                if mbatch==step:
                    opt.lr *= 0.1
                    opt_d.lr*=0.1
                    print('opt.lr ,opt_d.lr lr change to', opt.lr,opt_d.lr)
                    break
            
            if mbatch %200==0 and mbatch >0: #(fc7_save):            
                print('mbath %d, Training %s' % (epoch, metric_d.get()))

            if mbatch %1000==0 and mbatch >0: 
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, epoch, model.symbol, arg, aux)
                
                arg, aux = model_d.get_params()
                mx.model.save_checkpoint(prefix+"discriminator", epoch, model_d.symbol, arg, aux)
                
                top1,top10 = my_top(epoch)
                yidong_test_top1,yidong_test_top1=my_top_yidong_test(epoch)
                if top1 >= resize_acc[0]:
                    resize_acc[0]=top1
                    #������ߵ����ݲ���
                    arg, aux = model.get_params()
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                      if not k.startswith('fc7'):#�ַ�����ʼ�� fc7 ��ͷ������ѭ�������������������㣩
                        _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix+"_best", 1, _sym, _arg, aux)  
                    acc_list = ver_test(mbatch)
                    if len(acc_list)>0:
                        print ("LFW acc is :",acc_list[0])
 
                print("batch_num",batch_num[0],"epoch",epoch, "lr ",opt.lr)
                print('mbath %d, Training %s' % (epoch, metric_d.get()))
Пример #18
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()

    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx), ctx, cvd)
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    # image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    sym, arg_params, aux_params = get_symbol(args,
                                             arg_params,
                                             aux_params,
                                             layer_name='ms1m_fc7')
    fixed_args = [n for n in sym.list_arguments() if 'fc7' in n]

    # sym.get_internals()
    # sym.list_arguments()
    # sym.list_auxiliary_states()
    # sym.list_inputs()
    # sym.list_outputs()

    # label_name = 'softmax_label'
    # label_shape = (args.batch_size,)
    # arg_params['glint_fc7_weight'] = arg_params['fc7_weight'].copy()
    # arg_params['ms1m_fc7_weight'] = arg_params['glint_fc7_weight'].copy()
    assert 'ms1m_fc7_weight' in arg_params
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        fixed_param_names=fixed_args,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        color_jittering=args.color,
        images_filter=args.images_filter,
    )

    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]
    if args.ce_loss:
        metric2 = LossValueMetric()
        eval_metrics.append(mx.metric.create(metric2))

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  # resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    # opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    logging.info(f'base lr {base_lr}')
    opt = optimizer.Adam(
        learning_rate=base_lr,
        wd=base_wd,
        rescale_grad=_rescale,
    )
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    # ver_test( 0 )
    highest_acc = [0.0, 0.0]  # lfw and target
    # for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch: lr ', opt.lr, 'nbatch ', param.nbatch,
                  'epoch ', param.epoch, 'mbatch ', mbatch, 'lr_step',
                  lr_steps)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                # lfw_score = acc_list[0]
                # if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    # if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)

            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
    # model.set_params(arg_params, aux_params)
    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        # optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #19
0
def train_net(args):
    ctx = []
    for ctx_id in [int(x) for x in args.gpus.split(',')]:
        ctx.append(mx.gpu(ctx_id))
    print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    assert args.batch_size % args.ctx_num == 0
    args.per_batch_size = args.batch_size // args.ctx_num
    args.image_channel = 3

    data_dir = args.data_dir
    print('data dir', data_dir)
    path_imgrec = None
    path_imglist = None
    for line in open(os.path.join(data_dir, 'property')):
        vec = line.strip().split(',')
        assert len(vec) == 3
        args.num_classes = int(vec[0])
        image_size = [int(vec[1]), int(vec[2])]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    global_num_ctx = args.ctx_num
    if args.num_classes % global_num_ctx == 0:
        args.ctx_num_classes = args.num_classes // global_num_ctx
    else:
        args.ctx_num_classes = args.num_classes // global_num_ctx + 1
    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = 0
    args.ctx_class_start = []
    for i in range(args.ctx_num):

        _c = args.local_class_start + i * args.ctx_num_classes
        args.ctx_class_start.append(_c)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None
    begin_epoch = 0

    #feat_net = fresnet.get(100, 256)
    #net = TrainBlock(args)
    args.use_dropout = True
    if args.num_classes >= 20000:
        args.use_dropout = False
    feat_net = FeatBlock(args, is_train=True)
    feat_net.collect_params().reset_ctx(ctx)
    if args.hybrid:
        feat_net.hybridize()
    cls_nets = []
    for i in range(args.ctx_num):
        cls_net = ArcMarginBlock(args)
        #cls_net.initialize(init=mx.init.Normal(0.01))
        cls_net.collect_params().reset_ctx(mx.gpu(i))
        if args.hybrid:
            cls_net.hybridize()
        cls_nets.append(cls_net)

    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    #feat_net.initialize(ctx=ctx, init=initializer)
    #feat_net.hybridize()
    #margin_block.initialize(ctx=ctx, init=mx.init.Normal(0.01))
    #margin_block.hybridize()

    ds = FaceDataset(data_shape=(3, 112, 112), path_imgrec=path_imgrec)
    #print(len(ds))
    #img, label = ds[0]
    #print(img.__class__, label.__class__)
    #print(img.shape, label)
    loader = gluon.data.DataLoader(ds,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=8,
                                   last_batch='discard')

    metric = CompositeEvalMetric([AccMetric()])

    ver_list = []
    ver_name_list = []
    for name in args.eval.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            print('loading ver-set:', name)
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)

    def ver_test(nbatch, tnet):
        results = []
        for i in range(len(ver_list)):
            xnorm, acc, thresh = verification.easytest(
                ver_list[i], tnet, ctx, batch_size=args.batch_size)
            print('[%s][%d]Accuracy-Thresh-XNorm: %.5f - %.5f - %.5f' %
                  (ver_name_list[i], nbatch, acc, thresh, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            #print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc)
        return results

    total_time = 0
    num_epochs = 0
    best_acc = [0]
    highest_acc = [0.0, 0.0]  #lfw and target
    global_step = [0]
    save_step = [0]
    lr_steps = [20000, 28000, 32000]
    if args.num_classes >= 20000:
        lr_steps = [100000, 160000, 220000]
    print('lr_steps', lr_steps)

    kv = mx.kv.create('device')
    trainer = gluon.Trainer(
        feat_net.collect_params(),
        'sgd',
        {
            'learning_rate': args.lr,
            'wd': args.wd,
            'momentum': args.mom,
            'multi_precision': True
        },
    )

    cls_trainers = []
    for i in range(args.ctx_num):
        _trainer = gluon.Trainer(
            cls_nets[i].collect_params(),
            'sgd',
            {
                'learning_rate': args.lr,
                'wd': args.wd,
                'momentum': args.mom,
                'multi_precision': True
            },
        )
        cls_trainers.append(_trainer)

    def _batch_callback():
        mbatch = global_step[0]
        global_step[0] += 1
        for _lr in lr_steps:
            if mbatch == _lr:
                args.lr *= 0.1
                trainer.set_learning_rate(args.lr)
                for _trainer in cls_trainers:
                    _trainer.set_learning_rate(args.lr)
                print('lr change to', args.lr)
                break

        #_cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', args.lr, mbatch)

        if mbatch > 0 and mbatch % args.verbose == 0:
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            tnet = FeatBlock(args,
                             is_train=False,
                             params=feat_net.collect_params())
            if args.hybrid:
                tnet.hybridize()
            acc_list = ver_test(mbatch, tnet)
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 1:
                do_save = True
                msave = 1
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                #print('saving gluon params')
                fname = args.prefix + "-gluon.params"
                tnet.save_parameters(fname)
                if args.hybrid:
                    tnet.export(args.prefix, msave)
                #arg, aux = model.get_params()
                #mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    #loss_weight = 1.0
    #loss = gluon.loss.SoftmaxCrossEntropyLoss(weight = loss_weight)
    #loss = nd.SoftmaxOutput
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    tmp_ctx = mx.gpu(0)
    cpu_ctx = mx.cpu()
    cache = NPCache()
    #ctx_fc7_max = mx.np.zeros( (args.batch_size, args.ctx_num), dtype=np.float32, ctx=cpu_ctx)
    #global_fc7_max = mx.np.zeros( (args.batch_size, 1), dtype=np.float32, ctx=cpu_ctx)
    #local_fc7_sum = mx.np.zeros((args.batch_size,1), ctx=cpu_ctx)
    #local_fc1_grad = mx.np.zeros( (args.batch_size,args.emb_size), ctx=cpu_ctx)
    while True:
        #trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
        tic = time.time()
        #train_iter.reset()
        metric.reset()
        btic = time.time()
        #for i, batch in enumerate(train_iter):
        for batch_idx, (x, y) in enumerate(loader):
            y = y.astype(np.float32)
            #print(x.shape, y.shape)
            #print(x.dtype, y.dtype)
            _batch_callback()
            #data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
            #label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
            data = gluon.utils.split_and_load(x, ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(y, ctx_list=ctx, batch_axis=0)
            #outputs = []
            #losses = []
            fc1_list = []
            fc1_out_list = []
            fc1_list_cpu = []
            fc7_list = []
            with ag.record():
                for _data, _label in zip(data, label):
                    #print(y.asnumpy())
                    fc1 = feat_net(_data)
                    fc1_out_list.append(fc1)
                    #fc1_list.append(fc1)
            for _fc1 in fc1_out_list:
                #fc1_cpu = cache.get2(cpu_ctx, 'fc1_cpu', _fc1)
                fc1_cpu = _fc1.as_in_ctx(cpu_ctx)
                fc1_list_cpu.append(fc1_cpu)
            global_fc1 = cache.get(cpu_ctx, 'global_fc1_cpu',
                                   (args.batch_size, args.emb_size))
            mx.np.concatenate(fc1_list_cpu, axis=0, out=global_fc1)
            #mean = mx.np.mean(global_fc1, axis=[0])
            #var = mx.np.var(global_fc1, axis=[0])
            #var = mx.np.sqrt(var + 2e-5)
            #global_fc1 = (global_fc1 - mean) / var
            _xlist = []
            _ylist = []
            for i, cls_net in enumerate(cls_nets):
                _ctx = mx.gpu(i)
                _y = cache.get2(_ctx, 'ctxy', y)
                _y -= args.ctx_class_start[i]
                _x = cache.get2(_ctx, 'ctxfc1', global_fc1)
                _xlist.append(_x)
                _ylist.append(_y)
            with ag.record():
                for i, cls_net in enumerate(cls_nets):
                    _ctx = mx.gpu(i)
                    _x = _xlist[i]
                    _y = _ylist[i]
                    #_y = cache.get2(_ctx, 'ctxy', y)
                    #_y -= args.ctx_class_start[i]
                    #_x = cache.get2(_ctx, 'ctxfc1', global_fc1)
                    #_x = global_fc1.as_in_ctx(_ctx)
                    _x.attach_grad()
                    _fc7 = cls_net(_x, _y)
                    fc7_list.append(_fc7)
                    fc1_list.append(_x)
            #print('log A')
            fc7_grads = [None] * args.ctx_num
            ctx_fc7_max = cache.get(cpu_ctx, 'gctxfc7max',
                                    (args.batch_size, args.ctx_num))
            ctx_fc7_max[:, :] = 0.0
            for i, cls_net in enumerate(cls_nets):
                _fc7 = fc7_list[i]
                _max = cache.get(_fc7.context, 'ctxfc7max',
                                 (args.batch_size, ))
                #_max = cache.get(cpu_ctx, 'ctxfc7max', (args.batch_size, ))
                mx.np.max(_fc7, axis=1, out=_max)
                #_cpumax = cache.get2(cpu_ctx, 'ctxfc7maxcpu', _max)
                _cpumax = _max.as_in_ctx(cpu_ctx)
                ctx_fc7_max[:, i] = _cpumax
                fc7_grads[i] = cache.get2(_fc7.context, 'fc7grad', _fc7)
            #nd.max(ctx_fc7_max, axis=1, keepdims=True, out=local_fc7_max)
            global_fc7_max = cache.get(cpu_ctx, 'globalfc7max',
                                       (args.batch_size, 1))
            mx.np.max(ctx_fc7_max, axis=1, keepdims=True, out=global_fc7_max)
            local_fc7_sum = cache.get(cpu_ctx, 'local_fc7_sum',
                                      (args.batch_size, 1))
            local_fc7_sum[:, :] = 0.0

            for i, cls_net in enumerate(cls_nets):
                _ctx = mx.gpu(i)
                #_max = global_fc7_max.as_in_ctx(mx.gpu(i))
                _max = cache.get2(_ctx, 'fc7maxgpu', global_fc7_max)
                fc7_grads[i] -= _max
                #mx.np.exp(fc7_grads[i], out=fc7_grads[i])
                fc7_grads[i] = mx.np.exp(fc7_grads[i])
                #_sum = cache.get(cpu_ctx, 'ctxfc7sum', (args.batch_size, 1))
                _sum = cache.get(_ctx, 'ctxfc7sum', (args.batch_size, 1))
                mx.np.sum(fc7_grads[i], axis=1, keepdims=True, out=_sum)
                #_cpusum = cache.get2(cpu_ctx, 'ctxfc7maxcpu', _max)
                _cpusum = _sum.as_in_ctx(cpu_ctx)
                local_fc7_sum += _cpusum
            global_fc7_sum = local_fc7_sum

            #print('log B')
            local_fc1_grad = cache.get(cpu_ctx, 'localfc1grad',
                                       (args.batch_size, args.emb_size))
            local_fc1_grad[:, :] = 0.0

            for i, cls_net in enumerate(cls_nets):
                #_sum = global_fc7_sum.as_in_ctx(mx.gpu(i))
                _ctx = mx.gpu(i)
                _sum = cache.get2(_ctx, 'globalfc7sumgpu', global_fc7_sum)
                fc7_grads[i] /= _sum
                a = i * args.ctx_num_classes
                b = (i + 1) * args.ctx_num_classes
                _y = cache.get2(_ctx, 'ctxy2', y)
                _y -= args.ctx_class_start[i]
                _yonehot = cache.get(_ctx, 'yonehot',
                                     (args.batch_size, args.ctx_num_classes))
                mx.npx.one_hot(_y,
                               depth=args.ctx_num_classes,
                               on_value=1.0,
                               off_value=0.0,
                               out=_yonehot)
                #_label = (y - args.ctx_class_start[i]).as_in_ctx(mx.gpu(i))
                #_label = mx.npx.one_hot(_label, depth=args.ctx_num_classes, on_value=1.0, off_value=0.0)
                fc7_grads[i] -= _yonehot
                fc7_list[i].backward(fc7_grads[i])
                fc1 = fc1_list[i]
                #fc1_grad = cache.get2(cpu_ctx, 'fc1gradcpu', fc1.grad)
                fc1_grad = fc1.grad.as_in_ctx(cpu_ctx)
                #print(fc1.grad.dtype, fc1.grad.shape)
                #print(fc1.grad[0:5,0:5])
                local_fc1_grad += fc1_grad
                cls_trainers[i].step(args.batch_size)
            #print('log C')
            for i in range(args.ctx_num):
                p = args.batch_size // args.ctx_num
                a = p * i
                b = p * (i + 1)
                _fc1_grad = local_fc1_grad[a:b, :]
                _grad = cache.get2(mx.gpu(i), 'fc1gradgpu', _fc1_grad)
                #_grad = local_fc1_grad[a:b,:].as_in_ctx(mx.gpu(i))
                #print(i, fc1_out_list[i].shape, _grad.shape)
                fc1_out_list[i].backward(_grad)
            #print('log D')
            trainer.step(args.batch_size)
            #print('after step')
            mx.npx.waitall()
            i = batch_idx
            if i > 0 and i % 20 == 0:
                logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec' %
                            (num_epochs, i, args.batch_size /
                             (time.time() - btic)))
            #metric.update(label, outputs)
            #if i>0 and i%20==0:
            #    name, acc = metric.get()
            #    if len(name)==2:
            #      logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'%(
            #                     num_epochs, i, args.batch_size/(time.time()-btic), name[0], acc[0], name[1], acc[1]))
            #    else:
            #      logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'%(
            #                     num_epochs, i, args.batch_size/(time.time()-btic), name[0], acc[0]))
            #    #metric.reset()
            btic = time.time()

        epoch_time = time.time() - tic

        # First epoch will usually be much slower than the subsequent epics,
        # so don't factor into the average
        if num_epochs > 0:
            total_time = total_time + epoch_time

        #name, acc = metric.get()
        #logger.info('[Epoch %d] training: %s=%f, %s=%f'%(num_epochs, name[0], acc[0], name[1], acc[1]))
        logger.info('[Epoch %d] time cost: %f' % (num_epochs, epoch_time))
        num_epochs = num_epochs + 1
        #name, val_acc = test(ctx, val_data)
        #logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

        # save model if meet requirements
        #save_checkpoint(epoch, val_acc[0], best_acc)
    if num_epochs > 1:
        print('Average epoch time: {}'.format(
            float(total_time) / (num_epochs - 1)))
Пример #20
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.image_channel = 3

    data_dir = args.data_dir
    if args.task == 'gender':
        data_dir = args.gender_data_dir
    elif args.task == 'age':
        data_dir = args.age_data_dir
    print('data dir', data_dir)
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    net = get_model()
    #if args.task=='':
    #  test_net = get_model_test(net)
    #print(net.__class__)
    #net = net0[0]
    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    net.hybridize()
    if args.mode == 'gluon':
        if len(args.pretrained) == 0:
            pass
        else:
            net.load_params(args.pretrained,
                            allow_missing=True,
                            ignore_extra=True)
        net.initialize(initializer)
        net.collect_params().reset_ctx(ctx)

    val_iter = None
    if args.task == '':
        train_iter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=args.rand_mirror,
            mean=mean,
            cutoff=args.cutoff,
        )
    else:
        train_iter = FaceImageIterAge(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            task=args.task,
            shuffle=True,
            rand_mirror=args.rand_mirror,
            mean=mean,
            cutoff=args.cutoff,
        )

    if args.task == 'age':
        metric = CompositeEvalMetric([MAEMetric(), CUMMetric()])
    elif args.task == 'gender':
        metric = CompositeEvalMetric([AccMetric()])
    else:
        metric = CompositeEvalMetric([AccMetric()])

    ver_list = []
    ver_name_list = []
    if args.task == '':
        for name in args.eval.split(','):
            path = os.path.join(data_dir, name + ".bin")
            if os.path.exists(path):
                data_set = verification.load_bin(path, image_size)
                ver_list.append(data_set)
                ver_name_list.append(name)
                print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], net, ctx, batch_size=args.batch_size)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    def val_test(nbatch=0):
        acc = 0.0
        #if args.task=='age':
        if len(args.age_data_dir) > 0:
            val_iter = FaceImageIterAge(
                batch_size=args.batch_size,
                data_shape=data_shape,
                path_imgrec=os.path.join(args.age_data_dir, 'val.rec'),
                task=args.task,
                shuffle=False,
                rand_mirror=False,
                mean=mean,
            )
            _metric = MAEMetric()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            _metric2 = CUMMetric()
            val_metric2 = mx.metric.create(_metric2)
            val_metric2.reset()
            val_iter.reset()
            for batch in val_iter:
                data = gluon.utils.split_and_load(batch.data[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch.label[0],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = []
                for x in data:
                    outputs.append(net(x)[2])
                val_metric.update(label, outputs)
                val_metric2.update(label, outputs)
            _value = val_metric.get_name_value()[0][1]
            print('[%d][VMAE]: %f' % (nbatch, _value))
            _value = val_metric2.get_name_value()[0][1]
            if args.task == 'age':
                acc = _value
            print('[%d][VCUM]: %f' % (nbatch, _value))
        if len(args.gender_data_dir) > 0:
            val_iter = FaceImageIterAge(
                batch_size=args.batch_size,
                data_shape=data_shape,
                path_imgrec=os.path.join(args.gender_data_dir, 'val.rec'),
                task=args.task,
                shuffle=False,
                rand_mirror=False,
                mean=mean,
            )
            _metric = AccMetric()
            val_metric = mx.metric.create(_metric)
            val_metric.reset()
            val_iter.reset()
            for batch in val_iter:
                data = gluon.utils.split_and_load(batch.data[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch.label[0],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = []
                for x in data:
                    outputs.append(net(x)[1])
                val_metric.update(label, outputs)
            _value = val_metric.get_name_value()[0][1]
            if args.task == 'gender':
                acc = _value
            print('[%d][VACC]: %f' % (nbatch, _value))
        return acc

    total_time = 0
    num_epochs = 0
    best_acc = [0]
    highest_acc = [0.0, 0.0]  #lfw and target
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    kv = mx.kv.create('device')
    #kv = mx.kv.create('local')
    #_rescale = 1.0/args.ctx_num
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd, rescale_grad=_rescale)
    #opt = optimizer.SGD(learning_rate=args.lr, momentum=args.mom, wd=args.wd)
    if args.mode == 'gluon':
        trainer = gluon.Trainer(net.collect_params(),
                                'sgd', {
                                    'learning_rate': args.lr,
                                    'wd': args.wd,
                                    'momentum': args.mom,
                                    'multi_precision': True
                                },
                                kvstore=kv)
    else:
        _rescale = 1.0 / args.ctx_num
        opt = optimizer.SGD(learning_rate=args.lr,
                            momentum=args.mom,
                            wd=args.wd,
                            rescale_grad=_rescale)
        _cb = mx.callback.Speedometer(args.batch_size, 20)
        arg_params = None
        aux_params = None
        data = mx.sym.var('data')
        label = mx.sym.var('softmax_label')
        if args.margin_a > 0.0:
            fc7 = net(data, label)
        else:
            fc7 = net(data)
        #sym = mx.symbol.SoftmaxOutput(data=fc7, label = label, name='softmax', normalization='valid')
        ceop = gluon.loss.SoftmaxCrossEntropyLoss()
        loss = ceop(fc7, label)
        #loss = loss/args.per_batch_size
        loss = mx.sym.mean(loss)
        sym = mx.sym.Group([
            mx.symbol.BlockGrad(fc7),
            mx.symbol.MakeLoss(loss, name='softmax')
        ])

    def _batch_callback():
        mbatch = global_step[0]
        global_step[0] += 1
        for _lr in lr_steps:
            if mbatch == _lr:
                args.lr *= 0.1
                if args.mode == 'gluon':
                    trainer.set_learning_rate(args.lr)
                else:
                    opt.lr = args.lr
                print('lr change to', args.lr)
                break

        #_cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', args.lr, mbatch)

        if mbatch > 0 and mbatch % args.verbose == 0:
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if args.task == 'age' or args.task == 'gender':
                acc = val_test(mbatch)
                if acc >= highest_acc[-1]:
                    highest_acc[-1] = acc
                    is_highest = True
                    do_save = True
            else:
                acc_list = ver_test(mbatch)
                if len(acc_list) > 0:
                    lfw_score = acc_list[0]
                    if lfw_score > highest_acc[0]:
                        highest_acc[0] = lfw_score
                        if lfw_score >= 0.998:
                            do_save = True
                    if acc_list[-1] >= highest_acc[-1]:
                        highest_acc[-1] = acc_list[-1]
                        if lfw_score >= 0.99:
                            do_save = True
                            is_highest = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                #print('saving gluon params')
                fname = os.path.join(args.prefix, 'model-gluon.params')
                net.save_params(fname)
                fname = os.path.join(args.prefix, 'model')
                net.export(fname, msave)
                #arg, aux = model.get_params()
                #mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    def _batch_callback_sym(param):
        _cb(param)
        _batch_callback()

    if args.mode != 'gluon':
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
        )
        model.fit(train_iter,
                  begin_epoch=0,
                  num_epoch=args.end_epoch,
                  eval_data=None,
                  eval_metric=metric,
                  kvstore='device',
                  optimizer=opt,
                  initializer=initializer,
                  arg_params=arg_params,
                  aux_params=aux_params,
                  allow_missing=True,
                  batch_end_callback=_batch_callback_sym,
                  epoch_end_callback=None)
    else:
        loss_weight = 1.0
        if args.task == 'age':
            loss_weight = 1.0 / AGE
        #loss = gluon.loss.SoftmaxCrossEntropyLoss(weight = loss_weight)
        loss = nd.SoftmaxOutput
        #loss = gluon.loss.SoftmaxCrossEntropyLoss()
        while True:
            #trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
            tic = time.time()
            train_iter.reset()
            metric.reset()
            btic = time.time()
            for i, batch in enumerate(train_iter):
                _batch_callback()
                #data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
                #label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
                data = gluon.utils.split_and_load(batch.data[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch.label[0],
                                                   ctx_list=ctx,
                                                   batch_axis=0)
                outputs = []
                Ls = []
                with ag.record():
                    for x, y in zip(data, label):
                        #print(y.asnumpy())
                        if args.task == '':
                            if args.margin_a > 0.0:
                                z = net(x, y)
                            else:
                                z = net(x)
                            #print(z[0].shape, z[1].shape)
                        else:
                            z = net(x)
                        if args.task == 'gender':
                            L = loss(z[1], y)
                            #L = L/args.per_batch_size
                            Ls.append(L)
                            outputs.append(z[1])
                        elif args.task == 'age':
                            for k in xrange(AGE):
                                _z = nd.slice_axis(z[2],
                                                   axis=1,
                                                   begin=k * 2,
                                                   end=k * 2 + 2)
                                _y = nd.slice_axis(y,
                                                   axis=1,
                                                   begin=k,
                                                   end=k + 1)
                                _y = nd.flatten(_y)
                                L = loss(_z, _y)
                                #L = L/args.per_batch_size
                                #L /= AGE
                                Ls.append(L)
                            outputs.append(z[2])
                        else:
                            L = loss(z, y)
                            #L = L/args.per_batch_size
                            Ls.append(L)
                            outputs.append(z)
                        # store the loss and do backward after we have done forward
                        # on all GPUs for better speed on multiple GPUs.
                    ag.backward(Ls)
                #trainer.step(batch.data[0].shape[0], ignore_stale_grad=True)
                #trainer.step(args.ctx_num)
                n = batch.data[0].shape[0]
                #print(n,n)
                trainer.step(n)
                metric.update(label, outputs)
                if i > 0 and i % 20 == 0:
                    name, acc = metric.get()
                    if len(name) == 2:
                        logger.info(
                            'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'
                            % (num_epochs, i, args.batch_size /
                               (time.time() - btic), name[0], acc[0], name[1],
                               acc[1]))
                    else:
                        logger.info(
                            'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f'
                            % (num_epochs, i, args.batch_size /
                               (time.time() - btic), name[0], acc[0]))
                    #metric.reset()
                btic = time.time()

            epoch_time = time.time() - tic

            # First epoch will usually be much slower than the subsequent epics,
            # so don't factor into the average
            if num_epochs > 0:
                total_time = total_time + epoch_time

            #name, acc = metric.get()
            #logger.info('[Epoch %d] training: %s=%f, %s=%f'%(num_epochs, name[0], acc[0], name[1], acc[1]))
            logger.info('[Epoch %d] time cost: %f' % (num_epochs, epoch_time))
            num_epochs = num_epochs + 1
            #name, val_acc = test(ctx, val_data)
            #logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

            # save model if meet requirements
            #save_checkpoint(epoch, val_acc[0], best_acc)
        if num_epochs > 1:
            print('Average epoch time: {}'.format(
                float(total_time) / (num_epochs - 1)))
Пример #21
0
def train_net(args):
    #_seed = 727
    #random.seed(_seed)
    #np.random.seed(_seed)
    #mx.random.seed(_seed)
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in range(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    if len(args.extra_model_name)==0:
      prefix = os.path.join(args.models_root, '%s-%s-%s'%(args.network, args.loss, args.dataset), 'model')
    else:
      prefix = os.path.join(args.models_root, '%s-%s-%s-%s'%(args.network, args.loss, args.dataset, args.extra_model_name), 'model')
    prefix_dir = os.path.dirname(prefix)
    print('prefix', prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size
    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    print('image_size', image_size)
    print('num_classes', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    data_shape = (args.image_channel,image_size[0],image_size[1])

    num_workers = config.num_workers
    global_num_ctx = num_workers * args.ctx_num
    if config.num_classes%global_num_ctx==0:
      args.ctx_num_classes = config.num_classes//global_num_ctx
    else:
      args.ctx_num_classes = config.num_classes//global_num_ctx+1
    args.local_num_classes = args.ctx_num_classes * args.ctx_num
    args.local_class_start = args.local_num_classes * args.worker_id

    #if len(args.partial)==0:
    #  local_classes_range = (0, args.num_classes)
    #else:
    #  _vec = args.partial.split(',')
    #  local_classes_range = (int(_vec[0]), int(_vec[1]))

    #args.partial_num_classes = local_classes_range[1] - local_classes_range[0]
    #args.partial_start = local_classes_range[0]

    print('Called with argument:', args, config)
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    arg_params = None
    aux_params = None
    if len(args.pretrained)==0:
      esym = get_symbol_embedding()
      asym = get_symbol_arcface
    else:
      assert False

    if config.count_flops:
      all_layers = esym.get_internals()
      _sym = all_layers['fc1_output']
      FLOPs = flops_counter.count_flops(_sym, data=(1,3,image_size[0],image_size[1]))
      _str = flops_counter.flops_str(FLOPs)
      print('Network FLOPs: %s'%_str)

    if config.num_workers==1:
      from parall_module_local_v1 import ParallModule
    else:
      from parall_module_dist import ParallModule

    model = ParallModule(
        context       = ctx,
        symbol        = esym,
        data_names    = ['data'],
        label_names    = ['softmax_label'],
        asymbol       = asym,
        args = args,
    )
    val_dataiter = None
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = config.data_rand_mirror,
        mean                 = mean,
        cutoff               = config.data_cutoff,
        color_jittering      = config.data_color,
        images_filter        = config.data_images_filter,
    )


    
    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)

    _rescale = 1.0/args.batch_size
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)


    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)


    def ver_test(nbatch):
      results = []
      for i in range(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for step in lr_steps:
        if mbatch==step:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if len(acc_list)>0:
          #lfw_score = acc_list[0]
          #if lfw_score>highest_acc[0]:
          #  highest_acc[0] = lfw_score
          #  if lfw_score>=0.998:
          #    do_save = True
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:
            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score
            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        if is_highest:
          do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3:
          msave = 1

        if do_save:
          print('saving', msave)
          arg, aux = model.get_export_params()
          all_layers = model.symbol.get_internals()
          _sym = all_layers['fc1_output']
          mx.model.save_checkpoint(prefix, msave, _sym, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if config.max_steps>0 and mbatch>config.max_steps:
        sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = 999999,
        eval_data          = val_dataiter,
        #eval_metric        = eval_metrics,
        kvstore            = args.kvstore,
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Пример #22
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = "%s-%s-p%s" % (args.prefix, args.network, args.patch)
    end_epoch = args.end_epoch
    pretrained = args.pretrained
    load_epoch = args.load_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
        if args.network[0] == 'r':
            args.per_batch_size = 128
        else:
            if args.num_layers >= 64:
                args.per_batch_size = 120
        if args.ctx_num == 2:
            args.per_batch_size *= 2
        elif args.ctx_num == 3:
            args.per_batch_size = 170
        if args.network[0] == 'm':
            args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    image_size = [int(x) for x in args.image_size.split(',')]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    assert len(ppatch) == 5
    #if args.patch%2==1:
    #  args.image_channel = 1

    #os.environ['GLOBAL_STEP'] = "0"
    os.environ['BETA'] = str(args.beta)
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None

    #path_imglist = "/raid5data/dplearn/faceinsight_align_webface.lst.new"
    #path_imglist = "/raid5data/dplearn/faceinsight_align_webface_clean.lst.new"
    for line in open(os.path.join(args.data_dir, 'property')):
        args.num_classes = int(line.strip())
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(args.data_dir, "train.rec")
    val_rec = os.path.join(args.data_dir, "val.rec")
    if os.path.exists(val_rec):
        args.use_val = True
    else:
        val_rec = None
    #args.num_classes = 10572 #webface
    #args.num_classes = 81017
    #args.num_classes = 82395

    if args.loss_type == 1 and args.num_classes > 40000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    #mean = [127.5,127.5,127.5]
    mean = None

    if args.use_val:
        val_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=val_rec,
            #path_imglist         = val_path,
            shuffle=False,
            rand_mirror=False,
            mean=mean,
        )
    else:
        val_dataiter = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = 0.9
    if not args.retrain:
        #load and initialize params
        #print(pretrained)
        #_, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        #arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        data_shape_dict = {
            'data': (args.batch_size, ) + data_shape,
            'softmax_label': (args.batch_size, )
        }
        if args.network[0] == 's':
            arg_params, aux_params = spherenet.init_weights(
                sym, data_shape_dict, args.num_layers)
        #elif args.network[0]=='m':
        #  arg_params, aux_params = marginalnet.init_weights(sym, data_shape_dict, args.num_layers)
        #resnet_dcn.init_weights(sym, data_shape_dict, arg_params, aux_params)
    else:
        #sym, arg_params, aux_params = mx.model.load_checkpoint(pretrained, load_epoch)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            pretrained, load_epoch)
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        #begin_epoch = load_epoch
        #end_epoch = begin_epoch+10
        #base_wd = 0.00005

    if args.loss_type != 10:
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
        )
    else:
        data_names = ('data', 'extra')
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
            data_names=data_names,
        )

    if args.loss_type <= 9:
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=True,
            mean=mean,
        )
    elif args.loss_type == 10:
        train_dataiter = FaceImageIter4(
            batch_size=args.batch_size,
            ctx_num=args.ctx_num,
            images_per_identity=args.images_per_identity,
            data_shape=data_shape,
            path_imglist=path_imglist,
            shuffle=True,
            rand_mirror=True,
            mean=mean,
            patch=ppatch,
            use_extra=True,
            model=model,
        )
    elif args.loss_type == 11:
        train_dataiter = FaceImageIter5(
            batch_size=args.batch_size,
            ctx_num=args.ctx_num,
            images_per_identity=args.images_per_identity,
            data_shape=data_shape,
            path_imglist=path_imglist,
            shuffle=True,
            rand_mirror=True,
            mean=mean,
            patch=ppatch,
        )
    #args.epoch_size = int(math.ceil(train_dataiter.num_samples()/args.batch_size))

    #_dice = DiceMetric()
    _acc = AccMetric()
    eval_metrics = [mx.metric.create(_acc)]

    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    #for child_metric in [fcn_loss_metric]:
    #    eval_metrics.add(child_metric)

    # callback
    #batch_end_callback = callback.Speedometer(input_batch_size, frequent=args.frequent)
    #epoch_end_callback = mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True)

    # decide learning rate
    #lr_step = '10,20,30'
    #train_size = 4848
    #nrof_batch_in_epoch = int(train_size/input_batch_size)
    #print('nrof_batch_in_epoch:', nrof_batch_in_epoch)
    #lr_factor = 0.1
    #lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    #lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    #lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    #lr_iters = [int(epoch * train_size / batch_size) for epoch in lr_epoch_diff]
    #print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    #lr_scheduler = MultiFactorScheduler(lr_iters, lr_factor)

    # optimizer
    #optimizer_params = {'momentum': 0.9,
    #                    'wd': 0.0005,
    #                    'learning_rate': base_lr,
    #                    'rescale_grad': 1.0,
    #                    'clip_gradient': None}
    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    #_rescale = 1.0
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    #opt = optimizer.RMSProp(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
    #opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=_rescale)
    #opt = optimizer.AdaGrad(learning_rate=base_lr, wd=base_wd, rescale_grad=1.0)
    _cb = mx.callback.Speedometer(args.batch_size, 10)

    ver_list = []
    ver_name_list = []
    for name in ['lfw', 'cfp_ff', 'cfp_fp']:
        path = os.path.join(args.data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            print('[%s][%d]Accuracy: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    def val_test():
        acc = AccMetric()
        val_metric = mx.metric.create(acc)
        val_metric.reset()
        val_dataiter.reset()
        for i, eval_batch in enumerate(val_dataiter):
            model.forward(eval_batch, is_train=False)
            model.update_metric(val_metric, eval_batch.label)
        acc_value = val_metric.get_name_value()[0][1]
        print('VACC: %f' % (acc_value))

    highest_acc = [0.0]
    last_save_acc = [0.0]
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type == 1:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in len(lr_steps):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            acc = acc_list[0]
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if acc >= highest_acc[0]:
                highest_acc[0] = acc
                if acc >= 0.99:
                    do_save = True
            if mbatch > lr_steps[-1] and mbatch % 10000 == 0:
                do_save = True
            if do_save:
                print('saving', msave, acc)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
                #if acc>=highest_acc[0]:
                #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
                #  X = np.concatenate(embeddings_list, axis=0)
                #  print('saving lfw npy', X.shape)
                #  np.save(lfw_npy, X)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[0]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
            #_beta = max(args.beta_min, args.beta*math.pow(0.7, move//500))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    #def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #23
0
def train_net(args):
  # gpu / cpu 设置
    ctx = []
    # cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    # cvd = os.environ['0'].strip()
    # if len(cvd)>0:
    #   for i in range(len(cvd.split(','))):
    #     ctx.append(mx.gpu(i))
    # if len(ctx)==0:
    #   ctx = [mx.cpu()]
    #   print('use cpu')
    # else:
    #   print('gpu num:', len(ctx))
    ctx.append(mx.gpu(0))
  # 保存模型的路径设置
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
  # 参数预设
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)             #print
    if args.per_batch_size==0:
      args.per_batch_size = 128
    args.batch_size = args.per_batch_size*args.ctx_num
    print(args.batch_size)
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    #读取property文件
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    #image_size = prop.image_size
    image_size = [int(x) for x in args.image_size.split(',')]
    assert len(image_size)==2
    assert image_size[0]==image_size[1]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)             #print
    assert(args.num_classes>0)
    print('num_classes', args.num_classes)      #print

    path_imgrec = os.path.join(data_dir, "train.rec")    

    if args.loss_type==1 and args.num_classes>20000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    print('Called with argument:', args)         #print
    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None
  # 预训练模型是否存在
    begin_epoch = 0 
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom

    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
      if args.network[0]=='s':
        data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)


  # 初始化model

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context       = ctx,
        symbol        = sym,
    )
    val_dataiter = None
  # 获取train_data参数
    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = args.rand_mirror,
        mean                 = mean,
        cutoff               = args.cutoff,
        color_jittering      = args.color,
        images_filter        = args.images_filter,
    )
  # 获取eval_metric参数
    metric1 = AccMetric()
    eval_metrics = [mx.metric.create(metric1)]

    if args.ce_loss:
      metric2 = LossValueMetric()
      eval_metrics.append( mx.metric.create(metric2) )


  # initializer获取(权重初始化)  根据net类型获取  / 并获取optimizer

    if args.network[0]=='r' or args.network[0]=='y':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0/args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

  # 加载测试集数据
    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)


  # 对测试集进行测试
    def ver_test(nbatch):
      results = []
      for i in range(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10, None, None)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results



    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

  # lr_steps的设置
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size                        #args.batch_size = 128*x
      for l in range(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

  # 模型保存和lr等一些参数的变化设置
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]

      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:        #args.beta_freeze = 5000
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:      #args.verbose = 2000
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        is_highest = False
        if len(acc_list)>0:
          #lfw_score = acc_list[0]
          #if lfw_score>highest_acc[0]:
          #  highest_acc[0] = lfw_score
          #  if lfw_score>=0.998:
          #    do_save = True
          score = sum(acc_list)
          if acc_list[-1]>=highest_acc[-1]:

            if acc_list[-1]>highest_acc[-1]:
              is_highest = True
            else:
              if score>=highest_acc[0]:
                is_highest = True
                highest_acc[0] = score

            highest_acc[-1] = acc_list[-1]
            #if lfw_score>=0.99:
            #  do_save = True
        if is_highest:
          do_save = True
        # 模型保存方式
        if args.ckpt==0:
          do_save = False
        elif args.ckpt==2:
          do_save = True
        elif args.ckpt==3:
          msave = 1

        if do_save:
          print('saving', msave)
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta                 #args.beta = 1000
      else:                               #mbatch>args.beta_freeze
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)
  # 模型保存路劲设置
    epoch_cb = None
    # mx.io.PrefetchingIter()这个好像是把几个数据迭代器合并的接口
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)
  # 训练入口
    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Пример #24
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")

    assert args.images_per_identity >= 2
    assert args.triplet_bag_size % args.batch_size == 0

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        sym, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        all_layers = sym.get_internals()
        sym = all_layers['fc1_output']
        sym, arg_params, aux_params = get_symbol(args,
                                                 arg_params,
                                                 aux_params,
                                                 sym_embedding=sym)

    data_extra = None
    hard_mining = False
    triplet_params = [
        args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap
    ]
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        #data_names = ('data',),
        #label_names = None,
        #label_names = ('softmax_label',),
    )
    label_shape = (args.batch_size, )

    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        ctx_num=args.ctx_num,
        images_per_identity=args.images_per_identity,
        triplet_params=triplet_params,
        mx_model=model,
    )

    _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    if args.noise_sgd > 0.0:
        print('use noise sgd')
        opt = NoiseSGD(scale=args.noise_sgd,
                       learning_rate=base_lr,
                       momentum=base_mom,
                       wd=base_wd,
                       rescale_grad=_rescale)
    else:
        opt = optimizer.SGD(learning_rate=base_lr,
                            momentum=base_mom,
                            wd=base_wd,
                            rescale_grad=_rescale)
    som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, label_shape)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [1000000000]
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            #for i in xrange(len(acc_list)):
            #  acc = acc_list[i]
            #  if acc>=highest_acc[i]:
            #    highest_acc[i] = acc
            #    if lfw_score>=0.99:
            #      do_save = True
            #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
            #  do_save = True
            if do_save:
                print('saving', msave)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #25
0
def train_net(args):
    ctx = []
    for ctx_id in [int(x) for x in args.gpus.split(',')]:
        ctx.append(mx.gpu(ctx_id))
    print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    assert args.batch_size % args.ctx_num == 0
    args.per_batch_size = args.batch_size // args.ctx_num
    args.image_channel = 3

    data_dir = args.data_dir
    print('data dir', data_dir)
    path_imgrec = None
    path_imglist = None
    for line in open(os.path.join(data_dir, 'property')):
        vec = line.strip().split(',')
        assert len(vec) == 3
        args.num_classes = int(vec[0])
        image_size = [int(vec[1]), int(vec[2])]
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None
    begin_epoch = 0

    #feat_net = fresnet.get(100, 256)
    #margin_block = ArcMarginBlock(args)
    net = TrainBlock(args)
    net.collect_params().reset_ctx(ctx)
    net.hybridize()

    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    #feat_net.initialize(ctx=ctx, init=initializer)
    #feat_net.hybridize()
    #margin_block.initialize(ctx=ctx, init=mx.init.Normal(0.01))
    #margin_block.hybridize()

    ds = FaceDataset(data_shape=(3, 112, 112), path_imgrec=path_imgrec)
    #print(len(ds))
    #img, label = ds[0]
    #print(img.__class__, label.__class__)
    #print(img.shape, label)
    loader = gluon.data.DataLoader(ds,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=8,
                                   last_batch='discard')

    metric = CompositeEvalMetric([AccMetric()])

    ver_list = []
    ver_name_list = []
    if args.task == '':
        for name in args.eval.split(','):
            path = os.path.join(data_dir, name + ".bin")
            if os.path.exists(path):
                print('loading ver-set:', name)
                data_set = verification.load_bin(path, image_size)
                ver_list.append(data_set)
                ver_name_list.append(name)

    def ver_test(nbatch, tnet):
        results = []
        for i in range(len(ver_list)):
            xnorm, acc, thresh = verification.easytest(
                ver_list[i], tnet, ctx, batch_size=args.batch_size)
            print('[%s][%d]Accuracy-Thresh-XNorm: %.5f - %.5f - %.5f' %
                  (ver_name_list[i], nbatch, acc, thresh, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            #print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc)
        return results

    total_time = 0
    num_epochs = 0
    best_acc = [0]
    highest_acc = [0.0, 0.0]  #lfw and target
    global_step = [0]
    save_step = [0]
    lr_steps = [20000, 28000, 32000]
    if args.num_classes >= 20000:
        lr_steps = [100000, 160000, 220000]
    print('lr_steps', lr_steps)

    kv = mx.kv.create('device')
    trainer = gluon.Trainer(net.collect_params(),
                            'sgd', {
                                'learning_rate': args.lr,
                                'wd': args.wd,
                                'momentum': args.mom,
                                'multi_precision': True
                            },
                            kvstore=kv)

    def _batch_callback():
        mbatch = global_step[0]
        global_step[0] += 1
        for _lr in lr_steps:
            if mbatch == _lr:
                args.lr *= 0.1
                if args.mode == 'gluon':
                    trainer.set_learning_rate(args.lr)
                else:
                    opt.lr = args.lr
                print('lr change to', args.lr)
                break

        #_cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', args.lr, mbatch)

        if mbatch > 0 and mbatch % args.verbose == 0:
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            tnet = TestBlock(args, params=net.collect_params())
            tnet.hybridize()
            acc_list = ver_test(mbatch, tnet)
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 1:
                do_save = True
                msave = 1
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                #print('saving gluon params')
                fname = args.prefix + "-gluon.params"
                tnet.save_parameters(fname)
                tnet.export(args.prefix, msave)
                #arg, aux = model.get_params()
                #mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    loss_weight = 1.0
    #loss = gluon.loss.SoftmaxCrossEntropyLoss(weight = loss_weight)
    #loss = nd.SoftmaxOutput
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    while True:
        #trainer = update_learning_rate(opt.lr, trainer, epoch, opt.lr_factor, lr_steps)
        tic = time.time()
        #train_iter.reset()
        metric.reset()
        btic = time.time()
        #for i, batch in enumerate(train_iter):
        for batch_idx, (x, y) in enumerate(loader):
            #print(x.shape, y.shape)
            _batch_callback()
            #data = gluon.utils.split_and_load(batch.data[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
            #label = gluon.utils.split_and_load(batch.label[0].astype(opt.dtype), ctx_list=ctx, batch_axis=0)
            data = gluon.utils.split_and_load(x, ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(y, ctx_list=ctx, batch_axis=0)
            outputs = []
            losses = []
            with ag.record():
                for _data, _label in zip(data, label):
                    #print(y.asnumpy())
                    fc7 = net(_data, _label)
                    #print(z[0].shape, z[1].shape)
                    losses.append(loss(fc7, _label))
                    outputs.append(fc7)
            for l in losses:
                l.backward()
            #trainer.step(batch.data[0].shape[0], ignore_stale_grad=True)
            #trainer.step(args.ctx_num)
            n = x.shape[0]
            #print(n,n)
            trainer.step(n)
            metric.update(label, outputs)
            i = batch_idx
            if i > 0 and i % 20 == 0:
                name, acc = metric.get()
                if len(name) == 2:
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f, %s=%f'
                        % (num_epochs, i, args.batch_size /
                           (time.time() - btic), name[0], acc[0], name[1],
                           acc[1]))
                else:
                    logger.info(
                        'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f' %
                        (num_epochs, i, args.batch_size /
                         (time.time() - btic), name[0], acc[0]))
                #metric.reset()
            btic = time.time()

        epoch_time = time.time() - tic

        # First epoch will usually be much slower than the subsequent epics,
        # so don't factor into the average
        if num_epochs > 0:
            total_time = total_time + epoch_time

        #name, acc = metric.get()
        #logger.info('[Epoch %d] training: %s=%f, %s=%f'%(num_epochs, name[0], acc[0], name[1], acc[1]))
        logger.info('[Epoch %d] time cost: %f' % (num_epochs, epoch_time))
        num_epochs = num_epochs + 1
        #name, val_acc = test(ctx, val_data)
        #logger.info('[Epoch %d] validation: %s=%f, %s=%f'%(epoch, name[0], val_acc[0], name[1], val_acc[1]))

        # save model if meet requirements
        #save_checkpoint(epoch, val_acc[0], best_acc)
    if num_epochs > 1:
        print('Average epoch time: {}'.format(
            float(total_time) / (num_epochs - 1)))
Пример #26
0
def main(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    args.ctx_num = len(ctx)
    include_datasets = args.include.split(',')
    prop = face_image.load_property(include_datasets[0])
    image_size = prop.image_size
    print('image_size', image_size)
    vec = args.model.split(',')
    prefix = vec[0]
    epoch = int(vec[1])
    print('loading', prefix, epoch)
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    #arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
    all_layers = sym.get_internals()
    sym = all_layers['fc1_output']
    #model = mx.mod.Module.load(prefix, epoch, context = ctx)
    #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
    model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
    model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
                                      image_size[1]))])
    model.set_params(arg_params, aux_params)
    rec_list = []
    for ds in include_datasets:
        path_imgrec = os.path.join(ds, 'train.rec')
        path_imgidx = os.path.join(ds, 'train.idx')
        imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')  # pylint: disable=redefined-variable-type
        rec_list.append(imgrec)
    id_list_map = {}
    all_id_list = []
    test_limit = 0
    for ds_id in xrange(len(rec_list)):
        id_list = []
        imgrec = rec_list[ds_id]
        s = imgrec.read_idx(0)
        header, _ = mx.recordio.unpack(s)
        assert header.flag > 0
        print('header0 label', header.label)
        header0 = (int(header.label[0]), int(header.label[1]))
        #assert(header.flag==1)
        imgidx = range(1, int(header.label[0]))
        id2range = {}
        seq_identity = range(int(header.label[0]), int(header.label[1]))
        pp = 0
        for identity in seq_identity:
            pp += 1
            if pp % 10 == 0:
                print('processing id', pp)
            embedding = get_embedding(args, imgrec, identity, image_size,
                                      model)
            #print(embedding.shape)
            id_list.append([ds_id, identity, embedding])
            if test_limit > 0 and pp >= test_limit:
                break
        id_list_map[ds_id] = id_list
        if ds_id == 0:
            all_id_list += id_list
            print(ds_id, len(id_list))
        else:
            X = []
            for id_item in all_id_list:
                X.append(id_item[2])
            X = np.array(X)
            for i in xrange(len(id_list)):
                id_item = id_list[i]
                y = id_item[2]
                sim = np.dot(X, y.T)
                idx = np.where(sim >= args.param1)[0]
                if len(idx) > 0:
                    continue
                all_id_list.append(id_item)
            print(ds_id, len(id_list), len(all_id_list))

    if len(args.exclude) > 0:
        if os.path.isdir(args.exclude):
            _path_imgrec = os.path.join(args.exclude, 'train.rec')
            _path_imgidx = os.path.join(args.exclude, 'train.idx')
            _imgrec = mx.recordio.MXIndexedRecordIO(_path_imgidx, _path_imgrec,
                                                    'r')  # pylint: disable=redefined-variable-type
            _ds_id = len(rec_list)
            _id_list = []
            s = _imgrec.read_idx(0)
            header, _ = mx.recordio.unpack(s)
            assert header.flag > 0
            print('header0 label', header.label)
            header0 = (int(header.label[0]), int(header.label[1]))
            #assert(header.flag==1)
            imgidx = range(1, int(header.label[0]))
            seq_identity = range(int(header.label[0]), int(header.label[1]))
            pp = 0
            for identity in seq_identity:
                pp += 1
                if pp % 10 == 0:
                    print('processing ex id', pp)
                embedding = get_embedding(args, _imgrec, identity, image_size,
                                          model)
                #print(embedding.shape)
                _id_list.append((_ds_id, identity, embedding))
                if test_limit > 0 and pp >= test_limit:
                    break
        else:
            _id_list = []
            data_set = verification.load_bin(args.exclude, image_size)[0][0]
            print(data_set.shape)
            data = nd.zeros((1, 3, image_size[0], image_size[1]))
            for i in xrange(data_set.shape[0]):
                data[0] = data_set[i]
                db = mx.io.DataBatch(data=(data, ))
                model.forward(db, is_train=False)
                net_out = model.get_outputs()
                embedding = net_out[0].asnumpy().flatten()
                _norm = np.linalg.norm(embedding)
                embedding /= _norm
                _id_list.append((i, i, embedding))

        #X = []
        #for id_item in all_id_list:
        #  X.append(id_item[2])
        #X = np.array(X)
        #param1 = 0.3
        #while param1<=1.01:
        #  emap = {}
        #  for id_item in _id_list:
        #    y = id_item[2]
        #    sim = np.dot(X, y.T)
        #    #print(sim.shape)
        #    #print(sim)
        #    idx = np.where(sim>=param1)[0]
        #    for j in idx:
        #      emap[j] = 1
        #  exclude_removed = len(emap)
        #  print(param1, exclude_removed)
        #  param1+=0.05

            X = []
            for id_item in all_id_list:
                X.append(id_item[2])
            X = np.array(X)
            emap = {}
            for id_item in _id_list:
                y = id_item[2]
                sim = np.dot(X, y.T)
                idx = np.where(sim >= args.param2)[0]
                for j in idx:
                    emap[j] = 1
                    all_id_list[j][1] = -1
            print('exclude', len(emap))

    if args.test > 0:
        return

    if not os.path.exists(args.output):
        os.makedirs(args.output)
    writer = mx.recordio.MXIndexedRecordIO(
        os.path.join(args.output, 'train.idx'),
        os.path.join(args.output, 'train.rec'), 'w')
    idx = 1
    identities = []
    nlabel = -1
    for id_item in all_id_list:
        if id_item[1] < 0:
            continue
        nlabel += 1
        ds_id = id_item[0]
        imgrec = rec_list[ds_id]
        id = id_item[1]
        s = imgrec.read_idx(id)
        header, _ = mx.recordio.unpack(s)
        a, b = int(header.label[0]), int(header.label[1])
        identities.append((idx, idx + b - a))
        for _idx in xrange(a, b):
            s = imgrec.read_idx(_idx)
            _header, _content = mx.recordio.unpack(s)
            nheader = mx.recordio.IRHeader(0, nlabel, idx, 0)
            s = mx.recordio.pack(nheader, _content)
            writer.write_idx(idx, s)
            idx += 1
    id_idx = idx
    for id_label in identities:
        _header = mx.recordio.IRHeader(1, id_label, idx, 0)
        s = mx.recordio.pack(_header, '')
        writer.write_idx(idx, s)
        idx += 1
    _header = mx.recordio.IRHeader(1, (id_idx, idx), 0, 0)
    s = mx.recordio.pack(_header, '')
    writer.write_idx(0, s)
    with open(os.path.join(args.output, 'property'), 'w') as f:
        f.write("%d,%d,%d" % (len(identities), image_size[0], image_size[1]))
Пример #27
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    network, num_layers = args.network.split(',')
    print('num_layers', num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    path_imgrecs = []
    path_imglist = None
    args.num_classes = []
    for data_idx, data_dir in enumerate(data_dir_list):
        prop = face_image.load_property(data_dir)
        args.num_classes.append(prop.num_classes)
        image_size = prop.image_size
        if data_idx == 0:
            args.image_h = image_size[0]
            args.image_w = image_size[1]
        else:
            args.image_h = min(args.image_h, image_size[0])
            args.image_w = min(args.image_w, image_size[1])
        print('image_size', image_size)
        assert (args.num_classes[-1] > 0)
        print('num_classes', args.num_classes)
        path_imgrecs.append(os.path.join(data_dir, "train.rec"))

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(network, int(num_layers),
                                                 args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(network, int(num_layers),
                                                 args, arg_params, aux_params)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    ctx_group = dict(zip(['dev%d' % (i + 1) for i in range(len(ctx))], ctx))
    ctx_group['dev0'] = ctx
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        data_names=['data'] if args.loss_type != 6 else ['data', 'margin'],
        group2ctxs=ctx_group)
    val_dataiter = None

    from config import crop
    from config import cutout

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrecs=path_imgrecs,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutout=cutout,
        crop=crop,
        loss_type=args.loss_type,
        #margin_m             = args.margin_m,
        #margin_policy        = args.margin_policy,
        #max_steps            = args.max_steps,
        #data_names           = ['data', 'margin'],
        downsample_back=args.downsample_back,
        motion_blur=args.motion_blur,
    )

    _metric = AccMetric()
    #_metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num

    if len(args.lr_steps) == 0:
        print('Error: lr_steps is not seted')
        sys.exit(0)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_steps,
                                                        factor=0.1,
                                                        base_lr=base_lr)
    optimizer_params = {
        'learning_rate': base_lr,
        'momentum': base_mom,
        'wd': base_wd,
        'rescale_grad': _rescale,
        'lr_scheduler': lr_scheduler
    }

    #opt = AdaBound()
    #opt = AdaBound(lr=base_lr, wd=base_wd, gamma = 2. / args.max_steps)
    opt = optimizer.SGD(**optimizer_params)

    som = 2000
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            _, issame_list = ver_list[i]
            if all(issame_list):
                fp_rates, fp_dict, thred_dict, recall_dict = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    label_shape=(args.batch_size, len(path_imgrecs)))
                for k in fp_rates:
                    print("[%s] TPR at FPR %.2e[%.2e: %.4f]:\t%.5f" %
                          (ver_name_list[i], k, fp_dict[k], thred_dict[k],
                           recall_dict[k]))
            else:
                acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                    ver_list[i],
                    model,
                    args.batch_size,
                    10,
                    None,
                    label_shape=(args.batch_size, len(path_imgrecs)))
                print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
                #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
                print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                      (ver_name_list[i], nbatch, acc2, std2))
                results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in range(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]

        _cb(param)
        if mbatch % 10000 == 0:
            print('lr-batch-epoch:', opt.learning_rate, param.nbatch,
                  param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              optimizer_params=optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)
Пример #28
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
      if args.loss_type==10:
        args.per_batch_size = 256
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    assert len(ppatch)==5


    os.environ['BETA'] = str(args.beta)
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None
    prop = face_image.load_property(args.data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert(args.num_classes>0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(args.data_dir, "train.rec")
    val_rec = os.path.join(args.data_dir, "val.rec")
    if os.path.exists(val_rec) and args.loss_type<10:
      args.use_val = True
    else:
      val_rec = None
    #args.num_classes = 10572 #webface
    #args.num_classes = 81017
    #args.num_classes = 82395



    if args.loss_type==1 and args.num_classes>40000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    if args.loss_type==11:
      args.images_per_identity = 2
    elif args.loss_type==10:
      args.images_per_identity = 16

    if args.loss_type<10:
      assert args.images_per_identity==0
    else:
      assert args.images_per_identity>=2
      args.per_identities = int(args.per_batch_size/args.images_per_identity)

    print('Called with argument:', args)

    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None

    if args.use_val:
      val_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = val_rec,
          #path_imglist         = val_path,
          shuffle              = False,
          rand_mirror          = False,
          mean                 = mean,
      )
    else:
      val_dataiter = None



    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
      vec = args.pretrained.split(',')
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    data_extra = None
    hard_mining = False
    if args.loss_type==10:
      hard_mining = True
      _shape = (args.batch_size, args.per_batch_size)
      data_extra = np.full(_shape, -1.0, dtype=np.float32)
      c = 0
      while c<args.batch_size:
        a = 0
        while a<args.per_batch_size:
          b = a+args.images_per_identity
          data_extra[(c+a):(c+b),a:b] = 1.0
          #print(c+a, c+b, a, b)
          a = b
        c += args.per_batch_size
    elif args.loss_type==11:
      data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)
      c = 0
      while c<args.batch_size:
        for i in xrange(args.per_identities):
          data_extra[c+i][i] = 1.0
        c+=args.per_batch_size

    label_name = 'softmax_label'
    if data_extra is None:
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
      )
    else:
      data_names = ('data', 'extra')
      #label_name = ''
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
          data_names    = data_names,
          label_names   = (label_name,),
      )


    train_dataiter = FaceImageIter(
        batch_size           = args.batch_size,
        data_shape           = data_shape,
        path_imgrec          = path_imgrec,
        shuffle              = True,
        rand_mirror          = True,
        mean                 = mean,
        ctx_num              = args.ctx_num,
        images_per_identity  = args.images_per_identity,
        data_extra           = data_extra,
        hard_mining          = hard_mining,
        mx_model             = model,
        label_name           = label_name,
    )

    if args.loss_type<10:
      _metric = AccMetric()
    else:
      _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0]=='r':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0/args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, 20)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(args.data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, data_extra)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    def val_test():
      acc = AccMetric()
      val_metric = mx.metric.create(acc)
      val_metric.reset()
      val_dataiter.reset()
      for i, eval_batch in enumerate(val_dataiter):
        model.forward(eval_batch, is_train=False)
        model.update_metric(val_metric, eval_batch.label)
      acc_value = val_metric.get_name_value()[0][1]
      print('VACC: %f'%(acc_value))


    highest_acc = []
    for i in xrange(len(ver_list)):
      highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 70000]
      if args.loss_type==1:
        lr_steps = [50000, 70000, 80000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        lfw_score = acc_list[0]
        for i in xrange(len(acc_list)):
          acc = acc_list[i]
          if acc>=highest_acc[i]:
            highest_acc[i] = acc
            if lfw_score>=0.99:
              do_save = True
        if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
          do_save = True
        if do_save:
          print('saving', msave, acc)
          if val_dataiter is not None:
            val_test()
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
          #if acc>=highest_acc[0]:
          #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
          #  X = np.concatenate(embeddings_list, axis=0)
          #  print('saving lfw npy', X.shape)
          #  np.save(lfw_npy, X)
        #print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[0]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta
      else:
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None



    #def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    args.batch_size = args.batch_size1
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec1 = os.path.join(data_dir, "train.rec")

    data_dir_interclass_list = args.data_dir_interclass.split(',')
    assert len(data_dir_interclass_list) == 1
    data_dir_interclass = data_dir_interclass_list[0]
    path_imgrec2 = os.path.join(data_dir_interclass, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        print('sym: ', sym)

    else:
        vec = args.pretrained.split(',')
        #print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        #if 'fc7_weight' in arg_params.keys():
        #  del arg_params['fc7_weight']
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    sym_test, _, _ = get_symbol(args, arg_params, aux_params)
    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        mx_model=model,
        ctx=ctx,
        ctx_num=args.ctx_num,
        data_shape=data_shape,
        batch_size1=args.batch_size1,
        path_imgrec1=path_imgrec1,
        batchsize_id=args.batchsize_id,
        batch_size2=args.batch_size2,
        path_imgrec2=path_imgrec2,
        images_per_identity=args.images_per_identity,
        interclass_bag_size=args.bag_size,
        shuffle=True,
        aug_list=None,
        rand_mirror=True,
    )

    eval_metrics = [
        mx.metric.create(AccMetric()),
        mx.metric.create(LossValue()),
        mx.metric.create(LossValue2())
    ]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        #path = os.path.join('/ssd/MegaFace/MF2_aligned_pic9/', name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    model_t = None

    def ver_test(nbatch, model_t):
        results = []

        if model_t is None:
            all_layers = model.symbol.get_internals()
            symbol_t = all_layers['blockgrad0_output']
            model_t = mx.mod.Module(symbol=symbol_t,
                                    context=ctx,
                                    label_names=None)
            print([('data', (10, ) + data_shape)])
            model_t.bind(data_shapes=[('data', (10, ) + data_shape)])
            arg_t, aux_t = model.get_params()
            model_t.set_params(arg_t, aux_t)
        else:
            arg_t, aux_t = model.get_params()
            model_t.set_params(arg_t, aux_t)

        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model_t, 10, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        #if mbatch%1==0:
        #print('mbatch:',mbatch)
        #arg, aux = model.get_params()

        if mbatch == 1:
            ver_test(mbatch, model_t)
            arg, aux = model.get_params()
            mx.model.save_checkpoint(prefix, 1000, model.symbol, arg, aux)
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            arg, aux = model.get_params()
            mx.model.save_checkpoint(prefix, 0, model.symbol, arg, aux)
            acc_list = ver_test(mbatch, model_t)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f %1.5f %1.5f' %
                  (mbatch, highest_acc[0], highest_acc[1], highest_acc[2]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None

    #print('arg_params',arg_params,aux_params)
    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #30
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)
    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    print('Called with argument:', args)
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        # if args.finetune:
        #     def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
        #         """
        #         symbol: the pretrained network symbol
        #         arg_params: the argument parameters of the pretrained model
        #         num_classes: the number of classes for the fine-tune datasets
        #         layer_name: the layer name before the last fully-connected layer
        #         """
        #         all_layers = symbol.get_internals()
        #         # print(all_layers);exit(0)
        #         for k in arg_params:
        #             if k.startswith('fc'):
        #               print(k)
        #         exit(0)
        #         net = all_layers[layer_name + '_output']
        #         net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
        #         net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
        #         new_args = dict({k: arg_params[k] for k in arg_params if 'fc1' not in k})
        #         return (net, new_args)
        #     sym, arg_params = get_fine_tune_model(sym, arg_params, args.num_classes)

    if args.network[0] == 's':
        data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
        spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
    )
    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
    )

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r' or args.network[0] == 'y':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=base_lr,
                        momentum=base_mom,
                        wd=base_wd,
                        rescale_grad=_rescale)
    som = 20
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            if do_save:
                print('saving', msave)
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(
                args.beta_min,
                args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        #print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)
Пример #31
0
def train_net(args):
    ctx = []
    #ctx.append(mx.gpu(1)) #manual
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in range(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        logging.info('use cpu')
    else:
        logging.info('gpu num: %d', len(ctx))
    prefix = os.path.join(args.models_root,
                          '%s-%s-%s' % (args.network, args.loss, args.dataset),
                          'model')
    prefix_dir = os.path.dirname(prefix)
    logging.info('prefix %s', prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    args.ctx_num = len(ctx)
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = config.image_shape[2]
    config.batch_size = args.batch_size
    config.per_batch_size = args.per_batch_size

    data_dir = config.dataset_path
    path_imgrec = None
    path_imglist = None
    image_size = config.image_shape[0:2]
    assert len(image_size) == 2
    assert image_size[0] == image_size[1]
    logging.info('image_size %s', str(image_size))
    logging.info('num_classes %d', config.num_classes)
    path_imgrec = os.path.join(data_dir, "train.rec")

    logging.info('Called with argument: %s %s', str(args), str(config))
    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = args.pretrained_epoch
    if len(args.pretrained) == 0:  #no pretraining
        arg_params = None
        aux_params = None
        sym = get_symbol(args)
        if config.net_name == 'spherenet':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:  #load pretrained model
        logging.info('loading %s %s', str(args.pretrained),
                     str(args.pretrained_epoch))
        _, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)
        sym = get_symbol(args)

    if config.count_flops:
        all_layers = sym.get_internals()
        _sym = all_layers['fc1_output']
        FLOPs = flops_counter.count_flops(_sym,
                                          data=(1, 3, image_size[0],
                                                image_size[1]))
        _str = flops_counter.flops_str(FLOPs)
        logging.info('Network FLOPs: %s' % _str)

    #label_name = 'softmax_label'
    #label_shape = (args.batch_size,)
    model = mx.mod.Module(  #executable options and full model is loaded, loss functions and all
        context=ctx,
        symbol=sym,
        #fixed_param_names = ["conv_1_conv2d_weight","res_2_block0_conv_sep_conv2d_weight","res_2_block0_conv_dw_conv2d_weight","res_2_block0_conv_proj_conv2d_weight","res_2_block1_conv_sep_conv2d_weight","res_2_block1_conv_dw_conv2d_weight","res_2_block1_conv_proj_conv2d_weight","dconv_23_conv_sep_conv2d_weight","dconv_23_conv_dw_conv2d_weight","dconv_23_conv_proj_conv2d_weight","res_3_block0_conv_sep_conv2d_weight","res_3_block0_conv_dw_conv2d_weight","res_3_block0_conv_proj_conv2d_weight","res_3_block1_conv_sep_conv2d_weight","res_3_block1_conv_dw_conv2d_weight","res_3_block1_conv_proj_conv2d_weight","res_3_block2_conv_sep_conv2d_weight","res_3_block2_conv_dw_conv2d_weight","res_3_block2_conv_proj_conv2d_weight","res_3_block3_conv_sep_conv2d_weight","res_3_block3_conv_dw_conv2d_weight","res_3_block3_conv_proj_conv2d_weight","res_3_block4_conv_sep_conv2d_weight","res_3_block4_conv_dw_conv2d_weight","res_3_block4_conv_proj_conv2d_weight","res_3_block5_conv_sep_conv2d_weight","res_3_block5_conv_dw_conv2d_weight","res_3_block5_conv_proj_conv2d_weight","res_3_block6_conv_sep_conv2d_weight","res_3_block6_conv_dw_conv2d_weight","res_3_block6_conv_proj_conv2d_weight","res_3_block7_conv_sep_conv2d_weight","res_3_block7_conv_dw_conv2d_weight","res_3_block7_conv_proj_conv2d_weight","dconv_34_conv_sep_conv2d_weight","dconv_34_conv_dw_conv2d_weight","dconv_34_conv_proj_conv2d_weight","res_4_block0_conv_sep_conv2d_weight","res_4_block0_conv_dw_conv2d_weight","res_4_block0_conv_proj_conv2d_weight","res_4_block1_conv_sep_conv2d_weight","res_4_block1_conv_dw_conv2d_weight","res_4_block1_conv_proj_conv2d_weight","res_4_block2_conv_sep_conv2d_weight","res_4_block2_conv_dw_conv2d_weight","res_4_block2_conv_proj_conv2d_weight","res_4_block3_conv_sep_conv2d_weight","res_4_block3_conv_dw_conv2d_weight","res_4_block3_conv_proj_conv2d_weight","res_4_block4_conv_sep_conv2d_weight","res_4_block4_conv_dw_conv2d_weight","res_4_block4_conv_proj_conv2d_weight","res_4_block5_conv_sep_conv2d_weight","res_4_block5_conv_dw_conv2d_weight","res_4_block5_conv_proj_conv2d_weight","res_4_block6_conv_sep_conv2d_weight","res_4_block6_conv_dw_conv2d_weight","res_4_block6_conv_proj_conv2d_weight","res_4_block7_conv_sep_conv2d_weight","res_4_block7_conv_dw_conv2d_weight","res_4_block7_conv_proj_conv2d_weight","res_4_block8_conv_sep_conv2d_weight","res_4_block8_conv_dw_conv2d_weight","res_4_block8_conv_proj_conv2d_weight","res_4_block9_conv_sep_conv2d_weight","res_4_block9_conv_dw_conv2d_weight","res_4_block9_conv_proj_conv2d_weight","res_4_block10_conv_sep_conv2d_weight","res_4_block10_conv_dw_conv2d_weight","res_4_block10_conv_proj_conv2d_weight","res_4_block11_conv_sep_conv2d_weight","res_4_block11_conv_dw_conv2d_weight","res_4_block11_conv_proj_conv2d_weight","res_4_block12_conv_sep_conv2d_weight","res_4_block12_conv_dw_conv2d_weight","res_4_block12_conv_proj_conv2d_weight","res_4_block13_conv_sep_conv2d_weight","res_4_block13_conv_dw_conv2d_weight","res_4_block13_conv_proj_conv2d_weight","res_4_block14_conv_sep_conv2d_weight","res_4_block14_conv_dw_conv2d_weight","res_4_block14_conv_proj_conv2d_weight","res_4_block15_conv_sep_conv2d_weight","res_4_block15_conv_dw_conv2d_weight","res_4_block15_conv_proj_conv2d_weight","dconv_45_conv_sep_conv2d_weight","dconv_45_conv_dw_conv2d_weight","dconv_45_conv_proj_conv2d_weight"],
        #fixed_param_names = ['convolution'+str(i)+'_weight' for i in range(1,40)],
    )
    val_dataiter = None

    if config.loss_name.find('triplet') >= 0:  #if triplet or atriplet loss
        from triplet_image_iter import FaceImageIter
        triplet_params = [
            config.triplet_bag_size, config.triplet_alpha,
            config.triplet_max_ap
        ]
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            ctx_num=args.ctx_num,
            images_per_identity=config.images_per_identity,
            triplet_params=triplet_params,
            mx_model=model,
        )
        _metric = LossValueMetric()
        eval_metrics = [mx.metric.create(_metric)]
    else:
        from image_iter import FaceImageIter
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=config.data_rand_mirror,
            mean=mean,
            cutoff=config.data_cutoff,
            color_jittering=config.data_color,
            images_filter=config.data_images_filter,
        )
        metric1 = AccMetric()
        eval_metrics = [mx.metric.create(metric1)]
        if config.ce_loss:
            metric2 = LossValueMetric()
            eval_metrics.append(mx.metric.create(metric2))
    gaussian_nets = [
        'fresnet', 'fmobilefacenet', 'fsqueezefacenet_v1',
        'fsqueezefacenet_v2', 'fshufflefacenetv2', 'fshufflenetv1',
        'fsqueezenet1_0', 'fsqueezenet1_1', 'fsqueezenet1_2',
        'fsqueezenet1_1_no_pool', 'fmobilenetv2', 'fmobilefacenetv1',
        'fmobilenetv2_mxnet', 'vargfacenet', 'mobilenetv3'
    ]
    if config.net_name in gaussian_nets:

        #    if config.net_name=='fresnet' or config.net_name=='fmobilefacenet' or config.net_name=='fsqueezefacenet_v1' or config.net_name=='fsqueezefacenet_v2' or config.net_name=='fshufflefacenetv2' or config.net_name =='fefficientnet' or config.net_name == 'fshufflenetv1' or config.net_name == 'fsqueezenet1_0' or config.net_name == 'fsqueezenet1_1' or config.net_name =='fsqueezenet1_2' or config.net_name == 'fsqueezenet1_1_no_pool' or config.net_name == 'fmobilenetv2':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
        print("GAUSSIAN INITIALIZER")
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
        print("UNIFORM INITIALIZER")
    #initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    _rescale = 1.0 / args.ctx_num
    opt = optimizer.SGD(learning_rate=args.lr,
                        momentum=args.mom,
                        wd=args.wd,
                        rescale_grad=_rescale)
    #opt = optimizer.Adam(learning_rate=args.lr,wd=args.wd,rescale_grad=_rescale)
    _cb = mx.callback.Speedometer(args.batch_size, args.frequent)

    ver_list = []
    ver_name_list = []
    for name in config.val_targets:
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            logging.info('ver %s', name)

    def ver_test(nbatch):
        results = []
        for i in range(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, None)
            logging.info('[%s][%d]XNorm: %f' %
                         (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                         (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    highest_train_acc = [0.0, 0.0]
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    lr_steps = [int(x) for x in args.lr_steps.split(',')]
    logging.info('lr_steps %s', str(lr_steps))

    def _batch_callback(param):
        #global global_step
        #weights_db = model.get_params()[0]['fire2_act_squeeze_1x1'].asnumpy()

        #weights_db = model.get_params()[0]
        #print(str(weights_db.keys()))

        #for k, v in weights_db.items():
        #print(k)
        #     if(np.any(np.isnan(v.asnumpy())) or np.any(np.isinf(v.asnumpy()))):
        #         print("nan or inf weight found at "+k)

        #name_value = param.eval_metric.get_name_value()
        #for name, value in name_value:
        #    logging.info('Epoch[%d] Validation-%s=%f', param.epoch, name, value)
        loss = param.eval_metric.get_name_value()[1]
        train_acc = param.eval_metric.get_name_value()[0]
        #print(loss)
        #if (np.isnan(loss[1])):
        #    print("Nan loss found")
        #    f = open("nan_loss_weights.txt", "w")
        #    f.write("batch #"+str(global_step[0])+"\n"+str(loss)+"\n"+str(weights_db.keys())+"\n"+str(weights_db))
        #    f.close()
        #    print("Written file at: nan_loss_weights.txt")
        #    exit()

        global_step[0] += 1
        mbatch = global_step[0]
        for step in lr_steps:
            if mbatch == step:
                opt.lr *= 0.1
                logging.info('lr change to %f', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            logging.info('lr-batch-epoch: %f %d %d', opt.lr, param.nbatch,
                         param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            is_highest = False
            if len(acc_list) > 0:
                #lfw_score = acc_list[0]
                #if lfw_score>highest_acc[0]:
                #  highest_acc[0] = lfw_score
                #  if lfw_score>=0.998:
                #    do_save = True
                score = sum(acc_list)
                if acc_list[-1] >= highest_acc[-1]:
                    if acc_list[-1] > highest_acc[-1]:
                        is_highest = True
                    else:
                        if score >= highest_acc[0]:
                            is_highest = True
                            highest_acc[0] = score
                    highest_acc[-1] = acc_list[-1]
                    #if lfw_score>=0.99:
                    #  do_save = True
            if is_highest:
                do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt == 2:
                do_save = True
            elif args.ckpt == 3:
                msave = 1

            if do_save:
                logging.info('saving %d', msave)
                arg, aux = model.get_params()
                if config.ckpt_embedding:
                    all_layers = model.symbol.get_internals()
                    _sym = all_layers['fc1_output']
                    _arg = {}
                    for k in arg:
                        if not k.startswith('fc7'):
                            _arg[k] = arg[k]
                    mx.model.save_checkpoint(prefix, param.epoch + 1, _sym,
                                             _arg, aux)
                else:
                    mx.model.save_checkpoint(prefix, param.epoch + 1,
                                             model.symbol, arg, aux)
            logging.info('[%d]Accuracy-Highest: %1.5f' %
                         (mbatch, highest_acc[-1]))
        if config.max_steps > 0 and mbatch > config.max_steps:
            sys.exit(0)

    epoch_cb = None
    train_dataiter = mx.io.PrefetchingIter(train_dataiter)

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=999999,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore=args.kvstore,
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)