Example #1
0
    def _plot_perm_imp(perm, test_sample, node_params, **kwargs):
        weights = dict(
            zip(test_sample.columns.tolist(), perm.feature_importances_))

        if node_params is None:
            node_params = {}
            node_flg = True
        else:
            node_flg = False
        node_weights = {}
        for node, val in weights.items():
            if len(node) > 1:
                continue
            if node_flg:
                node_params.update({
                    node[0]:
                    'nice_node' if weights[node] >= 0 else 'bad_node',
                })
            node_weights.update({node[0]: val})

        edge_cols = [i for i in test_sample.columns if len(i) == 2]
        if len(edge_cols) == 0:
            print(
                "Sorry, you use only unigrams, change ngram_range to (1, 2) or greater"
            )
            return
        data = []
        for key in edge_cols:
            data.append([key[0], key[1], weights.get(key)])

        plot.graph(pd.DataFrame(data),
                   node_params,
                   node_weights=node_weights,
                   **kwargs)
Example #2
0
    def plot_graph(self, user_based=True, node_params=None, **kwargs):
        """
        Create interactive graph visualization

        :param user_based: if True, then edge weights is calculated as unique rate of users who go through them
        :param node_params: mapping describes which node should be highlighted by target or source type
            Node param should be represented in the following form
            ```{
                    'lost': 'bad_target',
                    'passed': 'nice_target',
                    'onboarding_welcome_screen': 'source',
                }```
            If mapping is not given, it will be constracted from config
        :param kwargs: other parameters for visualization
        :return: Nothing
        """
        if user_based:
            kwargs.update({
                'edge_col': self.retention_config['index_col'],
                'edge_attributes': '_nunique',
                'norm': True,
            })
        if node_params is None:
            _node_params = {
                'positive_target_event': 'nice_target',
                'negative_target_event': 'bad_target',
                'source_event': 'source',
            }
            node_params = {}
            for key, val in _node_params.items():
                name = self.retention_config.get(key)
                if name is None:
                    continue
                node_params.update({name: val})
        plot.graph(self._obj.trajectory.get_edgelist(**kwargs), node_params, **kwargs)