Beispiel #1
0
def view_stats(
    sources: PathOrPathsOrDictOfStrList,
    references: PathOrPathsOrDictOfStrList,
    tags: Optional[PathOrPathsOrDictOfStrList] = None,
):
    _src = VizSeqDataSources(sources, text_merged=True)
    _ref = VizSeqDataSources(references, text_merged=True)
    _tags = None if tags is None else VizSeqDataSources(tags, text_merged=True)
    stats = VizSeqStats.get(_src, _ref, _tags)

    html = env.get_template('ipynb_stats.html').render(
        stats=stats.to_dict(formatting=True),
        enum_src_names_and_types=VizSeqDataPageView.get_enum(
            zip(_src.names, [t.name.title() for t in _src.data_types])),
        enum_ref_names=VizSeqDataPageView.get_enum(_ref.names))
    display(HTML(html))

    n_src_plots = len(_src.text_indices)
    n_plots = n_src_plots + _ref.n_sources
    fig, ax = plt.subplots(nrows=1, ncols=n_plots, figsize=(7 * n_plots, 5))
    for i, idx in enumerate(_src.text_indices):
        cur_ax = ax if n_plots == 1 else ax[i]
        name = _src.names[idx]
        cur_sent_lens = stats.src_lens[name]
        _ = cur_ax.hist(cur_sent_lens, density=True, bins=25)
        _ = cur_ax.axvline(x=np.mean(cur_sent_lens), color='red', linewidth=3)
        cur_ax.set_title(f'Source {name} Length')
    for i, idx in enumerate(_ref.text_indices):
        cur_ax = ax if n_plots == 1 else ax[n_src_plots + i]
        name = _ref.names[idx]
        cur_sent_lens = stats.ref_lens[name]
        _ = cur_ax.hist(cur_sent_lens, density=True, bins=25)
        _ = cur_ax.axvline(x=np.mean(cur_sent_lens), color='red', linewidth=3)
        cur_ax.set_title(f'Reference {name} Length')
    plt.show()
Beispiel #2
0
def view_examples(
    sources: PathOrPathsOrDictOfStrList,
    references: PathOrPathsOrDictOfStrList,
    hypothesis: Optional[PathOrPathsOrDictOfStrList] = None,
    metrics: Optional[List[str]] = None,
    query: str = '',
    page_sz: int = DEFAULT_PAGE_SIZE,
    page_no: int = DEFAULT_PAGE_NO,
    sorting: VizSeqSortingType = VizSeqSortingType.original,
    need_g_translate: bool = False,
    disable_alignment: bool = False,
):
    _src = VizSeqDataSources(sources)
    _ref = VizSeqDataSources(references)
    _hypo = VizSeqDataSources(hypothesis)
    if _hypo.n_sources == 0:
        metrics = None
    assert len(_src) == len(_ref)
    assert _hypo.n_sources == 0 or len(_ref) == len(_hypo)

    _need_g_translate = need_g_translate and _src.has_text
    view = VizSeqDataPageView.get(_src,
                                  _ref,
                                  _hypo,
                                  page_sz,
                                  page_no,
                                  metrics=metrics,
                                  query=query,
                                  sorting=sorting.value,
                                  need_lang_tags=_need_g_translate,
                                  disable_alignment=disable_alignment)

    google_translation = []
    if _need_g_translate:
        for i, s in enumerate(view.cur_src_text):
            google_translation.append(get_g_translate(s, view.trg_lang[i]))

    html = env.get_template('ipynb_view.html').render(
        enum_metrics=VizSeqDataPageView.get_enum(metrics),
        enum_models=VizSeqDataPageView.get_enum(_hypo.text_names),
        cur_idx=view.cur_idx,
        src=view.viz_src,
        ref=view.viz_ref,
        hypo=view.viz_hypo,
        enum_src_names_and_types=VizSeqDataPageView.get_enum(
            zip(_src.names, [t.name for t in _src.data_types])),
        enum_ref_names=list(enumerate(_ref.names)),
        sent_scores=view.viz_sent_scores,
        google_translation=google_translation,
        span_highlight_js=SPAN_HIGHTLIGHT_JS,
        total_examples=view.total_examples,
        n_samples=view.n_samples,
        n_cur_samples=view.n_cur_samples,
    )
    return HTML(html)