Пример #1
0
vwb = VWB(x, [x1, x2], device=device, verbose=False)
output = vwb.cluster(max_iter_h=5000, max_iter_p=1)
idx = output['idx']

xmin, xmax, ymin, ymax = -1.0, 1.0, -0.5, 0.5

for k in [33]:
    plt.figure(figsize=(8, 4))
    for i in range(2):
        ce = np.array(plt.get_cmap('viridis')(idx[i].cpu().numpy() / (K - 1)))
        utils.scatter_otsamples(vwb.data_p,
                                vwb.data_e[i],
                                size_p=30,
                                marker_p='o',
                                color_x=ce,
                                xmin=xmin,
                                xmax=xmax,
                                ymin=ymin,
                                ymax=ymax,
                                facecolor_p='none')

    e0s = vwb.data_e[0][idx[0] == k]
    e1s = vwb.data_e[1][idx[1] == k]
    p = vwb.data_p[k]

    for e0, e1 in zip(e0s, e1s):
        x = [e1[0], p[0], e0[0]]
        y = [e1[1], p[1], e0[1]]
        plt.plot(x, y, c='lightgray', alpha=0.4)
    # plt.savefig("ship" + str(k) + ".svg")
    plt.savefig("ship" + str(k) + ".png", dpi=300, bbox_inches='tight')
Пример #2
0
vot = VOT(x, [x1, x2], verbose=False)
vot.cluster(max_iter_h=3000, max_iter_y=1)
idx = vot.idx

xmin, xmax, ymin, ymax = -1., 1., 0., 1.

for k in [21]:
    plt.figure(figsize=(8, 4))
    for i in range(2):
        ce = np.array(plt.get_cmap('viridis')(idx[i] / (K - 1)))
        utils.scatter_otsamples(vot.y,
                                vot.x[i],
                                size_p=30,
                                marker_p='o',
                                color_e=ce,
                                xmin=xmin,
                                xmax=xmax,
                                ymin=ymin,
                                ymax=ymax,
                                facecolor_p='none')

    p = vot.y[k]

    for i in range(2):
        es = vot.x[i][idx[i] == k]
        for e in es:
            x = [p[0], e[0]]
            y = [p[1], e[1]]
            plt.plot(x, y, c='lightgray', alpha=0.4)

    # plt.savefig("ship" + str(k) + ".svg")
Пример #3
0
                 vot.data_p,
                 fig232,
                 color=cp,
                 title='w/o reg map',
                 facecolor_after='none')

# ------ plot after ----- #
plt.subplot(233)
ce = np.array([utils.COLOR_LIGHT_BLUE,
               utils.COLOR_LIGHT_RED])[pred_label_e.int().cpu().numpy(), :]
cp = np.array([utils.COLOR_DARK_BLUE,
               utils.COLOR_RED])[vot.label_p.int().cpu().numpy(), :]
utils.scatter_otsamples(vot.data_p,
                        vot.data_e,
                        size_p=30,
                        marker_p='o',
                        color_p=cp,
                        color_e=ce,
                        title='w/o reg after',
                        facecolor_p='none')

# -------------------------------------- #
# --------- w/ regularization ---------- #
# -------------------------------------- #

# ------- run RWM ------- #
data_p = np.loadtxt('data/p.csv', delimiter=",")
data_p = torch.from_numpy(data_p).double().to(device)
vot_reg = VotReg(data_p[:, 1:],
                 data_e[:, 1:],
                 data_p[:, 0],
                 data_e[:, 0],
Пример #4
0
kmeans = KMeans(n_clusters=K, init=x).fit(x0)

label = kmeans.predict(x0)
newx = kmeans.cluster_centers_

color_map = np.array([[237, 125, 49, 255], [112, 173, 71, 255],
                      [91, 155, 213, 255]]) / 255

plt.figure(figsize=(4, 4))
for i in range(1):
    ce = color_map[label]
    utils.scatter_otsamples(newx,
                            x0,
                            size_p=30,
                            marker_p='o',
                            color_e=ce,
                            xmin=xmin,
                            xmax=xmax,
                            ymin=ymin,
                            ymax=ymax,
                            facecolor_p='none')
plt.axis('off')
# plt.savefig("kmeans.svg", bbox_inches='tight')
plt.savefig("kmeans.png", dpi=300, bbox_inches='tight')

use_gpu = False
if use_gpu and torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

# ---------------VWB---------------
Пример #5
0
idx, _ = vot.map(max_iter=3000)
tock = time.time()
print('total time: {0:.4f}'.format(tock - tick))

# Note: Area preserving usually requires a pre-defined boundary.
#  That is beyond the scope of the demo. Missing the boundary condition,
#  this area-preserving demo might not produce accurate maps near the boundary.
#  This can be visualized by drawing the Voronoi diagram or Delaunay triangulation
#  and one may see slight intersection near the boundary centroids.

# ----- plot before ----- #
plt.figure(figsize=(12, 8))
plt.subplot(231)
utils.scatter_otsamples(
    vot.data_p_original,
    vot.x,
    title='before',
)

# ------ plot map ------- #
fig232 = plt.subplot(232)
utils.plot_otmap(vot.data_p_original,
                 vot.y,
                 fig232,
                 title='vot map',
                 facecolor_after='none')

# ------ plot after ----- #
ce = np.array(plt.get_cmap('viridis')(idx / (N - 1)))
plt.subplot(233)
utils.scatter_otsamples(vot.y, vot.x, color_x=ce, title='after')
Пример #6
0
kmeans = KMeans(n_clusters=K, init=y).fit(x)

label = kmeans.predict(x)
y = kmeans.cluster_centers_

color_map = np.array([[237, 125, 49, 255], [112, 173, 71, 255],
                      [91, 155, 213, 255]]) / 255

plt.figure(figsize=(4, 4))
for i in range(1):
    ce = color_map[label]
    utils.scatter_otsamples(y,
                            x,
                            size_p=30,
                            marker_p='o',
                            color_e=ce,
                            xmin=xmin,
                            xmax=xmax,
                            ymin=ymin,
                            ymax=ymax,
                            facecolor_p='none')
plt.axis('off')
# plt.savefig("kmeans.svg", bbox_inches='tight')
plt.savefig("kmeans.png", dpi=300, bbox_inches='tight')

# --------------- OT ---------------
y_copy = y.copy()
x_copy = x.copy()

vwb = VOT(y_copy, [x_copy], verbose=False)
output = vwb.cluster(lr=0.5,
                     max_iter_h=20,