Ejemplo n.º 1
0
  def test_iris(self):
    num_constraints = 1500

    n = self.iris_points.shape[0]
    W = SDML.prepare_constraints(self.iris_labels, n, num_constraints)
    sdml = SDML(self.iris_points, W)
    sdml.fit()

    csep = class_separation(sdml.transform(), self.iris_labels)
    self.assertLess(csep, 0.25)
Ejemplo n.º 2
0
def sandwich_demo():
  x, y = sandwich_data()
  knn = nearest_neighbors(x, k=2)
  ax = pyplot.subplot(3, 1, 1)  # take the whole top row
  plot_sandwich_data(x, y, ax)
  plot_neighborhood_graph(x, knn, y, ax)
  ax.set_title('input space')
  ax.set_aspect('equal')
  ax.set_xticks([])
  ax.set_yticks([])

  num_constraints = 60
  mls = [
    LMNN(x, y),
    ITML(x, ITML.prepare_constraints(y, len(x), num_constraints)),
    SDML(x, SDML.prepare_constraints(y, len(x), num_constraints)),
    LSML(x, LSML.prepare_constraints(y, num_constraints))
  ]

  for ax_num, ml in zip(xrange(3,7), mls):
    ml.fit()
    tx = ml.transform()
    ml_knn = nearest_neighbors(tx, k=2)
    ax = pyplot.subplot(3,2,ax_num)
    plot_sandwich_data(tx, y, ax)
    plot_neighborhood_graph(tx, ml_knn, y, ax)
    ax.set_title('%s space' % ml.__class__.__name__)
    ax.set_xticks([])
    ax.set_yticks([])
  pyplot.show()