Exemplo n.º 1
0
    def fit(self, series):
        """Fit a hierarchical clustering tree.

        The linkage tree is available in self.linkage.
        """
        if np is None:
            raise NumpyException(
                "The fit function requires Numpy to be installed.")
        try:
            from scipy.cluster.hierarchy import linkage
        except ImportError:
            raise ScipyException(
                "The LinkageTree class requires the scipy package to be installed."
            )
        self.series = SeriesContainer.wrap(series)
        dists = self.dists_fun(self.series, **self.dists_options)
        dists_cond = np.zeros(self._size_cond(len(series)))
        idx = 0
        for r in range(len(series) - 1):
            dists_cond[idx:idx + len(series) - r - 1] = dists[r, r + 1:]
            idx += len(series) - r - 1

        self.linkage = linkage(dists_cond,
                               method=self.method,
                               metric='euclidean')
        return self.linkage
Exemplo n.º 2
0
    def fit(self, series):
        """Merge sequences.

        :param series: Iterator over series.
        :return: Dictionary with as keys the prototype indicices and as values all the indicides of the series in
            that cluster.
        """
        if np is None:
            raise NumpyException(
                "The fit function requires Numpy to be installed.")
        nb_series = len(series)
        cluster_idx = dict()
        self.dists_options['only_triu'] = True
        dists = self.dists_fun(series, **self.dists_options)
        min_value = np.min(dists)
        min_idxs = np.argwhere(dists == min_value)
        if self.order_hook:
            min_idx = self.order_hook(min_idxs)
        else:
            min_idx = min_idxs[0, :]
        deleted = set()
        cnt_merge = 0
        logger.debug('Merging patterns')
        if self.show_progress and tqdm:
            pbar = tqdm(total=dists.shape[0])
        else:
            pbar = None
        # Hierarchical clustering (distance to prototype)
        while min_value <= self.max_dist:
            cnt_merge += 1
            i1, i2 = int(min_idx[0]), int(min_idx[1])
            if self.merge_hook:
                result = self.merge_hook(i2, i1, min_value)
                if result:
                    i1, i2 = result
            logger.debug("Merge {} <- {} ({:.3f})".format(i1, i2, min_value))
            if i1 not in cluster_idx:
                cluster_idx[i1] = {i1}
            if i2 in cluster_idx:
                cluster_idx[i1].update(cluster_idx[i2])
                del cluster_idx[i2]
            else:
                cluster_idx[i1].add(i2)
            # if recompute:
            #     for r in range(i1):
            #         if r not in deleted and abs(len(cur_seqs[r]) - len(cur_seqs[i1])) <= max_length_diff:
            #             dists[r, i1] = self.dist(cur_seqs[r], cur_seqs[i1], **dist_opts)
            #     for c in range(i1+1, len(cur_seqs)):
            #         if c not in deleted and abs(len(cur_seqs[i1]) - len(cur_seqs[c])) <= max_length_diff:
            #             dists[i1, c] = self.dist(cur_seqs[i1], cur_seqs[c], **dist_opts)
            for r in range(i2):
                dists[r, i2] = np.inf
            for c in range(i2 + 1, len(series)):
                dists[i2, c] = np.inf
            deleted.add(i2)
            if len(deleted) == nb_series - 1:
                break
            if pbar:
                pbar.update(1)
            # min_idx = np.unravel_index(np.argmin(dists), dists.shape)
            # min_value = dists[min_idx]
            min_value = np.min(dists)
            # if np.isinf(min_value):
            #     break
            min_idxs = np.argwhere(dists == min_value)
            if self.order_hook:
                min_idx = self.order_hook(min_idxs)
            else:
                min_idx = min_idxs[0, :]
        if pbar:
            pbar.update(dists.shape[0] - cnt_merge)

        prototypes = []
        for i in range(len(series)):
            if i not in deleted:
                prototypes.append(i)
                if i not in cluster_idx:
                    cluster_idx[i] = {i}
        return cluster_idx
Exemplo n.º 3
0
    def plot(self,
             filename=None,
             axes=None,
             ts_height=10,
             bottom_margin=2,
             top_margin=2,
             ts_left_margin=0,
             ts_sample_length=1,
             tr_label_margin=3,
             tr_left_margin=2,
             ts_label_margin=0,
             show_ts_label=None,
             show_tr_label=None,
             cmap='viridis_r',
             ts_color=None):
        """Plot the hierarchy and time series.

        :param filename: If a filename is passed, the image is written to this file.
        :param axes: If a axes array is passed the image is added to this figure.
            Expects axes[0] and axes[1] to be present.
        :param ts_height: Height of a time series
        :param bottom_margin: Margin on bottom
        :param top_margin: Margin on top
        :param ts_left_margin: Margin on left of time series image
        :param ts_sample_length: Space between two points in the time series
        :param tr_label_margin: Margin between tree split and label
        :param tr_left_margin: Left margin for tree
        :param ts_label_margin: Margin between start of series and label
        :param show_ts_label: Show label indices. Boolean, callable or subscriptable object.
            If it is a callable object, the index of the time series will be given and the
            return string will be printed.
        :param show_tr_label: Show tree distances. Boolean, callable or subscriptable object.
            If it is a callable object, the index of the time series will be given and the
            return string will be printed.
        :param cmap: Matplotlib colormap name
        :param ts_color: function that takes the index and returns a color
            (compatible with the matplotlib.color color argument)
        """
        # print('linkage')
        # for l in self.linkage:
        #     print(l)
        if np is None:
            raise NumpyException(
                "The plot function requires Numpy to be installed.")
        try:
            from matplotlib import pyplot as plt
            from matplotlib.lines import Line2D
            import matplotlib.colors as colors
            import matplotlib.cm as cmx
        except ImportError:
            raise MatplotlibException(
                "The plot function requires Matplotlib to be installed.")

        if show_ts_label is True:
            show_ts_label = lambda idx: str(int(idx))
        elif show_ts_label is False or show_ts_label is None:
            show_ts_label = lambda idx: ""
        elif callable(show_ts_label):
            pass
        elif hasattr(show_ts_label, "__getitem__"):
            show_ts_label_prev = show_ts_label
            show_ts_label = lambda idx: show_ts_label_prev[idx]
        else:
            raise AttributeError(
                "Unknown type for show_ts_label, expecting boolean, subscriptable or callable, "
                "got {}".format(type(show_ts_label)))
        if show_tr_label is True:
            show_tr_label = lambda dist: "{:.2f}".format(dist)
        elif show_tr_label is False or show_tr_label is None:
            show_tr_label = lambda dist: ""
        elif callable(show_tr_label):
            pass
        elif hasattr(show_tr_label, "__getitem__"):
            show_tr_label_prev = show_tr_label
            show_tr_label = lambda idx: show_tr_label_prev[idx]
        else:
            raise AttributeError(
                "Unknown type for show_ts_label, expecting boolean, subscriptable or callable, "
                "got {}".format(type(show_ts_label)))

        self._series_y = [0] * len(self.series)

        max_dist = 0
        for _, _, d, _ in self.linkage:
            if not np.isinf(d):
                max_dist = max(max_dist, d)

        node_props = dict()

        max_y = self.series.get_max_y()
        self.ts_height_factor = (ts_height / max_y) * 0.9

        def count(node, height):
            # print('count({},{})'.format(node, height))
            maxheight = None
            maxcumdist = None
            curdepth = None
            cnt = 0
            left_cnt = None
            right_cnt = None
            if node < len(self.series):
                # Leaf
                cnt += 1
                maxheight = height
                maxcumdist = 0
                curdepth = 0
                left_cnt = 0
                right_cnt = 0
            else:
                # Inner node
                child_left, child_right, dist, cnt2 = self.get_linkage(
                    int(node))
                child_left, child_right, cnt2 = int(child_left), int(
                    child_right), int(cnt2)
                if child_left == child_right:
                    raise Exception(
                        "Row in linkage contains same node as left and right child: {}-{}"
                        .format(child_left, child_right))
                if np.isinf(dist):
                    dist = 1.5 * max_dist
                # Left
                nc, nmh, ncd, nmd = count(child_left, height + 1)
                cnt += nc
                maxheight = nmh
                maxcumdist = nmd + dist
                curdepth = ncd + 1
                left_cnt = nc
                # Right
                nc, nmh, ncd, nmd = count(child_right, height + 1)
                cnt += nc
                maxheight = max(maxheight, nmh)
                maxcumdist = max(maxcumdist, nmd + dist)
                curdepth = max(curdepth, ncd + 1)
                right_cnt = nc
                # if cnt != cnt2:
                #     raise Exception("Count in linkage not correct")
            # print('c', node, c)
            node_props[int(node)] = (cnt, curdepth, left_cnt, right_cnt,
                                     maxcumdist)
            # print('count({},{}) = {}, {}, {}, {}'.format(node, height, cnt, maxheight, curdepth, maxcumdist))
            return cnt, maxheight, curdepth, maxcumdist

        cnt, maxheight, curdepth, maxcumdist = count(self.maxnode, 0)
        # for node, props in node_props.items():
        #     print("{:<3}: {}".format(node, props))

        if axes is None:
            fig, ax = plt.subplots(nrows=1, ncols=2, frameon=False)
        else:
            fig, ax = None, axes
        ax[0].set_axis_off()
        # ax[0].set_xlim(left=0, right=curdept)
        ax[0].set_xlim(left=0, right=tr_left_margin + maxcumdist + 0.05)
        ax[0].set_ylim(bottom=0,
                       top=bottom_margin + ts_height * len(self.series) +
                       top_margin)
        # ax[0].plot([0,1],[1,2])
        # ax[0].add_line(Line2D((0,1),(2,2), lw=2, color='black', axes=ax[0]))

        ax[1].set_axis_off()
        max_length = max(len(s) for s in self.series)
        ax[1].set_xlim(left=0,
                       right=ts_left_margin + ts_sample_length * max_length)
        ax[1].set_ylim(bottom=0,
                       top=bottom_margin + ts_height * len(self.series) +
                       top_margin)

        if type(cmap) == str:
            cmap = plt.get_cmap(cmap)
        else:
            pass
        line_colors = cmx.ScalarMappable(norm=colors.Normalize(vmin=0,
                                                               vmax=max_dist),
                                         cmap=cmap)

        cnt_ts = 0

        def plot_i(node, depth, cnt_ts, prev_lcnt, ax, left):
            # print('plot_i', node, depth, cnt_ts, prev_lcnt)
            pcnt, pdepth, plcnt, prcnt, pcdist = node_props[node]
            # px = maxheight - pdepth
            px = tr_left_margin + maxcumdist - pcdist
            py = prev_lcnt * ts_height
            if node < len(self.series):
                # Plot series
                # print('plot series y={}'.format(ts_bottom_margin + ts_height * cnt_ts + self.ts_height_factor))
                self._series_y[int(node)] = bottom_margin + ts_height * cnt_ts
                serie = self.series[int(node)]
                ax[1].text(ts_left_margin + ts_label_margin,
                           bottom_margin + ts_height * cnt_ts + ts_height / 2,
                           show_ts_label(int(node)),
                           ha='left',
                           va='center')
                if ts_color:
                    curcolor = ts_color(int(node))
                else:
                    curcolor = None
                ax[1].plot(ts_left_margin +
                           ts_sample_length * np.arange(len(serie)),
                           bottom_margin + ts_height * cnt_ts +
                           self.ts_height_factor * serie,
                           color=curcolor)
                cnt_ts += 1

            else:
                child_left, child_right, dist, _ = self.get_linkage(node)
                color = line_colors.to_rgba(dist)
                ax[0].text(px + tr_label_margin,
                           py,
                           show_tr_label(dist),
                           ha='left',
                           va='center',
                           color=color)

                # Left
                ccnt, cdepth, clcntl, crcntl, clcdist = node_props[child_left]
                # print('left', ccnt, cdepth, clcntl, crcntl)
                # cx = maxheight - cdepth
                cx = tr_left_margin + maxcumdist - clcdist
                cy = (prev_lcnt - crcntl) * ts_height
                if py == cy:
                    cy -= 1 / 2 * ts_height
                # print('plot line', (px, cx), (py, cy))
                # ax[0].add_line(Line2D((px, cx), (py, cy), lw=2, color='black', axes=ax[0]))
                ax[0].add_line(
                    Line2D((px, px), (py, cy), lw=1, color=color, axes=ax[0]))
                ax[0].add_line(
                    Line2D((px, cx), (cy, cy), lw=1, color=color, axes=ax[0]))
                cnt_ts = plot_i(child_left, depth + 1, cnt_ts,
                                prev_lcnt - crcntl, ax, True)

                # Right
                ccnt, cdepth, clcntr, crcntr, crcdist = node_props[child_right]
                # print('right', ccnt, cdepth, clcntr, crcntr)
                # cx = maxheight - cdepth
                cx = tr_left_margin + maxcumdist - crcdist
                cy = (prev_lcnt + clcntr) * ts_height
                if py == cy:
                    cy += 1 / 2 * ts_height
                # print('plot line', (px, cx), (py, cy))
                # ax[0].add_line(Line2D((px, cx), (py, cy), lw=2, color='black', axes=ax[0]))
                ax[0].add_line(
                    Line2D((px, px), (py, cy), lw=1, color=color, axes=ax[0]))
                ax[0].add_line(
                    Line2D((px, cx), (cy, cy), lw=1, color=color, axes=ax[0]))
                cnt_ts = plot_i(child_right, depth + 1, cnt_ts,
                                prev_lcnt + clcntr, ax, False)
            return cnt_ts

        plot_i(self.maxnode, 0, 0, node_props[self.maxnode][2], ax, True)

        if filename:
            if isinstance(filename, Path):
                filename = str(filename)
            plt.savefig(filename, bbox_inches='tight', pad_inches=0)
            plt.close()
            fig, ax = None, None

        return fig, ax