def validate(val_data, val_dataset, net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric.reset() from tqdm import tqdm for batch in tqdm(val_data): # data, scale, center, score, imgid = val_batch_fn(batch, ctx) data, scale_box, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip] outputs = [(o + o_flip)/2 for o, o_flip in zip(outputs, outputs_flipback)] if len(outputs) > 1: outputs_stack = nd.concat(*[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) # preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) preds, maxvals = heatmap_to_coord_alpha_pose(outputs_stack, scale_box) # print(preds, maxvals, scale_box) # print(preds, maxvals) # raise val_metric.update(preds, maxvals, score, imgid) res = val_metric.get() return
def validate(val_data, val_dataset, net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric.reset() from tqdm import tqdm for batch in tqdm(val_data): data, scale, center, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip] outputs = [(o + o_flip)/2 for o, o_flip in zip(outputs, outputs_flipback)] if len(outputs) > 1: outputs_stack = nd.concat(*[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) val_metric.update(preds, maxvals, score, imgid) res = val_metric.get() return
def validate(val_data, val_dataset, net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric.reset() from tqdm import tqdm for batch in tqdm(val_data): data, scale, center, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [ flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip ] outputs = [(o + o_flip) / 2 for o, o_flip in zip(outputs, outputs_flipback)] if opt.dsnt: outputs = [net_dsnt(X)[0] for X in outputs] if len(outputs) > 1: outputs_stack = nd.concat( *[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) if opt.dsnt: preds = (outputs_stack - 0.5) * scale.expand_dims( axis=1) + center.expand_dims(axis=1) maxvals = nd.ones(preds.shape[0:2] + (1, )) else: preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) val_metric.update(preds, maxvals, score, imgid) metric_name, metric_score = val_metric.get() print("Inference Completed! %s = %.4f" % (metric_name, metric_score)) return
def validate(val_data, val_dataset, net, ctx, opt): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric = COCOKeyPointsMetric(val_dataset, 'coco_keypoints', in_vis_thresh=0) for batch in tqdm(val_data, dynamic_ncols=True): # data, scale, center, score, imgid = val_batch_fn(batch, ctx) data, scale_box, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [ flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip ] outputs = [(o + o_flip) / 2 for o, o_flip in zip(outputs, outputs_flipback)] if len(outputs) > 1: outputs_stack = nd.concat( *[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) # preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) preds, maxvals = heatmap_to_coord_alpha_pose(outputs_stack, scale_box) val_metric.update(preds, maxvals, score, imgid) nullwriter = NullWriter() oldstdout = sys.stdout sys.stdout = nullwriter try: res = val_metric.get() finally: sys.stdout = oldstdout return res