Beispiel #1
0
class TestKmeans(unittest.TestCase):
    def setUp(self):
        self.data = Table('iris')
        new_domain = Domain(self.data.domain.attributes[:2])
        self.data = Table(new_domain, self.data)
        # self.centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans = Kmeans(self.data)

    def test_k(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertEqual(self.kmeans.k, 3)
        self.assertEqual(self.kmeans.k, len(self.kmeans.centroids))

    def test_converged(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertFalse(self.kmeans.converged)

        self.kmeans.step()
        self.assertFalse(self.kmeans.converged)
        # step not complete so false every odd state

        self.kmeans.step()
        # it is even step so maybe converged but it depends on example
        # unable to test
        self.kmeans.step()
        self.assertFalse(self.kmeans.converged)

        # check if false every not completed step
        for i in range(self.kmeans.max_iter // 2 + 1):
            self.kmeans.step()
            self.kmeans.step()
            self.assertFalse(self.kmeans.converged)

        # it converged because of max iter
        self.kmeans.step()
        self.assertTrue(self.kmeans.converged)

    def test_centroids_belonging_points(self):
        centroids = [[5.2, 3.6]]
        self.kmeans.add_centroids(centroids)

        # if only one cluster all points in 0th element of first dimension
        np.testing.assert_equal(self.kmeans.centroids_belonging_points,
                                np.array([self.data.X]))

        # try with two clusters and less data
        self.kmeans.set_data(self.data[:3])
        self.kmeans.add_centroids([[4.7, 3.0]])
        desired_array = np.array([
            np.array([[5.100, 3.500]]),
            np.array([[4.900, 3.000], [4.700, 3.200]])
        ])
        for i, arr in enumerate(self.kmeans.centroids_belonging_points):
            np.testing.assert_equal(arr, desired_array[i])

    def test_step_completed(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertEqual(self.kmeans.step_completed, True)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, False)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, True)

    def test_set_data(self):
        self.kmeans.set_data(self.data)
        self.assertEqual(self.kmeans.data, self.data)
        self.assertEqual(self.kmeans.centroids_history, [])
        self.assertEqual(self.kmeans.step_no, 0)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)

        # try with none data
        self.kmeans.set_data(None)
        self.assertEqual(self.kmeans.data, None)
        self.assertEqual(self.kmeans.centroids_history, [])
        self.assertEqual(self.kmeans.clusters, None)
        self.assertEqual(self.kmeans.step_no, 0)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)

    def test_find_clusters(self):
        self.kmeans.add_centroids([[5.2, 3.6]])

        # if only one cluster all points in 0th element of first dimension
        np.testing.assert_equal(
            self.kmeans.find_clusters(self.kmeans.centroids),
            np.zeros(len(self.data)))

        # try with two clusters and less data
        self.kmeans.set_data(self.data[:3])
        self.kmeans.add_centroids([[4.7, 3.0]])
        np.testing.assert_equal(
            self.kmeans.find_clusters(self.kmeans.centroids),
            np.array([0, 1, 1]))

    def test_step(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        centroids_before = np.copy(self.kmeans.centroids)
        clusters_before = np.copy(self.kmeans.clusters)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, False)
        self.assertEqual(self.kmeans.centroids_moved, True)
        np.testing.assert_equal(centroids_before,
                                self.kmeans.centroids_history[-2])
        # clusters doesnt change in every odd step
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)

        centroids_before = np.copy(self.kmeans.centroids)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)

        centroids_before = np.copy(self.kmeans.centroids)
        self.kmeans.step()
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)

    def test_step_back(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)

        # check if nothing happens when step = 0
        centroids_before = np.copy(self.kmeans.centroids)
        clusters_before = np.copy(self.kmeans.clusters)
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)

        # check if centroids remain in even step
        self.kmeans.step()
        self.kmeans.step()

        centroids_before = self.kmeans.centroids
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)
        self.assertEqual(self.kmeans.step_completed, False)
        self.assertEqual(self.kmeans.centroids_moved, False)

        # check if clusters remain in even step
        clusters_before = self.kmeans.clusters
        self.kmeans.step_back()
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, True)

    def test_random_positioning(self):
        self.assertEqual(self.kmeans.random_positioning(4).shape, (4, 2))
        self.assertEqual(self.kmeans.random_positioning(1).shape, (1, 2))
        self.assertEqual(self.kmeans.random_positioning(0).shape, (0, ))
        self.assertEqual(self.kmeans.random_positioning(-1).shape, (0, ))

    def test_add_centroids(self):
        self.kmeans.add_centroids([[5.2, 3.1]])
        self.assertEqual(self.kmeans.k, 1)
        self.kmeans.add_centroids([[6.5, 3], [7, 4]])
        self.assertEqual(self.kmeans.k, 3)
        self.kmeans.add_centroids(2)
        self.assertEqual(self.kmeans.k, 5)
        self.kmeans.add_centroids()
        self.assertEqual(self.kmeans.k, 6)

        step_before = self.kmeans.step_no
        self.assertEqual(step_before, self.kmeans.step_no)
        self.kmeans.step()
        self.kmeans.add_centroids()
        self.assertEqual(step_before + 2, self.kmeans.step_no)
        self.assertEqual(self.kmeans.centroids_moved, False)

    def test_delete_centroids(self):
        self.kmeans.add_centroids([[6.5, 3], [7, 4], [5.2, 3.1]])
        self.kmeans.delete_centroids(1)
        self.assertEqual(self.kmeans.k, 2)
        self.kmeans.delete_centroids(2)
        self.assertEqual(self.kmeans.k, 0)
        self.kmeans.delete_centroids(2)
        self.assertEqual(self.kmeans.k, 0)

    def test_move_centroid(self):
        self.kmeans.add_centroids([[6.5, 3], [7, 4], [5.2, 3.1]])
        self.kmeans.move_centroid(1, 3, 3.2)
        np.testing.assert_equal(self.kmeans.centroids[1], np.array([3, 3.2]))
        self.assertEqual(self.kmeans.k, 3)

        self.kmeans.step()
        self.kmeans.move_centroid(2, 3.2, 3.4)
        self.assertEqual(self.kmeans.centroids_moved, False)
        self.assertEqual(self.kmeans.step_no, 2)

    def test_set_list(self):
        # test adding Nones if list too short
        self.assertEqual(self.kmeans.set_list([], 2, 1), [None, None, 1])
        # test adding Nones if list too short
        self.assertEqual(self.kmeans.set_list([2], 2, 1), [2, None, 1])
        # adding to end
        self.assertEqual(self.kmeans.set_list([2, 1], 2, 1), [2, 1, 1])
        # changing the element in the last place
        self.assertEqual(self.kmeans.set_list([2, 1], 1, 3), [2, 3])
        # changing the element in the middle place
        self.assertEqual(self.kmeans.set_list([2, 1, 3], 1, 3), [2, 3, 3])
class TestKmeans(unittest.TestCase):

    def setUp(self):
        self.data = Table('iris')
        new_domain = Domain(self.data.domain.attributes[:2])
        self.data = Table(new_domain, self.data)
        # self.centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans = Kmeans(self.data, centroids=None)

    def test_k(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertEqual(self.kmeans.k, 3)
        self.assertEqual(self.kmeans.k, len(self.kmeans.centroids))

    def test_centroids_belonging_points(self):
        centroids = [[5.2, 3.6]]
        self.kmeans.add_centroids(centroids)

        # if only one cluster all points in 0th element of first dimension
        np.testing.assert_equal(self.kmeans.centroids_belonging_points, np.array([self.data.X]))

        # try with two clusters and less data
        self.kmeans.set_data(self.data[:3])
        self.kmeans.add_centroids([[4.7, 3.0]])
        desired_array = np.array([np.array([[5.100, 3.500]]),
                                  np.array([[4.900, 3.000],
                                            [4.700, 3.200]])])
        for i, arr in enumerate(self.kmeans.centroids_belonging_points):
            np.testing.assert_equal(arr, desired_array[i])

    def test_step_completed(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertEqual(self.kmeans.step_completed, True)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, False)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, True)

    def test_set_data(self):
        self.kmeans.set_data(self.data)
        self.assertEqual(self.kmeans.data, self.data)
        self.assertEqual(self.kmeans.centroids_history, [])
        self.assertEqual(self.kmeans.stepNo, 0)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)

        # try with none data
        self.kmeans.set_data(None)
        self.assertEqual(self.kmeans.data, None)
        self.assertEqual(self.kmeans.centroids_history, [])
        self.assertEqual(self.kmeans.clusters, None)
        self.assertEqual(self.kmeans.stepNo, 0)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)

    def test_find_clusters(self):
        self.kmeans.add_centroids([[5.2, 3.6]])

        # if only one cluster all points in 0th element of first dimension
        np.testing.assert_equal(self.kmeans.find_clusters(self.kmeans.centroids), np.zeros(len(self.data)))

        # try with two clusters and less data
        self.kmeans.set_data(self.data[:3])
        self.kmeans.add_centroids([[4.7, 3.0]])
        np.testing.assert_equal(self.kmeans.find_clusters(self.kmeans.centroids), np.array([0, 1, 1]))

    def test_step(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        centroids_before = np.copy(self.kmeans.centroids)
        clusters_before = np.copy(self.kmeans.clusters)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, False)
        self.assertEqual(self.kmeans.centroids_moved, True)
        np.testing.assert_equal(centroids_before, self.kmeans.centroids_history[-1])
        # clusters doesnt change in every odd step
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)

        centroids_before = np.copy(self.kmeans.centroids)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)

        centroids_before = np.copy(self.kmeans.centroids)
        self.kmeans.step()
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)

    def test_step_back(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)

        # check if nothing happens when step = 0
        centroids_before = np.copy(self.kmeans.centroids)
        clusters_before = np.copy(self.kmeans.clusters)
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)

        # check if centroids remain in even step
        self.kmeans.step()
        self.kmeans.step()

        centroids_before = self.kmeans.centroids
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)
        self.assertEqual(self.kmeans.step_completed, False)
        self.assertEqual(self.kmeans.centroids_moved, False)

        # check if clusters remain in even step
        clusters_before = self.kmeans.clusters
        self.kmeans.step_back()
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, True)

    def test_random_positioning(self):
        self.assertEqual(self.kmeans.random_positioning(4).shape, (4, 2))
        self.assertEqual(self.kmeans.random_positioning(1).shape, (1, 2))
        self.assertEqual(self.kmeans.random_positioning(0).shape, (0,))
        self.assertEqual(self.kmeans.random_positioning(-1).shape, (0,))

    def test_add_centroids(self):
        self.kmeans.add_centroids([[5.2, 3.1]])
        self.assertEqual(self.kmeans.k, 1)
        self.kmeans.add_centroids([[6.5, 3], [7, 4]])
        self.assertEqual(self.kmeans.k, 3)
        self.kmeans.add_centroids(2)
        self.assertEqual(self.kmeans.k, 5)
        self.kmeans.add_centroids()
        self.assertEqual(self.kmeans.k, 6)

        step_before = self.kmeans.stepNo
        self.assertEqual(step_before, self.kmeans.stepNo)
        self.kmeans.step()
        self.kmeans.add_centroids()
        self.assertEqual(step_before + 2, self.kmeans.stepNo)
        self.assertEqual(self.kmeans.centroids_moved, False)

    def test_delete_centroids(self):
        self.kmeans.add_centroids([[6.5, 3], [7, 4], [5.2, 3.1]])
        self.kmeans.delete_centroids(1)
        self.assertEqual(self.kmeans.k, 2)
        self.kmeans.delete_centroids(2)
        self.assertEqual(self.kmeans.k, 0)
        self.kmeans.delete_centroids(2)
        self.assertEqual(self.kmeans.k, 0)

    def test_move_centroid(self):
        self.kmeans.add_centroids([[6.5, 3], [7, 4], [5.2, 3.1]])
        self.kmeans.move_centroid(1, 3, 3.2)
        np.testing.assert_equal(self.kmeans.centroids[1], np.array([3, 3.2]))
        self.assertEqual(self.kmeans.k, 3)

        self.kmeans.step()
        self.kmeans.move_centroid(2, 3.2, 3.4)
        self.assertEqual(self.kmeans.centroids_moved, False)
        self.assertEqual(self.kmeans.stepNo, 2)
class OWKmeans(OWWidget):
    """
    K-means widget
    """

    name = "Interactive k-Means"
    description = "Widget demonstrates working of k-means algorithm."
    keywords = ["kmeans", "clustering", "interactive"]
    icon = "icons/InteractiveKMeans.svg"
    want_main_area = False
    priority = 300

    # inputs and outputs
    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        annotated_data = Output("Annotated Data", Table, default=True)
        centroids = Output("Centroids", Table)

    class Warning(OWWidget.Warning):
        num_features = Msg(
            "Widget requires at least two numeric features with valid values")
        cluster_points = Msg(
            "The number of clusters can't exceed the number of points")

    # settings
    number_of_clusters = settings.Setting(3)
    auto_play_enabled = False
    auto_play_thread = None

    # data
    data = None
    selected_rows = None  # rows that are selected for kmeans (not nan rows)

    # selected attributes in chart
    attr_x = settings.Setting('')
    attr_y = settings.Setting('')

    # other settings
    k_means = None
    auto_play_speed = settings.Setting(1)
    lines_to_centroids = settings.Setting(True)
    graph_name = 'scatter'
    output_name = "cluster"
    STEP_BUTTONS = ["Reassign Membership", "Recompute Centroids"]
    AUTOPLAY_BUTTONS = ["Run", "Stop"]

    # colors taken from chart.options.colors in Highchart
    # (if more required check for more in chart.options.color)
    colors = [
        "#1F7ECA", "#D32525", "#28D825", "#D5861F", "#98257E", "#2227D5",
        "#D5D623", "#D31BD6", "#6A7CDB", "#78D5D4"
    ]

    # signals
    step_trigger = pyqtSignal()
    stop_auto_play_trigger = pyqtSignal()

    def __init__(self):
        super().__init__()

        # options box
        self.options_box = gui.widgetBox(self.controlArea, "Data")
        opts = dict(widget=self.options_box,
                    master=self,
                    orientation=Qt.Horizontal,
                    callback=self.restart,
                    sendSelectedValue=True,
                    maximumContentsLength=15)

        self.cbx = gui.comboBox(value='attr_x', label='X: ', **opts)
        self.cby = gui.comboBox(value='attr_y', label='Y: ', **opts)

        self.centroids_box = gui.widgetBox(self.controlArea, "Centroids")
        self.centroid_numbers_spinner = gui.spin(
            self.centroids_box,
            self,
            'number_of_clusters',
            minv=1,
            maxv=10,
            step=1,
            label='Number of centroids:',
            alignment=Qt.AlignRight,
            callback=self.number_of_clusters_change)
        self.restart_button = gui.button(self.centroids_box,
                                         self,
                                         "Randomize Positions",
                                         callback=self.restart)
        gui.separator(self.centroids_box)
        self.lines_checkbox = gui.checkBox(self.centroids_box,
                                           self,
                                           'lines_to_centroids',
                                           'Show membership lines',
                                           callback=self.complete_replot)

        # control box
        gui.separator(self.controlArea, 20, 20)
        self.step_box = gui.widgetBox(self.controlArea,
                                      "Manually step through")
        self.step_button = gui.button(self.step_box,
                                      self,
                                      self.STEP_BUTTONS[1],
                                      callback=self.step)
        self.step_back_button = gui.button(self.step_box,
                                           self,
                                           "Step Back",
                                           callback=self.step_back)

        self.run_box = gui.widgetBox(self.controlArea, "Run")

        self.auto_play_speed_spinner = gui.hSlider(self.run_box,
                                                   self,
                                                   'auto_play_speed',
                                                   label='Speed:',
                                                   minValue=0,
                                                   maxValue=1.91,
                                                   step=0.1,
                                                   intOnly=False,
                                                   createLabel=False)
        self.auto_play_button = gui.button(self.run_box,
                                           self,
                                           self.AUTOPLAY_BUTTONS[0],
                                           callback=self.auto_play)

        gui.rubber(self.controlArea)

        # disable until data loaded
        self.set_disabled_all(True)

        # graph in mainArea
        self.scatter = Scatterplot(click_callback=self.graph_clicked,
                                   drop_callback=self.centroid_dropped,
                                   xAxis_gridLineWidth=0,
                                   yAxis_gridLineWidth=0,
                                   tooltip_enabled=False,
                                   debug=False)

        # Just render an empty chart so it shows a nice 'No data to display'
        self.scatter.chart()
        self.mainArea.layout().addWidget(self.scatter)

    def concat_x_y(self):
        """
        Function takes two selected columns from data table and merge them in
        new Orange.data.Table

        Returns
        -------
        Orange.data.Table
            table with selected columns
        """
        attr_x = self.data.domain[self.attr_x]
        attr_y = self.data.domain[self.attr_y]
        cols = []
        for attr in (attr_x, attr_y):
            subset = self.data[:, attr]
            cols.append(subset.Y if subset.Y.size else subset.X)
        x = np.column_stack(cols)
        not_nan = ~np.isnan(x).any(axis=1)
        x = x[not_nan]  # remove rows with nan
        self.selected_rows = np.where(not_nan)
        domain = Domain([attr_x, attr_y])
        return Table(domain, x)

    def set_empty_plot(self):
        self.scatter.clear()

    def set_disabled_all(self, disabled):
        """
        Function disable all controls
        """
        self.options_box.setDisabled(disabled)
        self.centroids_box.setDisabled(disabled)
        self.step_box.setDisabled(disabled)
        self.run_box.setDisabled(disabled)

    @Inputs.data
    def set_data(self, data):
        """
        Function receives data from input and init part of widget if data are
        ok. Otherwise set empty plot and notice
        user about that

        Parameters
        ----------
        data : Orange.data.Table or None
            input data
        """
        self.data = data

        def get_valid_attributes(data):
            attrs = [
                var for var in data.domain.attributes if var.is_continuous
            ]
            return [var for var in attrs if sum(~np.isnan(data[:, var])) > 0]

        def reset_combos():
            self.cbx.clear()
            self.cby.clear()

        def init_combos():
            """
            function initialize the combos with attributes
            """
            reset_combos()
            valid_class_vars = [
                var for var in data.domain.class_vars
                if data is not None and var.is_continuous
            ]
            for var in chain(valid_attributes, valid_class_vars):
                self.cbx.addItem(gui.attributeIconDict[var], var.name)
                self.cby.addItem(gui.attributeIconDict[var], var.name)

        # remove warnings about too less continuous attributes and not enough data
        self.Warning.clear()

        if self.auto_play_thread:
            self.auto_play_thread.stop()

        if data is None or len(data) == 0:
            reset_combos()
            self.set_empty_plot()
            self.set_disabled_all(True)
            return

        valid_attributes = get_valid_attributes(data)

        if len(valid_attributes) < 2:
            reset_combos()
            self.Warning.num_features()
            self.set_empty_plot()
            self.set_disabled_all(True)
        else:
            init_combos()
            self.set_disabled_all(False)
            self.attr_x = self.cbx.itemText(0)
            self.attr_y = self.cbx.itemText(1)
            if self.k_means is None:
                self.k_means = Kmeans(self.concat_x_y())
            else:
                self.k_means.set_data(self.concat_x_y())
            self.number_of_clusters_change()

    def restart(self):
        """
        Function triggered on data change or restart button pressed
        """
        self.k_means = Kmeans(self.concat_x_y())
        self.number_of_clusters_change()

    def step(self):
        """
        Function called on every step
        """
        self.k_means.step()
        self.replot()
        self.button_text_change()
        self.send_data()

    def step_back(self):
        """
        Function called for step back
        """
        self.k_means.step_back()
        self.replot()
        self.button_text_change()
        self.send_data()
        self.number_of_clusters = self.k_means.k

    def button_text_change(self):
        """
        Function changes text on ste button and chanbe the button text
        """
        self.step_button.setText(
            self.STEP_BUTTONS[self.k_means.step_completed])
        if self.k_means.step_no <= 0:
            self.step_back_button.setDisabled(True)
        elif not self.auto_play_enabled:
            self.step_back_button.setDisabled(False)

    def auto_play(self):
        """
        Function called when autoplay button pressed
        """
        self.auto_play_enabled = not self.auto_play_enabled
        self.auto_play_button.setText(
            self.AUTOPLAY_BUTTONS[self.auto_play_enabled])
        if self.auto_play_enabled:
            self.options_box.setDisabled(True)
            self.centroids_box.setDisabled(True)
            self.step_box.setDisabled(True)
            self.auto_play_thread = Autoplay(self)
            self.step_trigger.connect(self.step)
            self.stop_auto_play_trigger.connect(self.stop_auto_play)
            self.auto_play_thread.start()
        else:
            self.stop_auto_play()

    def stop_auto_play(self):
        """
        Called when stop autoplay button pressed or in the end of autoplay
        """
        self.options_box.setDisabled(False)
        self.centroids_box.setDisabled(False)
        self.step_box.setDisabled(False)
        self.auto_play_enabled = False
        self.auto_play_button\
            .setText(self.AUTOPLAY_BUTTONS[self.auto_play_enabled])
        self.button_text_change()

    def replot(self):
        """
        Function refreshes the chart
        """
        if self.data is None or not self.attr_x or not self.attr_y:
            return

        km = self.k_means
        if not km.centroids_moved:
            self.complete_replot()
            return

        # when centroids moved during step
        self.scatter.update_series(0, self.k_means.centroids)

        if self.lines_to_centroids:
            for i, (c, pts) in enumerate(
                    zip(km.centroids, km.centroids_belonging_points)):
                self.scatter.update_series(
                    1 + i,
                    list(
                        chain.from_iterable(
                            ([p[0], p[1]], [c[0], c[1]]) for p in pts)))

    def complete_replot(self):
        """
        This function performs complete replot of the graph without animation
        """
        try:
            attr_x = self.data.domain[self.attr_x]
            attr_y = self.data.domain[self.attr_y]
        except KeyError:
            return

        # plot centroids
        options = dict(series=[])
        n_colors = len(self.colors)
        km = self.k_means
        options['series'].append(
            dict(data=[{
                'x': p[0],
                'y': p[1],
                'marker': {
                    'fillColor': self.colors[i % n_colors]
                }
            } for i, p in enumerate(km.centroids)],
                 type="scatter",
                 draggableX=True,
                 draggableY=True,
                 cursor="move",
                 zIndex=10,
                 marker=dict(symbol='square', radius=8)))

        # plot lines between centroids and points
        if self.lines_to_centroids:
            for i, (c, pts) in enumerate(
                    zip(km.centroids, km.centroids_belonging_points)):
                options['series'].append(
                    dict(data=list(
                        chain.from_iterable(
                            ([p[0], p[1]], [c[0], c[1]]) for p in pts)),
                         type="line",
                         lineWidth=0.2,
                         enableMouseTracking=False,
                         color="#ccc"))

        # plot data points
        for i, points in enumerate(km.centroids_belonging_points):
            options['series'].append(
                dict(data=points,
                     type="scatter",
                     color=rgb_hash_brighter(self.colors[i % len(self.colors)],
                                             0.3)))

        # highcharts parameters
        kwargs = dict(
            xAxis_title_text=attr_x.name,
            yAxis_title_text=attr_y.name,
            tooltip_headerFormat="",
            tooltip_pointFormat="<strong>%s:</strong> {point.x:.2f} <br/>"
            "<strong>%s:</strong> {point.y:.2f}" % (self.attr_x, self.attr_y))

        # plot
        self.scatter.chart(options, **kwargs)

    def replot_series(self):
        """
        This function replot just series connected with centroids and
        uses animation for that
        """
        km = self.k_means
        k = km.k

        series = []
        # plot lines between centroids and points
        if self.lines_to_centroids:
            for i, (c, pts) in enumerate(
                    zip(km.centroids, km.centroids_belonging_points)):
                series.append(
                    dict(data=list(
                        chain.from_iterable(
                            ([p[0], p[1]], [c[0], c[1]]) for p in pts)),
                         type="line",
                         showInLegend=False,
                         lineWidth=0.2,
                         enableMouseTracking=False,
                         color="#ccc"))

        # plot data points
        for i, points in enumerate(km.centroids_belonging_points):
            series.append(
                dict(data=points,
                     type="scatter",
                     showInLegend=False,
                     color=rgb_hash_brighter(self.colors[i % len(self.colors)],
                                             0.5)))

        self.scatter.add_series(series)

        self.scatter.remove_last_series(k *
                                        2 if self.lines_to_centroids else k)

    def number_of_clusters_change(self):
        """
        Function that change number of clusters if required
        """
        if self.data is None:
            return
        if self.number_of_clusters > len(self.data):
            # if too less data for clusters number
            self.Warning.cluster_points()
            self.set_empty_plot()
            self.step_box.setDisabled(True)
            self.run_box.setDisabled(True)
        else:
            self.Warning.cluster_points.clear()
            self.step_box.setDisabled(False)
            self.run_box.setDisabled(False)
            if self.k_means is None:  # if before too less data k_means is None
                self.k_means = Kmeans(self.concat_x_y())
            if self.k_means.k < self.number_of_clusters:
                self.k_means.add_centroids(self.number_of_clusters -
                                           self.k_means.k)
            elif not self.k_means.k == self.number_of_clusters:
                self.k_means.delete_centroids(self.k_means.k -
                                              self.number_of_clusters)
            self.replot()
            self.send_data()
        self.button_text_change()

    def graph_clicked(self, x, y):
        """
        Function called when user click in graph. Centroid have to be added.
        """
        if self.k_means is not None and self.data is not None:
            self.k_means.add_centroids([x, y])
            self.number_of_clusters += 1
            self.replot()
            self.send_data()
            self.button_text_change()

    def centroid_dropped(self, _index, x, y):
        """
        Function called when centroid with _index moved.
        """
        self.k_means.move_centroid(_index, x, y)
        self.complete_replot()
        self.send_data()
        self.button_text_change()

    def send_data(self):
        """
        Function sends data with clusters column and data with centroids
        position to the output
        """
        km = self.k_means
        if km is None or km.clusters is None:
            self.Outputs.annotated_data.send(None)
            self.Outputs.centroids.send(None)
        else:
            clust_var = DiscreteVariable(
                self.output_name,
                values=["C%d" % (x + 1) for x in range(km.k)])
            attributes = self.data.domain.attributes
            classes = self.data.domain.class_vars
            meta_attrs = self.data.domain.metas
            if classes:
                meta_attrs += classes
            classes = [clust_var]
            domain = Domain(attributes, classes, meta_attrs)
            annotated_data = Table.from_table(domain, self.data)
            annotated_data.Y[self.selected_rows] = km.clusters

            centroids = Table(Domain(km.data.domain.attributes), km.centroids)
            self.Outputs.annotated_data.send(annotated_data)
            self.Outputs.centroids.send(centroids)

    def send_report(self):
        if self.data is None:
            return
        caption = report.render_items_vert(
            (("Number of centroids:", self.number_of_clusters), ))
        self.report_plot(self.scatter)
        self.report_caption(caption)
class TestKmeans(unittest.TestCase):

    def setUp(self):
        self.data = Table('iris')
        new_domain = Domain(self.data.domain.attributes[:2])
        self.data = Table(new_domain, self.data)
        # self.centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans = Kmeans(self.data)

    def test_k(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertEqual(self.kmeans.k, 3)
        self.assertEqual(self.kmeans.k, len(self.kmeans.centroids))

    def test_converged(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertFalse(self.kmeans.converged)

        self.kmeans.step()
        self.assertFalse(self.kmeans.converged)
        # step not complete so false every odd state

        self.kmeans.step()
        # it is even step so maybe converged but it depends on example
        # unable to test
        self.kmeans.step()
        self.assertFalse(self.kmeans.converged)

        # check if false every not completed step
        for i in range(self.kmeans.max_iter // 2 + 1):
            self.kmeans.step()
            self.kmeans.step()
            self.assertFalse(self.kmeans.converged)

        # it converged because of max iter
        self.kmeans.step()
        self.assertTrue(self.kmeans.converged)


    def test_centroids_belonging_points(self):
        centroids = [[5.2, 3.6]]
        self.kmeans.add_centroids(centroids)

        # if only one cluster all points in 0th element of first dimension
        np.testing.assert_equal(
            self.kmeans.centroids_belonging_points, np.array([self.data.X]))

        # try with two clusters and less data
        self.kmeans.set_data(self.data[:3])
        self.kmeans.add_centroids([[4.7, 3.0]])
        desired_array = np.array([np.array([[5.100, 3.500]]),
                                  np.array([[4.900, 3.000],
                                            [4.700, 3.200]])])
        for i, arr in enumerate(self.kmeans.centroids_belonging_points):
            np.testing.assert_equal(arr, desired_array[i])

    def test_step_completed(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        self.assertEqual(self.kmeans.step_completed, True)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, False)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, True)

    def test_set_data(self):
        self.kmeans.set_data(self.data)
        self.assertEqual(self.kmeans.data, self.data)
        self.assertEqual(self.kmeans.centroids_history, [])
        self.assertEqual(self.kmeans.step_no, 0)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)

        # try with none data
        self.kmeans.set_data(None)
        self.assertEqual(self.kmeans.data, None)
        self.assertEqual(self.kmeans.centroids_history, [])
        self.assertEqual(self.kmeans.clusters, None)
        self.assertEqual(self.kmeans.step_no, 0)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)

    def test_find_clusters(self):
        self.kmeans.add_centroids([[5.2, 3.6]])

        # if only one cluster all points in 0th element of first dimension
        np.testing.assert_equal(
            self.kmeans.find_clusters(self.kmeans.centroids),
            np.zeros(len(self.data)))

        # try with two clusters and less data
        self.kmeans.set_data(self.data[:3])
        self.kmeans.add_centroids([[4.7, 3.0]])
        np.testing.assert_equal(
            self.kmeans.find_clusters(self.kmeans.centroids),
            np.array([0, 1, 1]))

    def test_step(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)
        centroids_before = np.copy(self.kmeans.centroids)
        clusters_before = np.copy(self.kmeans.clusters)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, False)
        self.assertEqual(self.kmeans.centroids_moved, True)
        np.testing.assert_equal(centroids_before,
                                self.kmeans.centroids_history[-2])
        # clusters doesnt change in every odd step
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)

        centroids_before = np.copy(self.kmeans.centroids)
        self.kmeans.step()
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, False)
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)

        centroids_before = np.copy(self.kmeans.centroids)
        self.kmeans.step()
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)

    def test_step_back(self):
        centroids = [[5.2, 3.1], [6.5, 3], [7, 4]]
        self.kmeans.add_centroids(centroids)

        # check if nothing happens when step = 0
        centroids_before = np.copy(self.kmeans.centroids)
        clusters_before = np.copy(self.kmeans.clusters)
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)

        # check if centroids remain in even step
        self.kmeans.step()
        self.kmeans.step()

        centroids_before = self.kmeans.centroids
        self.kmeans.step_back()
        np.testing.assert_equal(centroids_before, self.kmeans.centroids)
        self.assertEqual(self.kmeans.step_completed, False)
        self.assertEqual(self.kmeans.centroids_moved, False)

        # check if clusters remain in even step
        clusters_before = self.kmeans.clusters
        self.kmeans.step_back()
        np.testing.assert_equal(clusters_before, self.kmeans.clusters)
        self.assertEqual(self.kmeans.step_completed, True)
        self.assertEqual(self.kmeans.centroids_moved, True)

    def test_random_positioning(self):
        self.assertEqual(self.kmeans.random_positioning(4).shape, (4, 2))
        self.assertEqual(self.kmeans.random_positioning(1).shape, (1, 2))
        self.assertEqual(self.kmeans.random_positioning(0).shape, (0,))
        self.assertEqual(self.kmeans.random_positioning(-1).shape, (0,))

    def test_add_centroids(self):
        self.kmeans.add_centroids([[5.2, 3.1]])
        self.assertEqual(self.kmeans.k, 1)
        self.kmeans.add_centroids([[6.5, 3], [7, 4]])
        self.assertEqual(self.kmeans.k, 3)
        self.kmeans.add_centroids(2)
        self.assertEqual(self.kmeans.k, 5)
        self.kmeans.add_centroids()
        self.assertEqual(self.kmeans.k, 6)

        step_before = self.kmeans.step_no
        self.assertEqual(step_before, self.kmeans.step_no)
        self.kmeans.step()
        self.kmeans.add_centroids()
        self.assertEqual(step_before + 2, self.kmeans.step_no)
        self.assertEqual(self.kmeans.centroids_moved, False)

    def test_delete_centroids(self):
        self.kmeans.add_centroids([[6.5, 3], [7, 4], [5.2, 3.1]])
        self.kmeans.delete_centroids(1)
        self.assertEqual(self.kmeans.k, 2)
        self.kmeans.delete_centroids(2)
        self.assertEqual(self.kmeans.k, 0)
        self.kmeans.delete_centroids(2)
        self.assertEqual(self.kmeans.k, 0)

    def test_move_centroid(self):
        self.kmeans.add_centroids([[6.5, 3], [7, 4], [5.2, 3.1]])
        self.kmeans.move_centroid(1, 3, 3.2)
        np.testing.assert_equal(self.kmeans.centroids[1], np.array([3, 3.2]))
        self.assertEqual(self.kmeans.k, 3)

        self.kmeans.step()
        self.kmeans.move_centroid(2, 3.2, 3.4)
        self.assertEqual(self.kmeans.centroids_moved, False)
        self.assertEqual(self.kmeans.step_no, 2)

    def test_set_list(self):
        # test adding Nones if list too short
        self.assertEqual(self.kmeans.set_list([], 2, 1), [None, None, 1])
        # test adding Nones if list too short
        self.assertEqual(self.kmeans.set_list([2], 2, 1), [2, None, 1])
        # adding to end
        self.assertEqual(self.kmeans.set_list([2, 1], 2, 1), [2, 1, 1])
        # changing the element in the last place
        self.assertEqual(self.kmeans.set_list([2, 1], 1, 3), [2, 3])
        # changing the element in the middle place
        self.assertEqual(self.kmeans.set_list([2, 1, 3], 1, 3), [2, 3, 3])
class OWKmeans(OWWidget):
    """
    K-means widget
    """

    name = "Interactive k-Means"
    description = "Widget demonstrates working of k-means algorithm."
    icon = "icons/InteractiveKMeans.svg"
    want_main_area = False

    # inputs and outputs
    class Inputs:
        data = Input("Data", Table)

    class Outputs:
        annotated_data = Output("Annotated Data", Table, default=True)
        centroids = Output("Centroids", Table)

    class Warning(OWWidget.Warning):
        num_features = Msg("Widget requires at least two numeric features with valid values")
        cluster_points = Msg("The number of clusters can't exceed the number of points")

    # settings
    number_of_clusters = settings.Setting(3)
    auto_play_enabled = False
    auto_play_thread = None

    # data
    data = None
    selected_rows = None  # rows that are selected for kmeans (not nan rows)

    # selected attributes in chart
    attr_x = settings.Setting('')
    attr_y = settings.Setting('')

    # other settings
    k_means = None
    auto_play_speed = settings.Setting(1)
    lines_to_centroids = settings.Setting(True)
    graph_name = 'scatter'
    output_name = "cluster"
    STEP_BUTTONS = ["Reassign Membership", "Recompute Centroids"]
    AUTOPLAY_BUTTONS = ["Run", "Stop"]

    # colors taken from chart.options.colors in Highchart
    # (if more required check for more in chart.options.color)
    colors = ["#1F7ECA", "#D32525", "#28D825", "#D5861F", "#98257E",
              "#2227D5", "#D5D623", "#D31BD6", "#6A7CDB", "#78D5D4"]

    # signals
    step_trigger = pyqtSignal()
    stop_auto_play_trigger = pyqtSignal()

    def __init__(self):
        super().__init__()

        # options box
        self.options_box = gui.widgetBox(self.controlArea, "Data")
        opts = dict(
            widget=self.options_box, master=self, orientation=Qt.Horizontal,
            callback=self.restart, sendSelectedValue=True,
            maximumContentsLength=15)

        self.cbx = gui.comboBox(value='attr_x', label='X: ', **opts)
        self.cby = gui.comboBox(value='attr_y', label='Y: ', **opts)

        self.centroids_box = gui.widgetBox(self.controlArea, "Centroids")
        self.centroid_numbers_spinner = gui.spin(
            self.centroids_box, self, 'number_of_clusters',
            minv=1, maxv=10, step=1, label='Number of centroids:',
            alignment=Qt.AlignRight, callback=self.number_of_clusters_change)
        self.restart_button = gui.button(
            self.centroids_box, self, "Randomize Positions",
            callback=self.restart)
        gui.separator(self.centroids_box)
        self.lines_checkbox = gui.checkBox(
            self.centroids_box, self, 'lines_to_centroids',
            'Show membership lines', callback=self.complete_replot)

        # control box
        gui.separator(self.controlArea, 20, 20)
        self.step_box = gui.widgetBox(self.controlArea, "Manually step through")
        self.step_button = gui.button(
            self.step_box, self, self.STEP_BUTTONS[1], callback=self.step)
        self.step_back_button = gui.button(
            self.step_box, self, "Step Back", callback=self.step_back)

        self.run_box = gui.widgetBox(self.controlArea, "Run")
        
        self.auto_play_speed_spinner = gui.hSlider(
            self.run_box, self, 'auto_play_speed', label='Speed:',
            minValue=0, maxValue=1.91, step=0.1, intOnly=False,
            createLabel=False)
        self.auto_play_button = gui.button(
            self.run_box, self, self.AUTOPLAY_BUTTONS[0],
            callback=self.auto_play)

        gui.rubber(self.controlArea)

        # disable until data loaded
        self.set_disabled_all(True)

        # graph in mainArea
        self.scatter = Scatterplot(
            click_callback=self.graph_clicked,
            drop_callback=self.centroid_dropped,
            xAxis_gridLineWidth=0, yAxis_gridLineWidth=0,
            tooltip_enabled=False,
            debug=False)

        # Just render an empty chart so it shows a nice 'No data to display'
        self.scatter.chart()
        self.mainArea.layout().addWidget(self.scatter)

    def concat_x_y(self):
        """
        Function takes two selected columns from data table and merge them in
        new Orange.data.Table

        Returns
        -------
        Orange.data.Table
            table with selected columns
        """
        attr_x = self.data.domain[self.attr_x]
        attr_y = self.data.domain[self.attr_y]
        cols = []
        for attr in (attr_x, attr_y):
            subset = self.data[:, attr]
            cols.append(subset.Y if subset.Y.size else subset.X)
        x = np.column_stack(cols)
        not_nan = ~np.isnan(x).any(axis=1)
        x = x[not_nan]  # remove rows with nan
        self.selected_rows = np.where(not_nan)
        domain = Domain([attr_x, attr_y])
        return Table(domain, x)

    def set_empty_plot(self):
        self.scatter.clear()

    def set_disabled_all(self, disabled):
        """
        Function disable all controls
        """
        self.options_box.setDisabled(disabled)
        self.centroids_box.setDisabled(disabled)
        self.step_box.setDisabled(disabled)
        self.run_box.setDisabled(disabled)

    @Inputs.data
    def set_data(self, data):
        """
        Function receives data from input and init part of widget if data are
        ok. Otherwise set empty plot and notice
        user about that

        Parameters
        ----------
        data : Orange.data.Table or None
            input data
        """
        self.data = data

        def get_valid_attributes(data):
            attrs = [var for var in data.domain.attributes if var.is_continuous]
            return [var for var in attrs if sum(~np.isnan(data[:, var])) > 0]

        def reset_combos():
            self.cbx.clear()
            self.cby.clear()

        def init_combos():
            """
            function initialize the combos with attributes
            """
            reset_combos()
            valid_class_vars = [var for var in data.domain.class_vars
                                if data is not None and var.is_continuous]
            for var in chain(valid_attributes, valid_class_vars):
                self.cbx.addItem(gui.attributeIconDict[var], var.name)
                self.cby.addItem(gui.attributeIconDict[var], var.name)

        # remove warnings about too less continuous attributes and not enough data
        self.Warning.clear()

        if self.auto_play_thread:
            self.auto_play_thread.stop()

        if data is None or len(data) == 0:
            reset_combos()
            self.set_empty_plot()
            self.set_disabled_all(True)
            return

        valid_attributes = get_valid_attributes(data)

        if len(valid_attributes) < 2:
            reset_combos()
            self.Warning.num_features()
            self.set_empty_plot()
            self.set_disabled_all(True)
        else:
            init_combos()
            self.set_disabled_all(False)
            self.attr_x = self.cbx.itemText(0)
            self.attr_y = self.cbx.itemText(1)
            if self.k_means is None:
                self.k_means = Kmeans(self.concat_x_y())
            else:
                self.k_means.set_data(self.concat_x_y())
            self.number_of_clusters_change()

    def restart(self):
        """
        Function triggered on data change or restart button pressed
        """
        self.k_means = Kmeans(self.concat_x_y())
        self.number_of_clusters_change()

    def step(self):
        """
        Function called on every step
        """
        self.k_means.step()
        self.replot()
        self.button_text_change()
        self.send_data()

    def step_back(self):
        """
        Function called for step back
        """
        self.k_means.step_back()
        self.replot()
        self.button_text_change()
        self.send_data()
        self.number_of_clusters = self.k_means.k

    def button_text_change(self):
        """
        Function changes text on ste button and chanbe the button text
        """
        self.step_button.setText(self.STEP_BUTTONS[self.k_means.step_completed])
        if self.k_means.step_no <= 0:
            self.step_back_button.setDisabled(True)
        elif not self.auto_play_enabled:
            self.step_back_button.setDisabled(False)

    def auto_play(self):
        """
        Function called when autoplay button pressed
        """
        self.auto_play_enabled = not self.auto_play_enabled
        self.auto_play_button.setText(
            self.AUTOPLAY_BUTTONS[self.auto_play_enabled])
        if self.auto_play_enabled:
            self.options_box.setDisabled(True)
            self.centroids_box.setDisabled(True)
            self.step_box.setDisabled(True)
            self.auto_play_thread = Autoplay(self)
            self.step_trigger.connect(self.step)
            self.stop_auto_play_trigger.connect(self.stop_auto_play)
            self.auto_play_thread.start()
        else:
            self.stop_auto_play()

    def stop_auto_play(self):
        """
        Called when stop autoplay button pressed or in the end of autoplay
        """
        self.options_box.setDisabled(False)
        self.centroids_box.setDisabled(False)
        self.step_box.setDisabled(False)
        self.auto_play_enabled = False
        self.auto_play_button\
            .setText(self.AUTOPLAY_BUTTONS[self.auto_play_enabled])
        self.button_text_change()

    def replot(self):
        """
        Function refreshes the chart
        """
        if self.data is None or not self.attr_x or not self.attr_y:
            return

        km = self.k_means
        if not km.centroids_moved:
            self.complete_replot()
            return

        # when centroids moved during step
        self.scatter.update_series(0, self.k_means.centroids)

        if self.lines_to_centroids:
            for i, (c, pts) in enumerate(zip(
                    km.centroids, km.centroids_belonging_points)):
                self.scatter.update_series(1 + i, list(chain.from_iterable(
                    ([p[0], p[1]], [c[0], c[1]])
                    for p in pts)))

    def complete_replot(self):
        """
        This function performs complete replot of the graph without animation
        """
        try:
            attr_x = self.data.domain[self.attr_x]
            attr_y = self.data.domain[self.attr_y]
        except KeyError:
            return

        # plot centroids
        options = dict(series=[])
        n_colors = len(self.colors)
        km = self.k_means
        options['series'].append(
            dict(
                data=[{'x': p[0], 'y': p[1],
                       'marker':{'fillColor': self.colors[i % n_colors]}}
                      for i, p in enumerate(km.centroids)],
                type="scatter",
                draggableX=True,
                draggableY=True,
                cursor="move",
                zIndex=10,
                marker=dict(symbol='square', radius=8)))

        # plot lines between centroids and points
        if self.lines_to_centroids:
            for i, (c, pts) in enumerate(zip(
                    km.centroids, km.centroids_belonging_points)):
                options['series'].append(dict(
                    data=list(
                        chain.from_iterable(([p[0], p[1]], [c[0], c[1]])
                                            for p in pts)),
                    type="line",
                    lineWidth=0.2,
                    enableMouseTracking=False,
                    color="#ccc"))

        # plot data points
        for i, points in enumerate(km.centroids_belonging_points):
            options['series'].append(dict(
                data=points,
                type="scatter",
                color=rgb_hash_brighter(
                    self.colors[i % len(self.colors)], 0.3)))

        # highcharts parameters
        kwargs = dict(
            xAxis_title_text=attr_x.name,
            yAxis_title_text=attr_y.name,
            tooltip_headerFormat="",
            tooltip_pointFormat="<strong>%s:</strong> {point.x:.2f} <br/>"
                                "<strong>%s:</strong> {point.y:.2f}" %
                                (self.attr_x, self.attr_y))

        # plot
        self.scatter.chart(options, **kwargs)

    def replot_series(self):
        """
        This function replot just series connected with centroids and
        uses animation for that
        """
        km = self.k_means
        k = km.k

        series = []
        # plot lines between centroids and points
        if self.lines_to_centroids:
            for i, (c, pts) in enumerate(zip(
                    km.centroids, km.centroids_belonging_points)):
                series.append(dict(
                   data=list(
                       chain.from_iterable(([p[0], p[1]], [c[0], c[1]])
                                           for p in pts)),
                   type="line",
                   showInLegend=False,
                   lineWidth=0.2,
                   enableMouseTracking=False,
                   color="#ccc"))

        # plot data points
        for i, points in enumerate(km.centroids_belonging_points):
            series.append(dict(
                data=points,
                type="scatter",
                showInLegend=False,
                color=rgb_hash_brighter(
                    self.colors[i % len(self.colors)], 0.5)))

        self.scatter.add_series(series)

        self.scatter.remove_last_series(k * 2 if self.lines_to_centroids else k)

    def number_of_clusters_change(self):
        """
        Function that change number of clusters if required
        """
        if self.data is None:
            return
        if self.number_of_clusters > len(self.data):
            # if too less data for clusters number
            self.Warning.cluster_points()
            self.set_empty_plot()
            self.step_box.setDisabled(True)
            self.run_box.setDisabled(True)
        else:
            self.Warning.cluster_points.clear()
            self.step_box.setDisabled(False)
            self.run_box.setDisabled(False)
            if self.k_means is None:  # if before too less data k_means is None
                self.k_means = Kmeans(self.concat_x_y())
            if self.k_means.k < self.number_of_clusters:
                self.k_means.add_centroids(
                    self.number_of_clusters - self.k_means.k)
            elif not self.k_means.k == self.number_of_clusters:
                self.k_means.delete_centroids(
                    self.k_means.k - self.number_of_clusters)
            self.replot()
            self.send_data()
        self.button_text_change()

    def graph_clicked(self, x, y):
        """
        Function called when user click in graph. Centroid have to be added.
        """
        if self.k_means is not None and self.data is not None:
            self.k_means.add_centroids([x, y])
            self.number_of_clusters += 1
            self.replot()
            self.send_data()
            self.button_text_change()

    def centroid_dropped(self, _index, x, y):
        """
        Function called when centroid with _index moved.
        """
        self.k_means.move_centroid(_index, x, y)
        self.complete_replot()
        self.send_data()
        self.button_text_change()

    def send_data(self):
        """
        Function sends data with clusters column and data with centroids
        position to the output
        """
        km = self.k_means
        if km is None or km.clusters is None:
            self.Outputs.annotated_data.send(None)
            self.Outputs.centroids.send(None)
        else:
            clust_var = DiscreteVariable(
                self.output_name,
                values=["C%d" % (x + 1) for x in range(km.k)])
            attributes = self.data.domain.attributes
            classes = self.data.domain.class_vars
            meta_attrs = self.data.domain.metas
            if classes:
                meta_attrs += classes
            classes = [clust_var]
            domain = Domain(attributes, classes, meta_attrs)
            annotated_data = Table.from_table(domain, self.data)
            annotated_data.Y[self.selected_rows] = km.clusters

            centroids = Table(Domain(km.data.domain.attributes), km.centroids)
            self.Outputs.annotated_data.send(annotated_data)
            self.Outputs.centroids.send(centroids)

    def send_report(self):
        if self.data is None:
            return
        caption = report.render_items_vert((
             ("Number of centroids:", self.number_of_clusters),
        ))
        self.report_plot(self.scatter)
        self.report_caption(caption)
Beispiel #6
0
class OWKmeans(OWWidget):
    """
    K-means widget
    """

    name = "Interactive k-Means"
    description = "Widget demonstrates working of k-means algorithm."
    icon = "icons/mywidget.svg"
    want_main_area = False

    # inputs and outputs
    inputs = [("Data", Orange.data.Table, "set_data")]
    outputs = [("Annotated Data", Table, widget.Default),
               ("Centroids", Table)]

    # settings
    numberOfClusters = settings.Setting(1)
    autoPlay = False

    # data
    data = None

    # selected attributes in chart
    attr_x = settings.Setting('')
    attr_y = settings.Setting('')

    # other settings
    k_means = None
    autoPlaySpeed = settings.Setting(1)
    lines_to_centroids = settings.Setting(0)
    graph_name = 'scatter'
    outputName = "cluster"
    button_labels = {"step1": "Reassign membership",
                     "step2": "Recompute centroids",
                     "step_back": "Step back",
                     "autoplay_run": "Run",
                     "autoplay_stop": "Stop",
                     "random_centroids": "Randomize"}
    colors = ['#2f7ed8', '#0d233a', '#8bbc21', '#910000', '#1aadce',
              '#492970', '#f28f43', '#77a1e5', '#c42525', '#a6c96a']

    def __init__(self):
        super().__init__()

        # options box
        self.optionsBox = gui.widgetBox(self.controlArea)
        self.cbx = gui.comboBox(self.optionsBox, self, 'attr_x',
                                label='X:',
                                orientation=Qt.Horizontal,
                                callback=self.restart,
                                sendSelectedValue=True)
        self.cbx.setSizePolicy(QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed))
        self.cby = gui.comboBox(self.optionsBox, self, 'attr_y',
                                label='Y:',
                                orientation='horizontal',
                                callback=self.restart,
                                sendSelectedValue=True)
        self.cby.setSizePolicy(QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed))

        self.centroidsBox = gui.widgetBox(self.controlArea, "Centroids")
        self.centroidNumbersSpinner = gui.spin(self.centroidsBox,
                                               self,
                                               'numberOfClusters',
                                               minv=1,
                                               maxv=10,
                                               step=1,
                                               label='Number of centroids:',
                                               callback=self.number_of_clusters_change)
        self.centroidNumbersSpinner.setSizePolicy(QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed))
        self.restartButton = gui.button(self.centroidsBox, self, self.button_labels["random_centroids"],
                                        callback=self.restart)
        self.linesCheckbox = gui.checkBox(self.centroidsBox,
                                          self,
                                          'lines_to_centroids',
                                          'Show membership lines',
                                          callback=self.complete_replot)

        # control box
        self.commandsBox = gui.widgetBox(self.controlArea)
        self.stepButton = gui.button(self.commandsBox, self, self.button_labels["step2"],
                                     callback=self.step)
        self.stepBackButton = gui.button(self.commandsBox, self, self.button_labels["step_back"],
                                         callback=self.step_back)
        self.autoPlayButton = gui.button(self.commandsBox, self, self.button_labels["autoplay_run"],
                                         callback=self.auto_play)
        self.autoPlaySpeedSpinner = gui.hSlider(self.commandsBox,
                                                self,
                                                'autoPlaySpeed',
                                                minValue=0,
                                                maxValue=1.91,
                                                step=0.1,
                                                intOnly=False,
                                                createLabel=False,
                                                label='Speed:')

        gui.rubber(self.controlArea)

        # disable until data loaded
        self.set_disabled_all(True)

        # graph in mainArea
        self.scatter = Scatterplot(click_callback=self.graph_clicked,
                                   drop_callback=self.centroid_dropped,
                                   xAxis_gridLineWidth=0,
                                   yAxis_gridLineWidth=0,
                                   title_text='',
                                   tooltip_shared=False,
                                   debug=True)  # TODO: set false when end of development
        # Just render an empty chart so it shows a nice 'No data to display'
        self.scatter.chart()
        self.mainArea.layout().addWidget(self.scatter)

    def concat_x_y(self):
        """
        Function takes two selected columns from data table and merge them in new Orange.data.Table
        :return: table with selected columns
        :type: Orange.data.Table
        """
        attr_x, attr_y = self.data.domain[self.attr_x], self.data.domain[self.attr_y]
        cols = []
        for attr in (attr_x, attr_y):
            subset = self.data[:, attr]
            cols.append(subset.Y if subset.Y.size else subset.X)
        x = np.column_stack(cols)
        domain = Domain([attr_x, attr_y])
        return Table(domain, x)

    def set_empty_plot(self):
        self.scatter.clear()

    def set_disabled_all(self, disabled):
        """
        Function disable all controls
        """
        self.optionsBox.setDisabled(disabled)
        self.centroidsBox.setDisabled(disabled)
        self.commandsBox.setDisabled(disabled)

    def set_data(self, data):
        """
        Function receives data from input and init part of widget if data are ok. Otherwise set empty plot and notice
        user about that
        :param data: input data
        :type data: Orange.data.Table or None
        """
        self.data = data

        def reset_combos():
            self.cbx.clear()
            self.cby.clear()

        def init_combos():
            """
            function initialize the combos with attributes
            """
            reset_combos()
            for var in data.domain if data is not None else []:
                if var.is_primitive() and var.is_continuous:
                    self.cbx.addItem(gui.attributeIconDict[var], var.name)
                    self.cby.addItem(gui.attributeIconDict[var], var.name)

        self.warning(1)  # remove warning about too less continuous attributes if exists
        self.warning(2)  # remove warning about not enough data

        if data is None or len(data) == 0:
            reset_combos()
            self.set_empty_plot()
            self.set_disabled_all(True)
        elif sum(True for var in data.domain.attributes if isinstance(var, ContinuousVariable)) < 2:
            reset_combos()
            self.warning(1, "Too few Continuous feature. Min 2 required")
            self.set_empty_plot()
            self.set_disabled_all(True)
        else:
            init_combos()
            self.set_disabled_all(False)
            self.attr_x = self.cbx.itemText(0)
            self.attr_y = self.cbx.itemText(1)
            if self.k_means is None:
                self.k_means = Kmeans(self.concat_x_y())
            else:
                self.k_means.set_data(self.concat_x_y())
            self.modify_kmeans()


    def restart(self):
        """
        Function triggered on data change or restart button pressed
        """
        self.k_means = Kmeans(self.concat_x_y())
        self.modify_kmeans()

    def modify_kmeans(self):

        self.number_of_clusters_change()
        self.button_text_change()

    def step(self):
        """
        Function called on every step
        """
        self.k_means.step()
        self.replot()
        self.button_text_change()
        self.send_data()

    def step_back(self):
        """
        Function called for step back
        """
        self.k_means.step_back()
        self.replot()
        self.button_text_change()
        self.send_data()

    def button_text_change(self):
        """
        Function changes text on ste button and chanbe the button text
        """
        self.stepButton.setText(self.button_labels["step2"]
                                if self.k_means.step_completed
                                else self.button_labels["step1"])
        if self.k_means.stepNo <= 0:
            self.stepBackButton.setDisabled(True)
        elif not self.autoPlay:
            self.stepBackButton.setDisabled(False)

    def auto_play(self):
        """
        Function called when autoplay button pressed
        """
        self.autoPlay = not self.autoPlay
        self.autoPlayButton.setText(self.button_labels["autoplay_stop"]
                                    if self.autoPlay
                                    else self.button_labels["autoplay_run"])
        if self.autoPlay:
            self.optionsBox.setDisabled(True)
            self.centroidsBox.setDisabled(True)
            self.stepButton.setDisabled(True)
            self.stepBackButton.setDisabled(True)
            self.autoPlayThread = Autoplay(self)
            self.connect(self.autoPlayThread, SIGNAL("step()"), self.step)
            self.connect(self.autoPlayThread, SIGNAL("stop_auto_play()"), self.stop_auto_play)
            self.autoPlayThread.start()
        else:
            self.stop_auto_play()

    def stop_auto_play(self):
        """
        Called when stop autoplay button pressed or in the end of autoplay
        """
        self.optionsBox.setDisabled(False)
        self.stepButton.setDisabled(False)
        self.centroidsBox.setDisabled(False)
        self.stepBackButton.setDisabled(False)
        self.autoPlay = False
        self.autoPlayButton.setText(self.button_labelsp["autoplay_stop"]
                                    if self.autoPlay
                                    else self.button_labels["autoplay_run"])

    def replot(self):
        """
        Function refreshes the chart
        """
        if self.data is None or not self.attr_x or not self.attr_y:
            return

        if self.k_means.centroids_moved:
            # when centroids moved during step
            self.scatter.update_series(0, self.k_means.centroids)

            if self.lines_to_centroids:
                for i, c in enumerate(self.k_means.centroids):
                    self.scatter.update_series(1 + i, list(chain.from_iterable(
                        ([p[0], p[1]], [c[0], c[1]])
                        for p in self.k_means.centroids_belonging_points[i])))
        else:
            self.complete_replot()

    def complete_replot(self):
        """
        This function performs complete replot of the graph without animation
        """
        attr_x, attr_y = self.data.domain[self.attr_x], self.data.domain[self.attr_y]

        # plot centroids
        options = dict(series=[])
        options['series'].append(dict(data=[{'x': p[0],
                                             'y': p[1],
                                             'marker':{'fillColor': self.colors[i % len(self.colors)]}}
                                            for i, p in enumerate(self.k_means.centroids)],
                                      type="scatter",
                                      draggableX=True if self.k_means.step_completed else False,
                                      draggableY=True if self.k_means.step_completed else False,
                                      showInLegend=False,
                                      zIndex=10,
                                      marker=dict(symbol='diamond',
                                                  radius=10)))

        # plot lines between centroids and points
        if self.lines_to_centroids:
            for i, c in enumerate(self.k_means.centroids):
                options['series'].append(dict(data=list(
                    chain.from_iterable(([p[0], p[1]], [c[0], c[1]])
                                        for p in self.k_means.centroids_belonging_points[i])),
                                              type="line",
                                              showInLegend=False,
                                              lineWidth=0.2,
                                              enableMouseTracking=False,
                                              color="#ccc"))

        # plot data points
        for i, points in enumerate(self.k_means.centroids_belonging_points):
            options['series'].append(dict(data=points,
                                          type="scatter",
                                          showInLegend=False,
                                          color=rgb_hash_brighter(self.colors[i % len(self.colors)], 30)))

        # highcharts parameters
        kwargs = dict(
            xAxis_title_text=attr_x.name,
            yAxis_title_text=attr_y.name,
            tooltip_headerFormat="",
            tooltip_pointFormat="<strong>%s:</strong> {point.x:.2f} <br/>"
                                "<strong>%s:</strong> {point.y:.2f}" %
                                (self.attr_x, self.attr_y))

        # plot
        self.scatter.chart(options, **kwargs)

    def replot_series(self):
        """
        This function replot just series connected with centroids and uses animation for that
        """
        k = self.k_means.k

        series = []
        # plot lines between centroids and points
        if self.lines_to_centroids:
            for i, c in enumerate(self.k_means.centroids):
                series.append(dict(
                   data=list(chain.from_iterable(([p[0], p[1]], [c[0], c[1]])
                                                 for p in self.k_means.centroids_belonging_points[i])),
                   type="line",
                   showInLegend=False,
                   lineWidth=0.2,
                   enableMouseTracking=False,
                   color="#ccc"))

        # plot data points
        for i, points in enumerate(self.k_means.centroids_belonging_points):
            series.append(dict(data=points,
                               type="scatter",
                               showInLegend=False,
                               color=rgb_hash_brighter(self.colors[i % len(self.colors)], 30)))

        self.scatter.add_series(series)

        self.scatter.remove_last_series(k * 2 if self.lines_to_centroids else k)

    def number_of_clusters_change(self):
        """
        Function that change number of clusters if required
        """
        if self.numberOfClusters > len(self.data):
            # if too less data for clusters number
            self.warning(2, "Please provide at least number of points equal to "
                            "number of clusters selected or decrease number of clusters")
            self.set_empty_plot()
            self.commandsBox.setDisabled(True)
        else:
            self.warning(2)
            self.commandsBox.setDisabled(False)
            if self.k_means is None:  # if before too less data k_means is None
                self.k_means = Kmeans(self.concat_x_y())
            if self.k_means.k < self.numberOfClusters:
                self.k_means.add_centroids(self.numberOfClusters - self.k_means.k)
            elif not self.k_means.k == self.numberOfClusters:
                self.k_means.delete_centroids(self.k_means.k - self.numberOfClusters)
            self.replot()
            self.send_data()

    def graph_clicked(self, x, y):
        """
        Function called when user click in graph. Centroid have to be added.
        """
        if self.k_means is not None:
            self.k_means.add_centroids([x, y])
            self.numberOfClusters += 1
            self.replot()
            self.send_data()
            self.button_text_change()

    def centroid_dropped(self, _index, x, y):
        """
        Function called when centroid with _index moved.
        """
        self.k_means.move_centroid(_index, x, y)
        self.complete_replot()
        self.send_data()
        self.button_text_change()

    def send_data(self):
        """
        Function sends data with clusters column and data with centroids position to the output
        """
        if self.k_means is None or self.k_means.clusters is None:
            self.send("Annotated Data", None)
            self.send("Centroids", None)
        else:
            clust_var = DiscreteVariable(
                self.outputName, values=["C%d" % (x + 1) for x in range(self.k_means.k)])
            attributes, classes = self.data.domain.attributes, self.data.domain.class_vars
            meta_attrs = self.data.domain.metas
            if classes:
                meta_attrs += classes
            classes = [clust_var]
            domain = Domain(attributes, classes, meta_attrs)
            annotated_data = Table.from_table(domain, self.data)
            annotated_data.get_column_view(clust_var)[0][:] = self.k_means.clusters

            centroids = Table(Domain(self.k_means.data.domain.attributes), self.k_means.centroids)
            self.send("Annotated Data", annotated_data)
            self.send("Centroids", centroids)