Beispiel #1
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd)>0:
      for i in xrange(len(cvd.split(','))):
        ctx.append(mx.gpu(i))
    if len(ctx)==0:
      ctx = [mx.cpu()]
      print('use cpu')
    else:
      print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
      os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size==0:
      args.per_batch_size = 128
      if args.loss_type==10:
        args.per_batch_size = 256
    args.batch_size = args.per_batch_size*args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    assert len(ppatch)==5


    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    if args.loss_type!=12 and args.loss_type!=13:
      assert len(data_dir_list)==1
    data_dir = data_dir_list[0]
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert(args.num_classes>0)
    print('num_classes', args.num_classes)
    args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")
    val_rec = os.path.join(data_dir, "val.rec")
    if os.path.exists(val_rec) and args.loss_type<10:
      args.use_val = True
    else:
      val_rec = None
    #args.use_val = False

    if args.loss_type==1 and args.num_classes>20000:
      args.beta_freeze = 5000
      args.gamma = 0.06

    if args.loss_type<9:
      assert args.images_per_identity==0
    else:
      if args.images_per_identity==0:
        if args.loss_type==11:
          args.images_per_identity = 2
        elif args.loss_type==10 or args.loss_type==9:
          args.images_per_identity = 16
        elif args.loss_type==12 or args.loss_type==13:
          args.images_per_identity = 5
          assert args.per_batch_size%3==0
      assert args.images_per_identity>=2
      args.per_identities = int(args.per_batch_size/args.images_per_identity)

    print('Called with argument:', args)

    data_shape = (args.image_channel,image_size[0],image_size[1])
    mean = None




    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained)==0:
      arg_params = None
      aux_params = None
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
      vec = args.pretrained.split(',')
      print('loading', vec)
      _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
      sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    if args.network[0]=='s':
      data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
      spherenet.init_weights(sym, data_shape_dict, args.num_layers)

    data_extra = None
    hard_mining = False
    triplet_params = None
    coco_mode = False
    if args.loss_type==10:
      hard_mining = True
      _shape = (args.batch_size, args.per_batch_size)
      data_extra = np.full(_shape, -1.0, dtype=np.float32)
      c = 0
      while c<args.batch_size:
        a = 0
        while a<args.per_batch_size:
          b = a+args.images_per_identity
          data_extra[(c+a):(c+b),a:b] = 1.0
          #print(c+a, c+b, a, b)
          a = b
        c += args.per_batch_size
    elif args.loss_type==11:
      data_extra = np.zeros( (args.batch_size, args.per_identities), dtype=np.float32)
      c = 0
      while c<args.batch_size:
        for i in xrange(args.per_identities):
          data_extra[c+i][i] = 1.0
        c+=args.per_batch_size
    elif args.loss_type==12 or args.loss_type==13:
      triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]
    elif args.loss_type==9:
      coco_mode = True

    label_name = 'softmax_label'
    label_shape = (args.batch_size,)
    if args.output_c2c:
      label_shape = (args.batch_size,2)
    if data_extra is None:
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
      )
    else:
      data_names = ('data', 'extra')
      #label_name = ''
      model = mx.mod.Module(
          context       = ctx,
          symbol        = sym,
          data_names    = data_names,
          label_names   = (label_name,),
      )

    if args.use_val:
      val_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = val_rec,
          #path_imglist         = val_path,
          shuffle              = False,
          rand_mirror          = False,
          mean                 = mean,
          ctx_num              = args.ctx_num,
          data_extra           = data_extra,
      )
    else:
      val_dataiter = None

    if len(data_dir_list)==1 and args.loss_type!=12 and args.loss_type!=13:
      train_dataiter = FaceImageIter(
          batch_size           = args.batch_size,
          data_shape           = data_shape,
          path_imgrec          = path_imgrec,
          shuffle              = True,
          rand_mirror          = args.rand_mirror,
          mean                 = mean,
          cutoff               = args.cutoff,
          c2c_threshold        = args.c2c_threshold,
          output_c2c           = args.output_c2c,
          c2c_mode             = args.c2c_mode,
          ctx_num              = args.ctx_num,
          images_per_identity  = args.images_per_identity,
          data_extra           = data_extra,
          hard_mining          = hard_mining,
          triplet_params       = triplet_params,
          coco_mode            = coco_mode,
          mx_model             = model,
          label_name           = label_name,
      )
    else:
      iter_list = []
      for _data_dir in data_dir_list:
        _path_imgrec = os.path.join(_data_dir, "train.rec")
        _dataiter = FaceImageIter(
            batch_size           = args.batch_size,
            data_shape           = data_shape,
            path_imgrec          = _path_imgrec,
            shuffle              = True,
            rand_mirror          = args.rand_mirror,
            mean                 = mean,
            cutoff               = args.cutoff,
            c2c_threshold        = args.c2c_threshold,
            output_c2c           = args.output_c2c,
            c2c_mode             = args.c2c_mode,
            ctx_num              = args.ctx_num,
            images_per_identity  = args.images_per_identity,
            data_extra           = data_extra,
            hard_mining          = hard_mining,
            triplet_params       = triplet_params,
            coco_mode            = coco_mode,
            mx_model             = model,
            label_name           = label_name,
        )
        iter_list.append(_dataiter)
      iter_list.append(_dataiter)
      train_dataiter = FaceImageIterList(iter_list)

    if args.loss_type<10:
      _metric = AccMetric()
    else:
      _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0]=='r':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) #resnet style
    elif args.network[0]=='i' or args.network[0]=='x':
      initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) #inception
    else:
      initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0/args.ctx_num
    if args.noise_sgd>0.0:
      print('use noise sgd')
      opt = NoiseSGD(scale = args.noise_sgd, learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    else:
      opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    if args.loss_type==12 or args.loss_type==13:
      som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
      path = os.path.join(data_dir,name+".bin")
      if os.path.exists(path):
        data_set = verification.load_bin(path, image_size)
        ver_list.append(data_set)
        ver_name_list.append(name)
        print('ver', name)



    def ver_test(nbatch):
      results = []
      for i in xrange(len(ver_list)):
        acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, data_extra, label_shape)
        print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
        #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
        print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
        results.append(acc2)
      return results


    def val_test():
      acc = AccMetric()
      val_metric = mx.metric.create(acc)
      val_metric.reset()
      val_dataiter.reset()
      for i, eval_batch in enumerate(val_dataiter):
        model.forward(eval_batch, is_train=False)
        model.update_metric(val_metric, eval_batch.label)
      acc_value = val_metric.get_name_value()[0][1]
      print('VACC: %f'%(acc_value))


    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps)==0:
      lr_steps = [40000, 60000, 80000]
      if args.loss_type>=1 and args.loss_type<=7:
        lr_steps = [100000, 140000, 160000]
      p = 512.0/args.batch_size
      for l in xrange(len(lr_steps)):
        lr_steps[l] = int(lr_steps[l]*p)
    else:
      lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)
    def _batch_callback(param):
      #global global_step
      global_step[0]+=1
      mbatch = global_step[0]
      for _lr in lr_steps:
        if mbatch==args.beta_freeze+_lr:
          opt.lr *= 0.1
          print('lr change to', opt.lr)
          break

      _cb(param)
      if mbatch%1000==0:
        print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)

      if mbatch>=0 and mbatch%args.verbose==0:
        acc_list = ver_test(mbatch)
        save_step[0]+=1
        msave = save_step[0]
        do_save = False
        if len(acc_list)>0:
          lfw_score = acc_list[0]
          if lfw_score>highest_acc[0]:
            highest_acc[0] = lfw_score
            if lfw_score>=0.998:
              do_save = True
          if acc_list[-1]>=highest_acc[-1]:
            highest_acc[-1] = acc_list[-1]
            if lfw_score>=0.99:
              do_save = True
        if args.ckpt==0:
          do_save = False
        elif args.ckpt>1:
          do_save = True
        #for i in xrange(len(acc_list)):
        #  acc = acc_list[i]
        #  if acc>=highest_acc[i]:
        #    highest_acc[i] = acc
        #    if lfw_score>=0.99:
        #      do_save = True
        #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
        #  do_save = True
        if do_save:
          print('saving', msave)
          if val_dataiter is not None:
            val_test()
          arg, aux = model.get_params()
          mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
          #if acc>=highest_acc[0]:
          #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
          #  X = np.concatenate(embeddings_list, axis=0)
          #  print('saving lfw npy', X.shape)
          #  np.save(lfw_npy, X)
        print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
      if mbatch<=args.beta_freeze:
        _beta = args.beta
      else:
        move = max(0, mbatch-args.beta_freeze)
        _beta = max(args.beta_min, args.beta*math.pow(1+args.gamma*move, -1.0*args.power))
      #print('beta', _beta)
      os.environ['BETA'] = str(_beta)
      if args.max_steps>0 and mbatch>args.max_steps:
        sys.exit(0)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None



    #def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(train_dataiter,
        begin_epoch        = begin_epoch,
        num_epoch          = end_epoch,
        eval_data          = val_dataiter,
        eval_metric        = eval_metrics,
        kvstore            = 'device',
        optimizer          = opt,
        #optimizer_params   = optimizer_params,
        initializer        = initializer,
        arg_params         = arg_params,
        aux_params         = aux_params,
        allow_missing      = True,
        batch_end_callback = _batch_callback,
        epoch_end_callback = epoch_cb )
Beispiel #2
0
def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
    args.batch_size = args.per_batch_size * args.ctx_num
    args.image_channel = 3

    data_dir_list = args.data_dir.split(',')
    assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    path_imgrec = None
    path_imglist = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)

    #path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")

    assert args.images_per_identity >= 2
    assert args.triplet_bag_size % args.batch_size == 0

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
        if args.network[0] == 's':
            data_shape_dict = {'data': (args.per_batch_size, ) + data_shape}
            spherenet.init_weights(sym, data_shape_dict, args.num_layers)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        sym, arg_params, aux_params = mx.model.load_checkpoint(
            vec[0], int(vec[1]))
        all_layers = sym.get_internals()
        sym = all_layers['fc1_output']
        sym, arg_params, aux_params = get_symbol(args,
                                                 arg_params,
                                                 aux_params,
                                                 sym_embedding=sym)

    data_extra = None
    hard_mining = False
    triplet_params = [
        args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap
    ]
    model = mx.mod.Module(
        context=ctx,
        symbol=sym,
        #data_names = ('data',),
        #label_names = None,
        #label_names = ('softmax_label',),
    )
    label_shape = (args.batch_size, )

    val_dataiter = None

    train_dataiter = FaceImageIter(
        batch_size=args.batch_size,
        data_shape=data_shape,
        path_imgrec=path_imgrec,
        shuffle=True,
        rand_mirror=args.rand_mirror,
        mean=mean,
        cutoff=args.cutoff,
        ctx_num=args.ctx_num,
        images_per_identity=args.images_per_identity,
        triplet_params=triplet_params,
        mx_model=model,
    )

    _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="out",
                                     magnitude=2)  #resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian',
                                     factor_type="in",
                                     magnitude=2)  #inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform',
                                     factor_type="in",
                                     magnitude=2)
    _rescale = 1.0 / args.ctx_num
    if args.noise_sgd > 0.0:
        print('use noise sgd')
        opt = NoiseSGD(scale=args.noise_sgd,
                       learning_rate=base_lr,
                       momentum=base_mom,
                       wd=base_wd,
                       rescale_grad=_rescale)
    else:
        opt = optimizer.SGD(learning_rate=base_lr,
                            momentum=base_mom,
                            wd=base_wd,
                            rescale_grad=_rescale)
    som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
                ver_list[i], model, args.batch_size, 10, None, label_shape)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            #print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' %
                  (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    highest_acc = [0.0, 0.0]  #lfw and target
    #for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [1000000000]
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        #global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            #for i in xrange(len(acc_list)):
            #  acc = acc_list[i]
            #  if acc>=highest_acc[i]:
            #    highest_acc[i] = acc
            #    if lfw_score>=0.99:
            #      do_save = True
            #if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
            #  do_save = True
            if do_save:
                print('saving', msave)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    #epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    model.fit(
        train_dataiter,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        eval_data=val_dataiter,
        eval_metric=eval_metrics,
        kvstore='device',
        optimizer=opt,
        #optimizer_params   = optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback=_batch_callback,
        epoch_end_callback=epoch_cb)