def validate(val_data, val_dataset, net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric.reset() from tqdm import tqdm for batch in tqdm(val_data): data, scale, center, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip] outputs = [(o + o_flip)/2 for o, o_flip in zip(outputs, outputs_flipback)] if len(outputs) > 1: outputs_stack = nd.concat(*[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) val_metric.update(preds, maxvals, score, imgid) res = val_metric.get() return
def heatmap_to_coord(heatmaps, bbox_list): center_list = [] scale_list = [] for i, bbox in enumerate(bbox_list): x0 = bbox[0] y0 = bbox[1] x1 = bbox[2] y1 = bbox[3] w = (x1 - x0) / 2 h = (y1 - y0) / 2 center_list.append(np.array([x0 + w, y0 + h])) scale_list.append(np.array([w, h])) coords, maxvals = get_final_preds(heatmaps, center_list, scale_list) return coords, maxvals
def validate(val_data, val_dataset, net, ctx): if isinstance(ctx, mx.Context): ctx = [ctx] val_metric.reset() from tqdm import tqdm for batch in tqdm(val_data): data, scale, center, score, imgid = val_batch_fn(batch, ctx) outputs = [net(X) for X in data] if opt.flip_test: data_flip = [nd.flip(X, axis=3) for X in data] outputs_flip = [net(X) for X in data_flip] outputs_flipback = [ flip_heatmap(o, val_dataset.joint_pairs, shift=True) for o in outputs_flip ] outputs = [(o + o_flip) / 2 for o, o_flip in zip(outputs, outputs_flipback)] if opt.dsnt: outputs = [net_dsnt(X)[0] for X in outputs] if len(outputs) > 1: outputs_stack = nd.concat( *[o.as_in_context(mx.cpu()) for o in outputs], dim=0) else: outputs_stack = outputs[0].as_in_context(mx.cpu()) if opt.dsnt: preds = (outputs_stack - 0.5) * scale.expand_dims( axis=1) + center.expand_dims(axis=1) maxvals = nd.ones(preds.shape[0:2] + (1, )) else: preds, maxvals = get_final_preds(outputs_stack, center.asnumpy(), scale.asnumpy()) val_metric.update(preds, maxvals, score, imgid) metric_name, metric_score = val_metric.get() print("Inference Completed! %s = %.4f" % (metric_name, metric_score)) return