def check(shape, fail_cond, fail_func): a = jt.random(shape) selected = jt.candidate(a, fail_cond) a_ = a.data selected_out = selected.data selected_ans = check_candidate(a_, fail_func) assert selected_out.tolist() == selected_ans.tolist(), (selected_out, selected_ans)
def nms(dets,thresh): ''' dets jt.array [x1,y1,x2,y2,score] x(:,0)->x1,x(:,1)->y1,x(:,2)->x2,x(:,3)->y2,x(:,4)->score ''' threshold = str(thresh) order = jt.argsort(dets[:,4],descending=True)[0] dets = dets[order] s_1 = '(@x(j,2)-@x(j,0)+1)*(@x(j,3)-@x(j,1)+1)' s_2 = '(@x(i,2)-@x(i,0)+1)*(@x(i,3)-@x(i,1)+1)' s_inter_w = 'max((Tx)0,min(@x(j,2),@x(i,2))-max(@x(j,0),@x(i,0))+1)' s_inter_h = 'max((Tx)0,min(@x(j,3),@x(i,3))-max(@x(j,1),@x(i,1))+1)' s_inter = s_inter_h+'*'+s_inter_w iou = s_inter + '/(' + s_1 +'+' + s_2 + '-' + s_inter + ')' fail_cond = iou+'>'+threshold selected = jt.candidate(dets, fail_cond) return order[selected]