예제 #1
0
def _gen_trias(diag_pts, diag_pos, c, theta_invl=15, max_diff=1):
    max_label = 360 / theta_invl
    idx_a = np.arange(0, diag_pts.shape[0])
    idx_b = np.arange(0, c.pos.shape[0])
    idx_a, idx_b = np.meshgrid(idx_a, idx_b)
    idx_a = idx_a.ravel()
    idx_b = idx_b.ravel()
    tria_pts = np.hstack((diag_pts[idx_a, :], c.pos[idx_b, :]))
    tria_pos = np.hstack(
        (diag_pos[idx_a, :], get_value(c.pos[idx_b, :], c.prb)))
    areas = compute_tria_area(tria_pts[:, 0:2], tria_pts[:, 2:4],
                              tria_pts[:, 4:6])
    keep = np.where(areas != 0)[0]
    tria_pts = tria_pts[keep, :]
    tria_pos = tria_pos[keep, :]
    ws, hs, ctr_x, ctr_y = whctrs(tria_pts[:, 0:4])
    prac_theta = compute_theta(tria_pts[:, 4:6],
                               np.vstack((ctr_x, ctr_y)).transpose())
    prac_label = np.floor(prac_theta / theta_invl) + 1
    pred_label = get_value(tria_pts[:, 4:6], c.cls)
    diff_label = diff_link(prac_label, pred_label, max_label)
    keep = np.where(diff_label <= max_diff)[0]
    tria_pts = tria_pts[keep, :]
    tria_pos = tria_pos[keep, :]
    prac_theta = prac_theta[keep]
    prac_theta = np.mod(prac_theta + 180.0, 360.0) / 180.0 * np.pi
    len_diag = np.sqrt(
        np.sum(np.square(tria_pts[:, 0:2] - tria_pts[:, 2:4]), axis=1)) / 2.
    dist_x = len_diag * np.cos(prac_theta[:, 0])
    dist_y = len_diag * np.sin(prac_theta[:, 0])
    ws, hs, ctr_x, ctr_y = whctrs(tria_pts[:, 0:4])
    tria_pts[:, 4:6] = np.vstack(
        (ctr_x + dist_x, ctr_y - dist_y)).astype(np.int32,
                                                 copy=False).transpose()
    return tria_pts, tria_pos
예제 #2
0
def _get_last_one(tria, d):
    map_shape = d.prb.shape[:2]
    ws, hs, ctr_x, ctr_y = whctrs(tria[:, 0:4])
    pos = np.vstack(
        (2 * ctr_x - tria[:, 4], 2 * ctr_y - tria[:, 5])).transpose()
    pos[:, 0] = np.maximum(np.minimum(pos[:, 0], map_shape[1] - 1), 0)
    pos[:, 1] = np.maximum(np.minimum(pos[:, 1], map_shape[0] - 1), 0)
    pos = np.array(pos, dtype=np.int32)
    prb = get_value(pos, d.prb)
    return pos, prb