コード例 #1
0
def display_gui(wti_index, print_args=False):

    def lw(width='180px', left='0'):
        return widgets.Layout(width=width)

    def parties_options():
        return ['ALL'] + sorted([ x for x in wti_index.get_countries_list() if x not in ['ALL', 'ALL OTHER']])

    period_group_index_widget = widgets_config.period_group_widget(index_as_value=True)
    topic_group_name_widget = widgets_config.topic_groups_widget(layout=lw())
    parties_widget = widgets_config.parties_widget(options=parties_options(), value=['ALL'], rows=10)

    recode_7corr_widget = widgets_config.toggle('Recode 7CULT', True, tooltip='Recode all treaties with cultural=yes as 7CORR', layout=lw(width='110px'))
    use_lemma_widget = widgets_config.toggle('Use LEMMA', False,  tooltip='Use WordNet lemma', layout=lw(width='110px'))
    remove_stopwords_widget = widgets_config.toggle('Remove STOP', True, tooltip='Do not include stopwords', layout=lw(width='110px'))
    compute_co_occurance_widget = widgets_config.toggle('Cooccurrence', True,  tooltip='Compute Cooccurrence', layout=lw(width='110px'))

    extra_groupbys_widget = widgets_config.dropdown('Groupbys', EXTRA_GROUPBY_OPTIONS, None, layout=lw())
    min_word_size_widget = widgets_config.itext(0, 5, 2, description='Min word', layout=lw())
    # plot_style_widget = widgets_config.plot_style_widget(layout=lw())
    n_top_widget = widgets_config.slider('Top/grp #', 2, 100, 25, step=10, layout=lw(width='200px'))
    n_min_count_widget = widgets_config.slider('Min count', 1, 10, 5, step=1, tooltip='Filter out words with count less than specified value', layout=lw(width='200px'))

    treaty_filter_widget = widgets_config.dropdown('Filter', config.TREATY_FILTER_OPTIONS, 'is_cultural', layout=lw())
    output_format_widget = widgets_config.dropdown('Output', OUTPUT_OPTIONS, 'plot_stacked_bar', layout=lw())
    progress_widget = widgets_config.progress(0, 8, 1, 0, layout=lw(width='95%'))

    def progress(x=None):
        progress_widget.value = progress_widget.value + 1 if x is None else x

    iw = widgets.interactive(
        display_headnote_toplist,
        period_group_index=period_group_index_widget,
        topic_group_name=topic_group_name_widget,
        recode_7corr=recode_7corr_widget,
        treaty_filter=treaty_filter_widget,
        parties=parties_widget,
        extra_groupbys=extra_groupbys_widget,
        n_min_count=n_min_count_widget,
        n_top=n_top_widget,
        min_word_size=min_word_size_widget,
        use_lemma=use_lemma_widget,
        compute_co_occurance=compute_co_occurance_widget,
        remove_stopwords=remove_stopwords_widget,
        output_format=output_format_widget,
        progress=widgets.fixed(progress),
        wti_index=widgets.fixed(wti_index),
        print_args=widgets.fixed(print_args)
        # plot_style=plot_style
    )

    boxes = widgets.HBox(
        [
            widgets.VBox([ period_group_index_widget, parties_widget ]),
            widgets.VBox([ topic_group_name_widget, extra_groupbys_widget, n_top_widget, min_word_size_widget, n_min_count_widget]),
            widgets.VBox([ recode_7corr_widget, use_lemma_widget, remove_stopwords_widget, compute_co_occurance_widget ]),
            widgets.VBox([ treaty_filter_widget, output_format_widget, progress_widget]),
        ]
    )
    display(widgets.VBox([boxes, iw.children[-1]]))
    iw.update()
コード例 #2
0
def display_corpus_load_gui(data_folder, wti_index, container):

    lw = lambda w: widgets.Layout(width=w)

    treaty_source_options = wti_index.unique_sources
    treaty_default_source_options = ['LTS', 'UNTS', 'UNXX']

    language_options = {
        config.LANGUAGE_MAP[k].title(): k
        for k in config.LANGUAGE_MAP.keys() if k in ['en', 'fr']
    }

    period_group_options = {
        config.PERIOD_GROUPS_ID_MAP[k]['title']: k
        for k in config.PERIOD_GROUPS_ID_MAP
    }

    corpus_files = list(
        sorted(glob.glob(
            os.path.join(data_folder,
                         'treaty_text_corpora_??_*[!_preprocessed].zip')),
               key=lambda x: os.stat(x).st_mtime))

    if len(corpus_files) == 0:
        print("No prepared corpus")
        return

    corpus_files = [os.path.basename(x) for x in corpus_files]

    gui = types.SimpleNamespace(
        progress=widgets.IntProgress(value=0,
                                     min=0,
                                     max=5,
                                     step=1,
                                     description='',
                                     layout=lw('90%')),
        output=widgets.Output(layout={'border': '1px solid black'}),
        source_path=widgets_config.dropdown(description='Corpus',
                                            options=corpus_files,
                                            value=corpus_files[-1],
                                            layout=lw('400px')),
        sources=widgets_config.select_multiple(
            description='Sources',
            options=treaty_source_options,
            values=treaty_default_source_options,
            disabled=False,
            layout=lw('180px')),
        language=widgets_config.dropdown(description='Language',
                                         options=language_options,
                                         value='en',
                                         layout=lw('180px')),
        period_group=widgets_config.dropdown('Period',
                                             period_group_options,
                                             'years_1935-1972',
                                             disabled=False,
                                             layout=lw('180px')),
        merge_entities=widgets_config.toggle('Merge NER',
                                             False,
                                             icon='',
                                             layout=lw('100px')),
        overwrite=widgets_config.toggle(
            'Force',
            False,
            icon='',
            layout=lw('100px'),
            tooltip="Force generation of new corpus (even if exists)"),
        compute_pos=widgets_config.toggle(
            'POS',
            True,
            icon='',
            layout=lw('100px'),
            disabled=True,
            tooltip="Enable Part-of-Speech tagging"),
        compute_ner=widgets_config.toggle('NER',
                                          False,
                                          icon='',
                                          layout=lw('100px'),
                                          disabled=False,
                                          tooltip="Enable NER tagging"),
        compute_dep=widgets_config.toggle('DEP',
                                          False,
                                          icon='',
                                          layout=lw('100px'),
                                          disabled=True,
                                          tooltip="Enable dependency parsing"),
        compute=widgets.Button(description='Compute',
                               button_style='Success',
                               layout=lw('100px')))

    display(
        widgets.VBox([
            gui.progress,
            widgets.HBox([
                gui.source_path,
                widgets.VBox([gui.language, gui.period_group, gui.sources]),
                widgets.VBox([gui.merge_entities, gui.overwrite]),
                widgets.VBox(
                    [gui.compute_pos, gui.compute_ner, gui.compute_dep]),
                gui.compute
            ]), gui.output
        ]))

    def tick(step=None, max_step=None):
        if max_step is not None:
            gui.progress.max = max_step
        gui.progress.value = gui.progress.value + 1 if step is None else step

    def compute_callback(*_args):
        gui.output.clear_output()
        with gui.output:
            disabled_pipes = (() if gui.compute_pos.value else ("tagger",)) + \
                             (() if gui.compute_dep.value else ("parser",)) + \
                             (() if gui.compute_ner.value else ("ner",))
            generate_textacy_corpus(data_folder=data_folder,
                                    wti_index=wti_index,
                                    container=container,
                                    source_path=os.path.join(
                                        data_folder, gui.source_path.value),
                                    language=gui.language.value,
                                    merge_entities=gui.merge_entities.value,
                                    overwrite=gui.overwrite.value,
                                    period_group=gui.period_group.value,
                                    parties=None,
                                    disabled_pipes=tuple(disabled_pipes),
                                    tick=tick,
                                    treaty_sources=gui.sources.value)

    gui.compute.on_click(compute_callback)
コード例 #3
0
def display_network_analyis_gui(wti_index, plot_data):

    box = widgets.Layout()

    def lw(w='170px'):
        return widgets.Layout(width=w)

    party_preset_widget = widgets_config.dropdown('Presets',
                                                  config.PARTY_PRESET_OPTIONS,
                                                  None,
                                                  layout=lw())

    # period_group=widgets_config.period_group_widget(layout=lw()),
    period_group_index_widget = widgets_config.period_group_widget(
        index_as_value=True, layout=lw())

    topic_group_widget = widgets_config.topic_groups_widget2(layout=lw())
    party_name_widget = widgets_config.party_name_widget(layout=lw())

    treaty_filter_widget = widgets_config.dropdown(
        'Filter', config.TREATY_FILTER_OPTIONS, 'is_cultural', layout=lw())
    recode_is_cultural_widget = widgets_config.toggle('Recode 7CORR',
                                                      True,
                                                      layout=lw(w='100px'))

    parties_widget = widgets_config.parties_widget(
        options=wti_index.get_countries_list(), value=['FRANCE'], layout=lw())

    palette_widget = widgets_config.dropdown('Color',
                                             PALETTE_OPTIONS,
                                             None,
                                             layout=lw())
    node_size_widget = widgets_config.dropdown('Node size',
                                               NODE_SIZE_OPTIONS,
                                               None,
                                               layout=lw())
    node_size_range_widget = widgets_config.rangeslider('Range',
                                                        5,
                                                        100, [20, 49],
                                                        step=1,
                                                        layout=box)

    C_widget = widgets_config.slider('C',
                                     0,
                                     100,
                                     1,
                                     step=1,
                                     layout=box,
                                     continuous_update=False)
    K_widget = widgets_config.sliderf('K',
                                      0.01,
                                      1.0,
                                      0.01,
                                      0.10,
                                      layout=box,
                                      continuous_update=False)
    p_widget = widgets_config.sliderf('p',
                                      0.01,
                                      2.0,
                                      0.01,
                                      1.10,
                                      layout=box,
                                      continuous_update=False)
    fig_width_widget = widgets_config.slider('Width',
                                             600,
                                             1600,
                                             900,
                                             step=100,
                                             layout=box,
                                             continuous_update=False)
    fig_height_widget = widgets_config.slider('Height',
                                              600,
                                              1600,
                                              700,
                                              step=100,
                                              layout=box,
                                              continuous_update=False)
    output_widget = widgets_config.dropdown('Output',
                                            OUTPUT_OPTIONS,
                                            'network',
                                            layout=lw())
    layout_algorithm_widget = widgets_config.dropdown('Layout',
                                                      NETWORK_LAYOUT_OPTIONS,
                                                      'nx_spring_layout',
                                                      layout=lw())
    progress_widget = widgets_config.progress(0, 4, 1, 0, layout=lw("300px"))
    node_partition_widget = widgets_config.dropdown('Partition',
                                                    COMMUNITY_OPTIONS,
                                                    None,
                                                    layout=lw())
    simple_mode_widget = widgets_config.toggle('Simple',
                                               False,
                                               tooltip='Simple view',
                                               layout=lw(w='100px'))

    slice_range_type_widget = widgets_config.dropdown('Unit',
                                                      SLICE_TYPE_OPTIONS,
                                                      SLICE_TYPE_DEFAULT,
                                                      layout=lw())
    time_travel_range_widget = widgets_config.rangeslider(
        'Time travel',
        0,
        100, [0, 100],
        layout=lw('60%'),
        continuous_update=False)
    time_travel_label_widget = widgets.Label(value="")

    label_map = {
        'graphtool_sfdp': {
            'K': 'K',
            'C': 'C',
            'p': 'gamma'
        },
        'graphtool_arf': {
            'K': 'd',
            'C': 'a',
            'p': '_'
        },
        'graphtool_fr': {
            'K': 'a/(2*N)',
            'C': 'r/2',
            'p': '_'
        },
        'nx_spring_layout': {
            'K': 'k',
            'C': '_',
            'p': '_'
        },
        'nx_spectral_layout': {
            'K': '_',
            'C': '_',
            'p': '_'
        },
        'nx_circular_layout': {
            'K': '_',
            'C': '_',
            'p': '_'
        },
        'nx_shell_layout': {
            'K': '_',
            'C': '_',
            'p': '_'
        },
        'nx_kamada_kawai_layout': {
            'K': '_',
            'C': '_',
            'p': '_'
        },
        'other': {
            'K': 'K',
            'C': '_',
            'p': '_'
        },
    }

    def on_layout_type_change(change):
        layout = change['new']
        opts = label_map.get(layout, label_map.get('other'))
        gv = [
            'graphviz_neato', 'graphviz_dot', 'graphviz_circo', 'graphviz_fdp',
            'graphviz_sfdp'
        ]
        C_widget.disabled = layout not in [
            'graphtool_sfdp', 'graphtool_arf', 'graphtool_fr'
        ]
        C_widget.description = ' ' if C_widget.disabled else opts.get('C')
        K_widget.disabled = layout not in [
            'graphtool_sfdp', 'graphtool_arf', 'graphtool_fr',
            'nx_spring_layout'
        ] + gv
        K_widget.description = ' ' if K_widget.disabled else opts.get('K')
        p_widget.disabled = layout not in [
            'graphtool_sfdp',
        ]
        p_widget.description = ' ' if p_widget.disabled else opts.get('p')

    def on_simple_mode_value_change(change):
        display_mode = 'none' if change['new'] is True else ''
        node_partition_widget.layout.display = display_mode
        node_size_widget.layout.display = display_mode
        node_size_range_widget.layout.display = display_mode
        layout_algorithm_widget.layout.display = display_mode
        C_widget.layout.display = display_mode
        K_widget.layout.display = display_mode
        p_widget.layout.display = display_mode
        fig_width_widget.layout.display = display_mode
        fig_height_widget.layout.display = display_mode
        palette_widget.layout.display = display_mode
        if change['new'] is True:
            on_layout_type_change(dict(new=layout_algorithm_widget.value))

    def progress_callback(step=0):
        progress_widget.value = step

    def update_slice_range(slice_range_type):
        """ Called whenever display of new plot_data is done. """
        slice_range_type = slice_range_type or plot_data.slice_range_type
        sign_dates = plot_data.edges['signed']

        if slice_range_type == 1:
            range_min, range_max = 0, len(sign_dates)
        else:
            range_min, range_max = min(sign_dates).year, max(sign_dates).year

        plot_data.update(slice_range_type=slice_range_type,
                         slice_range=(range_min, range_max))

        time_travel_range_widget.min = 0
        time_travel_range_widget.max = 10000
        time_travel_range_widget.min = range_min
        time_travel_range_widget.max = range_max
        time_travel_range_widget.value = (range_min, range_max)

    def slice_changed_callback(min_year, max_year):
        time_travel_range_widget.description = '{}-{}'.format(
            min_year, max_year)

    def on_slice_range_type_change(change):
        update_slice_range(change['new'])

    def on_party_preset_change(change):  # pylint: disable=W0613
        if party_preset_widget.value is None:
            return
        parties_widget.value = parties_widget.options if 'ALL' in party_preset_widget.value \
            else party_preset_widget.value

    slice_range_type_widget.observe(on_slice_range_type_change, names='value')
    simple_mode_widget.observe(on_simple_mode_value_change, names='value')
    layout_algorithm_widget.observe(on_layout_type_change, names='value')
    party_preset_widget.observe(on_party_preset_change, names='value')

    simple_mode_widget.value = True

    wn = widgets.interactive(display_party_network,
                             parties=parties_widget,
                             period_group_index=period_group_index_widget,
                             treaty_filter=treaty_filter_widget,
                             plot_data=widgets.fixed(plot_data),
                             topic_group=topic_group_widget,
                             recode_is_cultural=recode_is_cultural_widget,
                             layout_algorithm=layout_algorithm_widget,
                             C=C_widget,
                             K=K_widget,
                             p1=p_widget,
                             output=output_widget,
                             party_name=party_name_widget,
                             node_size_range=node_size_range_widget,
                             palette_name=palette_widget,
                             width=fig_width_widget,
                             height=fig_height_widget,
                             node_size=node_size_widget,
                             node_partition=node_partition_widget,
                             year_limit=widgets.fixed(None),
                             wti_index=widgets.fixed(wti_index),
                             progress=widgets.fixed(progress_callback),
                             done_callback=widgets.fixed(update_slice_range))

    boxes = widgets.HBox([
        widgets.VBox([
            period_group_index_widget, topic_group_widget, party_name_widget,
            treaty_filter_widget, party_preset_widget
        ]),
        widgets.VBox([parties_widget]),
        widgets.VBox([
            widgets.HBox([
                recode_is_cultural_widget, simple_mode_widget, progress_widget
            ]),
            widgets.HBox([
                widgets.VBox([
                    output_widget, layout_algorithm_widget, palette_widget,
                    node_size_widget, node_partition_widget
                ]),
                widgets.VBox([
                    K_widget, C_widget, p_widget, fig_width_widget,
                    fig_height_widget, node_size_range_widget
                ])
            ])
        ])
    ])

    display(widgets.VBox([boxes, wn.children[-1]]))

    wn.update()

    iw_time_travel = widgets.interactive(
        display_partial_party_network,
        plot_data=widgets.fixed(plot_data),
        slice_range_type=slice_range_type_widget,
        slice_range=time_travel_range_widget,
        slice_changed_callback=widgets.fixed(slice_changed_callback))

    time_travel_box = widgets.VBox([
        widgets.HBox([
            time_travel_label_widget, time_travel_range_widget,
            slice_range_type_widget
        ]), iw_time_travel.children[-1]
    ])

    display(time_travel_box)

    slice_range_type_widget.value = 2
コード例 #4
0
def display_corpus_load_gui(data_folder,
                            document_index=None,
                            container=None,
                            compute_ner=False,
                            domain_logic=None):

    lw = lambda w: widgets.Layout(width=w)

    language_options = {
        config.LANGUAGE_MAP[k].title(): k
        for k in config.LANGUAGE_MAP.keys()
    }

    corpus_files = sorted(glob.glob(os.path.join(data_folder, '*.txt.zip')))

    gui = types.SimpleNamespace(
        progress=widgets.IntProgress(value=0,
                                     min=0,
                                     max=5,
                                     step=1,
                                     description='',
                                     layout=lw('90%')),
        output=widgets.Output(layout={'border': '1px solid black'}),
        source_path=widgets_config.dropdown(description='Corpus',
                                            options=corpus_files,
                                            value=corpus_files[-1],
                                            layout=lw('300px')),
        language=widgets_config.dropdown(description='Language',
                                         options=language_options,
                                         value='en',
                                         layout=lw('180px')),
        merge_entities=widgets_config.toggle('Merge NER',
                                             compute_ner,
                                             icon='',
                                             layout=lw('100px')),
        binary_format=widgets_config.toggle('Store as binary',
                                            False,
                                            disabled=False,
                                            icon='',
                                            layout=lw('130px')),
        use_compression=widgets_config.toggle('Store compressed',
                                              True,
                                              disabled=False,
                                              icon='',
                                              layout=lw('130px')),
        overwrite=widgets_config.toggle(
            'Force if exists',
            False,
            icon='',
            layout=lw('130px'),
            tooltip="Force generation of new corpus (even if exists)"),
        compute_pos=widgets_config.toggle(
            'PoS',
            True,
            icon='',
            layout=lw('100px'),
            disabled=True,
            tooltip="Enable Part-of-Speech tagging"),
        compute_ner=widgets_config.toggle(
            'NER',
            compute_ner,
            icon='',
            layout=lw('100px'),
            disabled=False,
            tooltip="Enable named entity recognition"),
        compute_dep=widgets_config.toggle('DEP',
                                          False,
                                          icon='',
                                          layout=lw('100px'),
                                          disabled=True,
                                          tooltip="Enable dependency parsing"),
        compute=widgets.Button(description='Compute',
                               button_style='Success',
                               layout=lw('100px')))

    display(
        widgets.VBox([
            gui.progress,
            widgets.HBox([
                widgets.VBox([
                    gui.source_path,
                    gui.language,
                ]),
                widgets.VBox(
                    [gui.compute_pos, gui.compute_ner, gui.compute_dep]),
                widgets.VBox(
                    [gui.overwrite, gui.binary_format, gui.use_compression]),
                widgets.VBox([
                    gui.compute,
                    gui.merge_entities,
                ]),
            ]), gui.output
        ]))

    def tick(step=None, max_step=None):
        if max_step is not None:
            gui.progress.max = max_step
        gui.progress.value = gui.progress.value + 1 if step is None else step

    def compute_callback(*_args):
        gui.output.clear_output()
        with gui.output:
            disabled_pipes = (() if gui.compute_pos.value else ("tagger",)) + \
                             (() if gui.compute_dep.value else ("parser",)) + \
                             (() if gui.compute_ner.value else ("ner",)) + \
                             ("textcat", )
            generate_textacy_corpus(domain_logic=domain_logic,
                                    data_folder=data_folder,
                                    container=container,
                                    document_index=document_index,
                                    source_path=gui.source_path.value,
                                    language=gui.language.value,
                                    merge_entities=gui.merge_entities.value,
                                    overwrite=gui.overwrite.value,
                                    binary_format=gui.binary_format.value,
                                    use_compression=gui.use_compression.value,
                                    disabled_pipes=tuple(disabled_pipes),
                                    tick=tick)

    gui.compute.on_click(compute_callback)
def display_gui(wti_index, print_args=False):
    def lw(width='120px'):
        return widgets.Layout(width=width)

    party_preset_options = wti_index.get_party_preset_options()

    period_group_index_widget = widgets_config.period_group_widget(
        index_as_value=True)
    topic_group_name_widget = widgets_config.topic_groups_widget(
        value='7CULTURE')
    # treaty_filter_widget = widgets_config.treaty_filter_widget()
    recode_is_cultural_widget = widgets_config.recode_7corr_widget(
        layout=lw('120px'))
    normalize_values_widget = widgets_config.toggle('Display %',
                                                    False,
                                                    icon='',
                                                    layout=lw('100px'))
    chart_type_name_widget = widgets_config.dropdown(
        'Output',
        config.CHART_TYPE_NAME_OPTIONS,
        "plot_stacked_bar",
        layout=lw('200px'))
    plot_style_widget = widgets_config.plot_style_widget()
    parties_widget = widgets_config.parties_widget(options=[
        x for x in wti_index.get_countries_list() if x != 'ALL OTHER'
    ],
                                                   value=['FRANCE'])
    party_preset_widget = widgets_config.dropdown('Presets',
                                                  party_preset_options,
                                                  None,
                                                  layout=lw(width='200px'))
    chart_per_category_widget = widgets_config.toggle(
        'Chart per Qty',
        False,
        tooltip='Display one chart per selected quantity category',
        layout=lw())
    extra_other_category_widget = widgets_config.toggle('Add OTHER topics',
                                                        False,
                                                        layout=lw())
    target_quantity_widget = widgets_config.dropdown(
        'Quantity', ['topic', 'party', 'source', 'continent', 'group'],
        'topic',
        layout=lw(width='200px'))
    progress_widget = widgets_config.progress(0, 5, 1, 0, layout=lw())

    def stepper(step=None):
        progress_widget.value = progress_widget.value + 1 if step is None else step

    def on_party_preset_change(change):  # pylint: disable=W0613

        if party_preset_widget.value is None:
            return

        if 'ALL' in party_preset_widget.value:
            parties_widget.value = parties_widget.options
        else:
            parties_widget.value = party_preset_widget.value

        #if top_n_parties_widget.value > 0:
        #    top_n_parties_widget.value = 0

    party_preset_widget.observe(on_party_preset_change, names='value')

    itw = widgets.interactive(display_topic_quantity_groups,
                              period_group_index=period_group_index_widget,
                              topic_group_name=topic_group_name_widget,
                              parties=parties_widget,
                              recode_is_cultural=recode_is_cultural_widget,
                              normalize_values=normalize_values_widget,
                              extra_other_category=extra_other_category_widget,
                              chart_type_name=chart_type_name_widget,
                              plot_style=plot_style_widget,
                              chart_per_category=chart_per_category_widget,
                              target_quantity=target_quantity_widget,
                              progress=widgets.fixed(stepper),
                              wti_index=widgets.fixed(wti_index),
                              print_args=widgets.fixed(print_args))

    boxes = widgets.HBox([
        widgets.VBox([
            period_group_index_widget, topic_group_name_widget,
            target_quantity_widget, party_preset_widget
        ]),
        widgets.VBox([parties_widget]),
        widgets.VBox([
            recode_is_cultural_widget, extra_other_category_widget,
            chart_per_category_widget
        ]),
        widgets.VBox([chart_type_name_widget, plot_style_widget]),
        widgets.VBox([normalize_values_widget, progress_widget])
    ])
    display(widgets.VBox([boxes, itw.children[-1]]))
    itw.update()
def display_corpus_load_gui(data_folder, wti_index, container):

    lw = lambda w: widgets.Layout(width=w)

    language_options = {
        config.LANGUAGE_MAP[k].title(): k
        for k in config.LANGUAGE_MAP.keys() if k in ['en', 'fr']
    }

    period_group_options = {
        config.PERIOD_GROUPS_ID_MAP[k]['title']: k
        for k in config.PERIOD_GROUPS_ID_MAP
    }

    default_corpus_index = -1
    corpus_files = sorted(
        glob.glob(
            os.path.join(data_folder, 'treaty_text_corpora_??_??????.zip')))

    if len(corpus_files) > 0:
        x, *_ = [x for x in corpus_files if fnmatch.fnmatch(x, '*_en*')
                 ] + corpus_files[-1:]
        default_corpus_index = corpus_files.index(x)
    else:
        corpus_files = ['No corpus found']

    gui = types.SimpleNamespace(
        progress=widgets.IntProgress(value=0,
                                     min=0,
                                     max=5,
                                     step=1,
                                     description='',
                                     layout=lw('90%')),
        output=widgets.Output(layout={'border': '1px solid black'}),
        source_path=widgets_config.dropdown(
            description='Corpus',
            options=corpus_files,
            value=corpus_files[default_corpus_index],
            layout=lw('400px')),
        language=widgets_config.dropdown(description='Language',
                                         options=language_options,
                                         value='en',
                                         layout=lw('180px')),
        period_group=widgets_config.dropdown('Period',
                                             period_group_options,
                                             'years_1945-1972',
                                             disabled=False,
                                             layout=lw('180px')),
        merge_entities=widgets_config.toggle('Merge NER',
                                             False,
                                             icon='',
                                             layout=lw('100px')),
        overwrite=widgets_config.toggle(
            'Force',
            False,
            icon='',
            layout=lw('100px'),
            tooltip="Force generation of new corpus (even if exists)"),
        compute_pos=widgets_config.toggle(
            'POS',
            True,
            icon='',
            layout=lw('100px'),
            disabled=True,
            tooltip="Enable Part-of-Speech tagging"),
        compute_ner=widgets_config.toggle('NER',
                                          False,
                                          icon='',
                                          layout=lw('100px'),
                                          disabled=False,
                                          tooltip="Enable NER tagging"),
        compute_dep=widgets_config.toggle('DEP',
                                          False,
                                          icon='',
                                          layout=lw('100px'),
                                          disabled=True,
                                          tooltip="Enable dependency parsing"),
        compute=widgets.Button(description='Compute',
                               button_style='Success',
                               layout=lw('100px')))

    display(
        widgets.VBox([
            gui.progress,
            widgets.HBox([
                gui.source_path,
                widgets.VBox([gui.language, gui.period_group]),
                widgets.VBox([gui.merge_entities, gui.overwrite]),
                widgets.VBox(
                    [gui.compute_pos, gui.compute_ner, gui.compute_dep]),
                gui.compute
            ]), gui.output
        ]))

    def tick(step=None, max_step=None):
        if max_step is not None:
            gui.progress.max = max_step
        gui.progress.value = gui.progress.value + 1 if step is None else step

    def compute_callback(*_args):
        gui.output.clear_output()
        with gui.output:
            disabled_pipes = (() if gui.compute_pos.value else ("tagger",)) + \
                             (() if gui.compute_dep.value else ("parser",)) + \
                             (() if gui.compute_ner.value else ("ner",))
            generate_textacy_corpus(data_folder=data_folder,
                                    wti_index=wti_index,
                                    container=container,
                                    source_path=gui.source_path.value,
                                    language=gui.language.value,
                                    merge_entities=gui.merge_entities.value,
                                    overwrite=gui.overwrite.value,
                                    period_group=gui.period_group.value,
                                    parties=None,
                                    disabled_pipes=tuple(disabled_pipes),
                                    tick=tick)

    gui.compute.on_click(compute_callback)
コード例 #7
0
def display_gui(wti_index, print_args=False):
    def lw(width='100px', left='0'):
        return widgets.Layout(width=width, left=left)

    def period_group_window(period_group_index):
        '''Returns (min_year, max_year) for the period group.

        Periods are either a list of years, or a list of tuples (from-year, to-year)
        '''
        period_group = config.DEFAULT_PERIOD_GROUPS[period_group_index]

        periods = period_group['periods']

        if period_group['type'] == 'divisions':
            periods = list(itertools.chain(*periods))

        return min(periods), max(periods)

    treaty_source_options = wti_index.unique_sources
    party_preset_options = wti_index.get_party_preset_options()

    period_group_index_widget = widgets_config.period_group_widget(
        index_as_value=True)

    min_year, max_year = period_group_window(period_group_index_widget.value)

    gui = types.SimpleNamespace(
        year_limit=widgets_config.rangeslider('Window',
                                              min_year,
                                              max_year, [min_year, max_year],
                                              layout=lw('900px'),
                                              continuous_update=False),
        sources=widgets_config.select_multiple(description='Sources',
                                               options=treaty_source_options,
                                               values=treaty_source_options,
                                               disabled=False,
                                               layout=lw('200px')),
        period_group_index=period_group_index_widget,
        party_name=widgets_config.party_name_widget(),
        normalize_values=widgets_config.toggle('Display %',
                                               False,
                                               icon='',
                                               layout=lw('100px')),
        chart_type_name=widgets_config.dropdown('Output',
                                                config.CHART_TYPE_NAME_OPTIONS,
                                                "plot_stacked_bar",
                                                layout=lw('200px')),
        plot_style=widgets_config.plot_style_widget(),
        top_n_parties=widgets_config.slider('Top #',
                                            0,
                                            10,
                                            0,
                                            continuous_update=False,
                                            layout=lw(width='200px')),
        party_preset=widgets_config.dropdown('Presets',
                                             party_preset_options,
                                             None,
                                             layout=lw(width='200px')),
        parties=widgets_config.parties_widget(
            options=wti_index.get_countries_list(
                excludes=['ALL', 'ALL OTHER']),
            value=['FRANCE'],
            rows=8,
        ),
        treaty_filter=widgets_config.dropdown('Filter',
                                              config.TREATY_FILTER_OPTIONS,
                                              'is_cultural',
                                              layout=lw(width='200px')),
        extra_category=widgets_config.dropdown('Include',
                                               OTHER_CATEGORY_OPTIONS,
                                               '',
                                               layout=lw(width='200px')),
        #overlay_option = widgets_config.toggle('Overlay', True, icon='', layout=lw()),
        progress=widgets_config.progress(0, 5, 1, 0, layout=lw('95%')),
        info=widgets.Label(value="", layout=lw('95%')))

    def stepper(step=None):
        gui.progress.value = gui.progress.value + 1 if step is None else step

    itw = widgets.interactive(
        display_quantity_by_party,
        period_group_index=period_group_index_widget,
        year_limit=gui.year_limit,
        party_name=gui.party_name,
        parties=gui.parties,
        treaty_filter=gui.treaty_filter,
        extra_category=gui.extra_category,
        normalize_values=gui.normalize_values,
        chart_type_name=gui.chart_type_name,
        plot_style=gui.plot_style,
        top_n_parties=gui.top_n_parties,
        overlay=widgets.fixed(False),  # overlay_option,
        progress=widgets.fixed(stepper),
        wti_index=widgets.fixed(wti_index),
        print_args=widgets.fixed(print_args),
        treaty_sources=gui.sources)

    def on_party_preset_change(change):  # pylint: disable=W0613

        if gui.party_preset.value is None:
            return

        try:
            gui.parties.unobserve(on_parties_change, names='value')
            gui.top_n_parties.unobserve(on_top_n_parties_change, names='value')
            if 'ALL' in gui.party_preset.value:
                gui.parties.value = gui.parties.options
            else:
                gui.parties.value = gui.party_preset.value

            if gui.top_n_parties.value > 0:
                gui.top_n_parties.value = 0
        except Exception as ex:
            logger.info(ex)
        finally:
            gui.parties.observe(on_parties_change, names='value')
            gui.top_n_parties.observe(on_top_n_parties_change, names='value')

    def on_parties_change(change):  # pylint: disable=W0613
        try:
            if gui.top_n_parties.value != 0:
                gui.top_n_parties.unobserve(on_top_n_parties_change,
                                            names='value')
                gui.top_n_parties.value = 0
                gui.top_n_parties.observe(on_top_n_parties_change,
                                          names='value')
        except Exception as ex:  # pylint: disable=W0703
            logger.info(ex)

    def on_top_n_parties_change(change):  # pylint: disable=W0613
        try:
            if gui.top_n_parties.value > 0:
                gui.parties.unobserve(on_parties_change, names='value')
                gui.parties.disabled = True
                gui.party_preset.disabled = True
                if len(gui.parties.value) > 0:
                    gui.parties.value = []
            else:
                gui.parties.observe(on_parties_change, names='value')
                gui.parties.disabled = False
                gui.party_preset.disabled = False
        except Exception as ex:  # pylint: disable=W0703
            logger.info(ex)

    def set_years_window(period_group_index):
        try:
            min_year, max_year = period_group_window(period_group_index)
            gui.year_limit.min, gui.year_limit.max = min_year, max_year
            gui.year_limit.value = (min_year, max_year)
            period_group = config.DEFAULT_PERIOD_GROUPS[period_group_index]
            gui.year_limit.disabled = (period_group['type'] != 'range')
        except Exception as ex:  # pylint: disable=W0703
            logger.info(ex)

    def on_period_change(change):
        period_group_index = change['new']
        set_years_window(period_group_index)

    gui.parties.observe(on_parties_change, names='value')
    gui.period_group_index.observe(on_period_change, names='value')
    gui.party_preset.observe(on_party_preset_change, names='value')
    gui.top_n_parties.observe(on_top_n_parties_change, names='value')

    set_years_window(gui.period_group_index.value)

    boxes = widgets.HBox([
        widgets.VBox([
            gui.period_group_index, gui.party_name, gui.top_n_parties,
            gui.party_preset
        ]),
        widgets.VBox([gui.parties]),
        widgets.VBox([
            widgets.HBox([
                widgets.VBox(
                    [gui.treaty_filter, gui.extra_category, gui.sources]),
                widgets.VBox([gui.chart_type_name, gui.plot_style]),
                widgets.VBox([gui.normalize_values])
            ]), gui.progress, gui.info
        ])
    ])

    display(widgets.VBox([boxes, gui.year_limit, itw.children[-1]]))
    itw.update()
def display_gui(wti_index, print_args=False):
    def lw(width='120px'):
        return widgets.Layout(width=width)

    party_preset_options = wti_index.get_party_preset_options()

    treaty_source_options = wti_index.unique_sources

    gui = types.SimpleNamespace(
        treaty_sources=widgets_config.select_multiple(
            description='Sources',
            options=treaty_source_options,
            values=treaty_source_options,
            disabled=False,
            layout=lw('180px')),
        period_group_index=widgets_config.period_group_widget(
            index_as_value=True),
        topic_group_name=widgets_config.topic_groups_widget(value='7CULTURE'),
        # treaty_filter = widgets_config.treaty_filter_widget()
        recode_is_cultural=widgets_config.recode_7corr_widget(
            layout=lw('120px')),
        normalize_values=widgets_config.toggle('Display %',
                                               False,
                                               icon='',
                                               layout=lw('100px')),
        chart_type_name=widgets_config.dropdown('Output',
                                                config.CHART_TYPE_NAME_OPTIONS,
                                                "plot_stacked_bar",
                                                layout=lw('200px')),
        plot_style=widgets_config.plot_style_widget(),
        parties=widgets_config.parties_widget(options=[
            x for x in wti_index.get_countries_list() if x != 'ALL OTHER'
        ],
                                              value=['FRANCE']),
        party_preset=widgets_config.dropdown('Presets',
                                             party_preset_options,
                                             None,
                                             layout=lw(width='200px')),
        chart_per_category=widgets_config.toggle(
            'Chart per Qty',
            False,
            tooltip='Display one chart per selected quantity category',
            layout=lw()),
        extra_other_category=widgets_config.toggle('Add OTHER topics',
                                                   False,
                                                   layout=lw()),
        target_quantity=widgets_config.dropdown(
            'Quantity', ['topic', 'party', 'source', 'continent', 'group'],
            'topic',
            layout=lw(width='200px')),
        progress=widgets_config.progress(0, 5, 1, 0, layout=lw()))

    def stepper(step=None):
        gui.progress.value = gui.progress.value + 1 if step is None else step

    def on_party_preset_change(change):  # pylint: disable=W0613

        if gui.party_preset.value is None:
            return

        if 'ALL' in gui.party_preset.value:
            gui.parties.value = gui.parties.options
        else:
            gui.parties.value = gui.party_preset.value

        #if top_n_parties.value > 0:
        #    top_n_parties.value = 0

    gui.party_preset.observe(on_party_preset_change, names='value')

    itw = widgets.interactive(display_topic_quantity_groups,
                              period_group_index=gui.period_group_index,
                              topic_group_name=gui.topic_group_name,
                              parties=gui.parties,
                              recode_is_cultural=gui.recode_is_cultural,
                              normalize_values=gui.normalize_values,
                              extra_other_category=gui.extra_other_category,
                              chart_type_name=gui.chart_type_name,
                              plot_style=gui.plot_style,
                              chart_per_category=gui.chart_per_category,
                              target_quantity=gui.target_quantity,
                              progress=widgets.fixed(stepper),
                              wti_index=widgets.fixed(wti_index),
                              treaty_sources=gui.treaty_sources,
                              print_args=widgets.fixed(print_args))

    boxes = widgets.HBox([
        widgets.VBox([
            gui.period_group_index, gui.topic_group_name, gui.target_quantity,
            gui.party_preset
        ]),
        widgets.VBox([gui.treaty_sources, gui.parties]),
        widgets.VBox([
            gui.recode_is_cultural, gui.extra_other_category,
            gui.chart_per_category
        ]),
        widgets.VBox([gui.chart_type_name, gui.plot_style]),
        widgets.VBox([gui.normalize_values, gui.progress])
    ])
    display(widgets.VBox([boxes, itw.children[-1]]))
    itw.update()