def _get_scores(dir_path: str, metric: str, model: str): hypo = _get_hypo(dir_path, [model]) ref = _get_ref(dir_path) tag = _get_tag(dir_path) return get_scorer(metric)(corpus_level=True, sent_level=True).score( hypo.data[0].text, ref.text, tags=tag.text )
def get_scores( sources: PathOrPathsOrDictOfStrList, references: PathOrPathsOrDictOfStrList, model_to_hypotheses: PathOrPathsOrDictOfStrList, metrics: List[str], tags: Optional[PathOrPathsOrDictOfStrList] = None, verbose: bool = False, problem: str = None, ) -> Tuple[Dict, Dict]: # Copyright (c) Facebook, Inc. and its affiliates. # The code in this function is licensed under the MIT license. _srcs = VizSeqDataSources(sources) _refs = VizSeqDataSources(references) _hypos = VizSeqDataSources(model_to_hypotheses) _tags, tag_set = None, [] if tags is not None: _tags = VizSeqDataSources(tags, text_merged=True) tag_set = sorted(_tags.unique()) _tags = _tags.text models = _hypos.names all_metrics = get_scorer_ids() _metrics = [] for s in metrics: if s in all_metrics: _metrics.append(s) else: logger.warning(f'"{s}" is not a valid metric.') def scorer_kwargs(s): kwargs = {"corpus_level": True, "sent_level": False, "verbose": verbose} if s in ( "kendall_task_ranking", "req_cov", "essential_req_cov", "achievement", "granularity", ): # ProcGenScorer's kwargs["extra_args"] = {"problem": problem} return kwargs scores = { s: { m: get_scorer(s)(**scorer_kwargs(s)).score( _hypos.data[i].text, _refs.text, tags=_tags, sources=_srcs.text ) for i, m in enumerate(models) } for s in _metrics } corpus_scores = { s: {m: scores[s][m].corpus_score for m in models} for s in _metrics } group_scores = { s: {t: {m: scores[s][m].group_scores[t] for m in models} for t in tag_set} for s in _metrics } return corpus_scores, group_scores
def view_scores(references: PathOrPathsOrDictOfStrList, hypothesis: Optional[PathOrPathsOrDictOfStrList], metrics: List[str], tags: Optional[PathOrPathsOrDictOfStrList] = None): _ref = VizSeqDataSources(references) _hypo = VizSeqDataSources(hypothesis) _tags, tag_set = None, [] if tags is not None: _tags = VizSeqDataSources(tags, text_merged=True) tag_set = sorted(_tags.unique()) _tags = _tags.text models = _hypo.names all_metrics = get_scorer_ids() _metrics = [] for s in metrics: if s in all_metrics: _metrics.append(s) else: logger.warn(f'"{s}" is not a valid metric.') scores = { s: { m: get_scorer(s)(corpus_level=True, sent_level=False).score(_hypo.data[i].text, _ref.text, tags=_tags) for i, m in enumerate(models) } for s in _metrics } corpus_scores = { s: {m: scores[s][m].corpus_score for m in models} for s in _metrics } group_scores = { s: { t: {m: scores[s][m].group_scores[t] for m in models} for t in tag_set } for s in _metrics } metrics_and_names = [[s, get_scorer_name(s)] for s in _metrics] html = env.get_template('ipynb_scores.html').render( metrics_and_names=metrics_and_names, models=models, tag_set=tag_set, corpus_scores=corpus_scores, group_scores=group_scores, corpus_and_group_score_latex=VizSeqWebView.latex_corpus_group_scores( corpus_scores, group_scores), corpus_and_group_score_csv=VizSeqWebView.csv_corpus_group_scores( corpus_scores, group_scores), ) return HTML(html)
def get( cls, src: VizSeqDataSources, ref: VizSeqDataSources, hypo: VizSeqDataSources, page_sz: int, page_no: int, metrics: Optional[List[str]] = None, query: str = '', sorting: int = 0, sorting_metric: str = '', need_lang_tags: bool = False, disable_alignment: bool = False, ) -> VizSeqPageData: assert page_no > 0 and page_sz > 0 page_sz = min(page_sz, MAX_PAGE_SZ) metrics = [] if metrics is None else metrics models = hypo.text_names # query cur_idx = list(range(len(src))) if src.has_text: cur_idx = VizSeqFilter.filter(src.text, query) elif ref.has_text: cur_idx = VizSeqFilter.filter(ref.text, query) n_samples = len(cur_idx) # sorting sorting = {e.value: e for e in VizSeqSortingType}.get(sorting, None) assert sorting is not None if sorting == VizSeqSortingType.random: cur_idx = VizSeqRandomSorter.sort(cur_idx) elif sorting == VizSeqSortingType.ref_len: cur_idx = VizSeqByLenSorter.sort(ref.main_text, cur_idx) elif sorting == VizSeqSortingType.ref_alphabetical: cur_idx = VizSeqByStrOrderSorter.sort(ref.main_text, cur_idx) elif sorting == VizSeqSortingType.src_len: if src.has_text: cur_idx = VizSeqByLenSorter.sort(src.main_text, cur_idx) elif sorting == VizSeqSortingType.src_alphabetical: if src.has_text: cur_idx = VizSeqByStrOrderSorter.sort(src.main_text, cur_idx) elif sorting == VizSeqSortingType.metric: if sorting_metric in get_scorer_ids(): _cur_ref = [_select(t, cur_idx) for t in ref.text] scores = { m: get_scorer(sorting_metric)(corpus_level=False, sent_level=True).score( _select(t, cur_idx), _cur_ref).sent_scores for m, t in zip(models, hypo.text) } scores = [{m: scores[m][i] for m in models} for i in range(len(cur_idx))] cur_idx = VizSeqByMetricSorter.sort(scores, cur_idx) # pagination start_idx, end_idx = _get_start_end_idx(len(cur_idx), page_sz, page_no) cur_idx = cur_idx[start_idx:end_idx + 1] n_cur_samples = len(cur_idx) # page data cur_src = src.cached(cur_idx) cur_src_text = _select(src.main_text, cur_idx) if src.has_text else None cur_ref = [_select(t, cur_idx) for t in ref.text] cur_hypo = {n: _select(t, cur_idx) for n, t in zip(models, hypo.text)} # sent scores cur_sent_scores = { s: { m: np.round(get_scorer(s)(corpus_level=False, sent_level=True).score( hh, cur_ref).sent_scores, decimals=2) for m, hh in cur_hypo.items() } for s in metrics } # rendering viz_src = cur_src if not disable_alignment: viz_src = VizSeqSrcVisualizer.visualize(cur_src, src.text_indices) viz_ref = cur_ref if not disable_alignment and cur_src_text is not None: viz_ref = VizSeqRefVisualizer.visualize(cur_src_text, cur_ref, src.main_text_idx) viz_hypo = cur_hypo if not disable_alignment: viz_hypo = VizSeqHypoVisualizer.visualize(cur_ref[0], cur_hypo, 0) viz_sent_scores = [{ s: VizSeqDictVisualizer.visualize( {m: cur_sent_scores[s][m][i] for m in models}) for s in metrics } for i in range(n_cur_samples)] trg_lang = None if need_lang_tags: trg_lang = [VizSeqLanguageTagger.tag_lang(r) for r in cur_ref[0]] return VizSeqPageData( viz_src=viz_src, viz_ref=viz_ref, viz_hypo=viz_hypo, cur_src=cur_src, cur_src_text=cur_src_text, cur_ref=cur_ref, cur_idx=cur_idx, viz_sent_scores=viz_sent_scores, trg_lang=trg_lang, n_cur_samples=n_cur_samples, n_samples=n_samples, total_examples=len(src), )