示例#1
0
def validate(val_data, val_dataset, net, ctx, opt):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    val_metric = COCOKeyPointsMetric(val_dataset,
                                     'coco_keypoints',
                                     in_vis_thresh=0)

    for batch in tqdm(val_data, dynamic_ncols=True):
        # data, scale, center, score, imgid = val_batch_fn(batch, ctx)
        data, scale_box, score, imgid = val_batch_fn(batch, ctx)

        outputs = [net(X) for X in data]
        if opt.flip_test:
            data_flip = [nd.flip(X, axis=3) for X in data]
            outputs_flip = [net(X) for X in data_flip]
            outputs_flipback = [
                flip_heatmap(o, val_dataset.joint_pairs, shift=True)
                for o in outputs_flip
            ]
            outputs = [(o + o_flip) / 2
                       for o, o_flip in zip(outputs, outputs_flipback)]

        if len(outputs) > 1:
            outputs_stack = nd.concat(
                *[o.as_in_context(mx.cpu()) for o in outputs], dim=0)
        else:
            outputs_stack = outputs[0].as_in_context(mx.cpu())

        # preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy())
        preds, maxvals = heatmap_to_coord_alpha_pose(outputs_stack, scale_box)
        val_metric.update(preds, maxvals, score, imgid)

    nullwriter = NullWriter()
    oldstdout = sys.stdout
    sys.stdout = nullwriter
    try:
        res = val_metric.get()
    finally:
        sys.stdout = oldstdout
    return res
示例#2
0
    transform_val = AlphaPoseDefaultValTransform(num_joints=val_dataset.num_joints,
                                                 joint_pairs=val_dataset.joint_pairs,
                                                 image_size=input_size)
    val_data = gluon.data.DataLoader(
        val_dataset.transform(transform_val),
        batch_size=batch_size, shuffle=False, last_batch='keep',
        num_workers=num_workers)

    return val_dataset, val_data, val_batch_fn

input_size = [int(i) for i in opt.input_size.split(',')]
val_dataset, val_data, val_batch_fn = get_data_loader(opt.dataset, batch_size,
                                                      num_workers, input_size)
val_metric = COCOKeyPointsMetric(val_dataset, 'coco_keypoints',
                                 data_shape=tuple(input_size),
                                 in_vis_thresh=opt.score_threshold)

use_pretrained = True if not opt.params_file else False
model_name = '_'.join((opt.model, opt.dataset))
kwargs = {'ctx': context,
          'pretrained': use_pretrained,
          'num_gpus': num_gpus}
net = get_model(model_name, **kwargs)
if not use_pretrained:
    net.load_parameters(opt.params_file, ctx=context)
net.hybridize()

def validate(val_data, val_dataset, net, ctx):
    if isinstance(ctx, mx.Context):
        ctx = [ctx]