コード例 #1
0
ファイル: train.py プロジェクト: anminhhung/Face-Align
 def val_test():
     all_layers = sym.get_internals()
     vsym = all_layers['heatmap_output']
     vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names=None)
     #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
     vmodel.bind(data_shapes=[('data', (args.batch_size, ) + data_shape)])
     arg_params, aux_params = model.get_params()
     vmodel.set_params(arg_params, aux_params)
     for target in config.val_targets:
         _file = os.path.join(config.dataset_path, '%s.rec' % target)
         if not os.path.exists(_file):
             continue
         val_iter = FaceSegIter(
             path_imgrec=_file,
             batch_size=args.batch_size,
             #batch_size = 4,
             aug_level=0,
             args=args,
         )
         _metric = NMEMetric()
         val_metric = mx.metric.create(_metric)
         val_metric.reset()
         val_iter.reset()
         for i, eval_batch in enumerate(val_iter):
             #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
             batch_data = mx.io.DataBatch(eval_batch.data)
             model.forward(batch_data, is_train=False)
             model.update_metric(val_metric, eval_batch.label)
         nme_value = val_metric.get_name_value()[0][1]
         print('[%d][%s]NME: %f' % (global_step[0], target, nme_value))
コード例 #2
0
else:
  ctx = mx.cpu()
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['heatmap_output']
#model = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=['softmax_label'])
model = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=None)
#model = mx.mod.Module(symbol=sym, context=ctx)
model.bind(for_training=False, data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)

val_iter = FaceSegIter(path_imgrec = rec_path,
  batch_size = 1,
  aug_level = 0,
  )
_metric = NMEMetric()
#val_metric = mx.metric.create(_metric)
#val_metric.reset()
#val_iter.reset()
nme = []
for i, eval_batch in enumerate(val_iter):
  if i%10==0:
    print('processing', i)
  #print(eval_batch.data[0].shape, eval_batch.label[0].shape)
  batch_data = mx.io.DataBatch(eval_batch.data)
  model.forward(batch_data, is_train=False)
  #model.update_metric(val_metric, eval_batch.label, True)
  pred_label = model.get_outputs()[-1].asnumpy()
  label = eval_batch.label[0].asnumpy()
  _nme = _metric.cal_nme(label, pred_label)
  nme.append(_nme)