def validate(val_data, val_dataset, net, ctx, opt): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric = COCOKeyPointsMetric(val_dataset, 'coco_keypoints', in_vis_thresh=0) for batch in tqdm(val_data, dynamic_ncols=True): # data, scale, center, score, imgid = val_batch_fn(batch, ctx) data, scale_box, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [ flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip ] outputs = [(o + o_flip) / 2 for o, o_flip in zip(outputs, outputs_flipback)] if len(outputs) > 1: outputs_stack = nd.concat( *[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) # preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) preds, maxvals = heatmap_to_coord_alpha_pose(outputs_stack, scale_box) val_metric.update(preds, maxvals, score, imgid) nullwriter = NullWriter() oldstdout = sys.stdout sys.stdout = nullwriter try: res = val_metric.get() finally: sys.stdout = oldstdout return res
transform_val = AlphaPoseDefaultValTransform(num_joints=val_dataset.num_joints, joint_pairs=val_dataset.joint_pairs, image_size=input_size) val_data = gluon.data.DataLoader( val_dataset.transform(transform_val), batch_size=batch_size, shuffle=False, last_batch='keep', num_workers=num_workers) return val_dataset, val_data, val_batch_fn input_size = [int(i) for i in opt.input_size.split(',')] val_dataset, val_data, val_batch_fn = get_data_loader(opt.dataset, batch_size, num_workers, input_size) val_metric = COCOKeyPointsMetric(val_dataset, 'coco_keypoints', data_shape=tuple(input_size), in_vis_thresh=opt.score_threshold) use_pretrained = True if not opt.params_file else False model_name = '_'.join((opt.model, opt.dataset)) kwargs = {'ctx': context, 'pretrained': use_pretrained, 'num_gpus': num_gpus} net = get_model(model_name, **kwargs) if not use_pretrained: net.load_parameters(opt.params_file, ctx=context) net.hybridize() def validate(val_data, val_dataset, net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx]