Ejemplo n.º 1
0
    def reduction_table(self, labels=None, vertical=False):
        """Table with steps of model reduction

        Parameters
        ----------
        labels : dict {str: str}
            Substitute new labels for predictors.
        vertical : bool
            Orient table vertically.
        """
        if not self._reduction_results:
            self.execute()
        if labels is None:
            labels = {}
        n_steps = len(self._reduction_results)
        # find terms
        terms = []
        for ress in self._reduction_results:
            terms.extend(term for term in ress.keys() if term not in terms)
        n_terms = len(terms)
        # cell content
        cells = {}
        for x in terms:
            for i, ress in enumerate(self._reduction_results):
                if x in ress:
                    res = ress[x]
                    pmin = res.p.min()
                    t_cell = fmtxt.stat(res.t.max(), stars=pmin)
                    p_cell = fmtxt.p(pmin)
                else:
                    t_cell = p_cell = ''
                cells[i, x] = t_cell, p_cell

        if vertical:
            t = fmtxt.Table('ll' + 'l' * n_terms)
            t.cells('Step', '')
            for x in terms:
                t.cell(labels.get(x, x))
            t.midrule()
            for i in range(n_steps):
                t_row = t.add_row()
                p_row = t.add_row()
                t_row.cells(i + 1, fmtxt.symbol('t', 'max'))
                p_row.cells('', fmtxt.symbol('p'))
                for x in terms:
                    t_cell, p_cell = cells[i, x]
                    t_row.cell(t_cell)
                    p_row.cell(p_cell)
        else:
            t = fmtxt.Table('l' + 'rr' * n_steps)
            t.cell()
            for _ in range(n_steps):
                t.cell(fmtxt.symbol('t', 'max'))
                t.cell(fmtxt.symbol('p'))
            t.midrule()
            for x in terms:
                t.cell(labels.get(x, x))
                for i in range(n_steps):
                    t.cells(*cells[i, x])
        return t
Ejemplo n.º 2
0
def test_table():
    table = fmtxt.Table('ll')
    table.cells('A', 'B')
    table.midrule()
    table.cells('a1', 'b1', 'a2', 'b2')
    eq_(str(table), 'A    B \n-------\na1   b1\na2   b2')
    eq_(table.get_html(), u'<figure><table rules="none" cellpadding="2" '
                          u'frame="hsides" border="1"><tr>\n'
                          u' <td>A</td>\n <td>B</td>\n</tr>\n<tr>\n'
                          u' <td>a1</td>\n <td>b1</td>\n</tr>\n<tr>\n'
                          u' <td>a2</td>\n <td>b2</td>\n</tr></table></figure>')
    eq_(table.get_rtf(), '\\trowd\n\\cellx0000\n\\cellx1000\n\\row\n'
                         'A\\intbl\\cell\nB\\intbl\\cell\n\\row\n'
                         'a1\\intbl\\cell\nb1\\intbl\\cell\n\\row\n'
                         'a2\\intbl\\cell\nb2\\intbl\\cell\n\\row')
    eq_(table.get_tex(), '\\begin{center}\n\\begin{tabular}{ll}\n\\toprule\n'
                         'A & B \\\\\n\\midrule\n'
                         'a1 & b1 \\\\\na2 & b2 \\\\\n'
                         '\\bottomrule\n\\end{tabular}\n\\end{center}')

    # empty table
    str(fmtxt.Table(''))

    # saving
    tempdir = TempDir()
    # HTML
    path = os.path.join(tempdir, 'test.html')
    table.save_html(path)
    eq_(open(path).read(),  '<!DOCTYPE html>\n<html>\n<head>\n'
                            '    <title>Untitled</title>\n'
                            '<style>\n\n.float {\n    float:left\n}\n\n'
                            '</style>\n</head>\n\n'
                            '<body>\n\n<figure>'
                            '<table rules="none" cellpadding="2" frame="hsides" '
                            'border="1"><tr>\n'
                            ' <td>A</td>\n <td>B</td>\n</tr>\n<tr>\n'
                            ' <td>a1</td>\n <td>b1</td>\n</tr>\n<tr>\n'
                            ' <td>a2</td>\n <td>b2</td>\n</tr>'
                            '</table></figure>\n\n</body>\n</html>\n')
    # rtf
    path = os.path.join(tempdir, 'test.rtf')
    table.save_rtf(path)
    eq_(open(path).read(), '{\\rtf1\\ansi\\deff0\n\n'
                           '\\trowd\n\\cellx0000\n\\cellx1000\n\\row\n'
                           'A\\intbl\\cell\nB\\intbl\\cell\n\\row\n'
                           'a1\\intbl\\cell\nb1\\intbl\\cell\n\\row\n'
                           'a2\\intbl\\cell\nb2\\intbl\\cell\n\\row\n}')
    # TeX
    path = os.path.join(tempdir, 'test.tex')
    table.save_tex(path)
    eq_(open(path).read(), '\\begin{center}\n\\begin{tabular}{ll}\n\\toprule\n'
                           'A & B \\\\\n\\midrule\n'
                           'a1 & b1 \\\\\na2 & b2 \\\\\n'
                           '\\bottomrule\n\\end{tabular}\n\\end{center}')
    # txt
    path = os.path.join(tempdir, 'test.txt')
    table.save_txt(path)
    eq_(open(path).read(), 'A    B \n-------\na1   b1\na2   b2')
Ejemplo n.º 3
0
 def clusters(self, p=0.05):
     """Table with significant clusters"""
     if self.test_type is LMGroup:
         raise NotImplementedError
     else:
         table = fmtxt.Table('lrrll')
         table.cells('Effect',
                     't-start',
                     't-stop',
                     fmtxt.symbol('p'),
                     'sig',
                     just='l')
         table.midrule()
         for key, res in self.items():
             table.cell(key)
             table.endline()
             clusters = res.find_clusters(p)
             clusters.sort('tstart')
             if self.test_type != anova:
                 clusters[:, 'effect'] = ''
             for effect, tstart, tstop, p_, sig in clusters.zip(
                     'effect', 'tstart', 'tstop', 'p', 'sig'):
                 table.cells(f'  {effect}', ms(tstart), ms(tstop),
                             fmtxt.p(p_), sig)
     return table
Ejemplo n.º 4
0
    def cv_info(self):
        if self._cv_results is None:
            raise ValueError(
                f"CV: no cross-validation was performed. Use mu='auto' to perform cross-validation."
            )
        cv_results = sorted(self._cv_results, key=attrgetter('mu'))
        criteria = ('cross-fit', 'l2/mu')
        best_mu = {criterion: self.cv_mu(criterion) for criterion in criteria}

        table = fmtxt.Table('lllll')
        table.cells('mu', 'cross-fit', 'l2-error', 'weighted l2-error',
                    'ES metric')
        table.midrule()
        fmt = '%.5f'
        for result in cv_results:
            table.cell(fmtxt.stat(result.mu, fmt=fmt))
            star = 1 if result.mu is best_mu['cross-fit'] else 0
            table.cell(fmtxt.stat(result.cross_fit, fmt, star, 1))
            star = 1 if result.mu is best_mu['l2/mu'] else 0
            table.cell(fmtxt.stat(result.l2_error, fmt, star, 1))
            table.cell(fmtxt.stat(result.weighted_l2_error, fmt=fmt))
            table.cell(fmtxt.stat(result.estimation_stability, fmt=fmt))
        # warnings
        mus = [res.mu for res in self._cv_results]
        warnings = []
        if self.mu == min(mus):
            warnings.append(f"Best mu is smallest mu")
        if warnings:
            table.caption(f"Warnings: {'; '.join(warnings)}")
        return table
Ejemplo n.º 5
0
def model_comparison_table(x1: Model,
                           x0: Model,
                           x1_name: str = 'x1',
                           x0_name: str = 'x0'):
    "Generate a table comparing the terms in two models"
    # find corresponding terms
    term_map = []
    x0_terms = list(x0.term_names)
    for x1_term in x1.term_names:
        if x1_term in x0_terms:
            target = x1_term
        else:
            rand = f'{x1_term}$'
            for x0_term in x0_terms:
                if x0_term.startswith(rand):
                    target = x0_term
                    break
            else:
                target = ''
        term_map.append((x1_term, target))
        if target:
            x0_terms.remove(target)
    for x0_term in x0_terms:
        term_map.append(('', x0_term))
    # format table
    table = fmtxt.Table('ll')
    table.cells(x1_name, x0_name)
    table.midrule()
    for x1_term, x0_term in term_map:
        table.cells(x1_term, x0_term)
    return table
Ejemplo n.º 6
0
 def term_table(self):
     "Table describing the structured model terms"
     table = fmtxt.Table('rrll')
     table.cells('#', 'dep', 'term', 'randomization')
     table.midrule()
     for i, term in enumerate(self.terms):
         dep = term.parent if term.parent >= 0 else ''
         table.cells(i, dep, term.string, f'${term.shuffle}')
     return table
Ejemplo n.º 7
0
 def table(self, title=None, caption=None):
     """Table with effects and smallest p-value"""
     if self.test_type is LMGroup:
         cols = sorted(
             {col
              for res in self.values() for col in res.column_names})
         table = fmtxt.Table('l' * (1 + len(cols)),
                             title=title,
                             caption=caption)
         table.cell('')
         table.cells(*cols)
         table.midrule()
         for key, lmg in self.items():
             table.cell(key)
             for res in (lmg.tests[c] for c in cols):
                 pmin = res.p.min()
                 table.cell(fmtxt.FMText([fmtxt.p(pmin), star(pmin)]))
     elif self.test_type is anova:
         table = fmtxt.Table('lllll', title=title, caption=caption)
         table.cells('Test', 'Effect',
                     fmtxt.symbol(self.test_type._statistic, 'max'),
                     fmtxt.symbol('p'), 'sig')
         table.midrule()
         for key, res in self.items():
             for i, effect in enumerate(res.effects):
                 table.cells(key, effect)
                 pmin = res.p[i].min()
                 table.cell(fmtxt.stat(res._max_statistic(i)))
                 table.cell(fmtxt.p(pmin))
                 table.cell(star(pmin))
                 key = ''
     else:
         table = fmtxt.Table('llll', title=title, caption=caption)
         table.cells('Effect',
                     fmtxt.symbol(self.test_type._statistic, 'max'),
                     fmtxt.symbol('p'), 'sig')
         table.midrule()
         for key, res in self.items():
             table.cell(key)
             pmin = res.p.min()
             table.cell(fmtxt.stat(res._max_statistic()))
             table.cell(fmtxt.p(pmin))
             table.cell(star(pmin))
     return table
Ejemplo n.º 8
0
    def report(
            self,
            brain_view: Union[str, Sequence[float]] = None,
            axw: float = None,
            surf: str = 'inflated',
            cortex: Any = ((1.00,) * 3, (.4,) * 3),
    ):
        doc = []

        # plot model-test results
        layout = BrainLayout(brain_view, axw)
        sp = plot.brain.SequencePlotter()
        sp.set_brain_args(mask=(0, 0, 0, 1))
        if layout.brain_view:
            sp.set_parallel_view(*layout.brain_view)
        sp.set_brain_args(surf=surf, cortex=cortex)
        # ROI overlay
        if self.masks:
            roi = self.masks[0] + self.masks[1]
            sp.add_ndvar_label(roi, color=(0, 1, 0), borders=2, overlay=True)
        # det data
        cmap = plot.soft_threshold_colormap('polar-lux-a', .2, 1)
        for label, term in self.terms.items():
            res = self.ress[term]
            diffs = [res.difference.sub(source=hemi) for hemi in ['lh', 'rh']]
            diffs = [diff / diff.max() for diff in diffs]
            diff = concatenate(diffs, 'source')
            sp.add_ndvar(diff, cmap=cmap, vmax=1, label=label)
        p = sp.plot_table(view='lateral', orientation='vertical', **layout.table_args)
        doc.append(p)

        # generate table
        t = fmtxt.Table('l' * (2 * len(self.terms) + 1))
        # header 1
        t.cell('')
        for text in ['Left', 'Right']:
            t.cell(f'{text} H', width=len(self.terms))
        # header 2
        t.cell('')
        for _ in range(2):
            t.cells(*self.terms.keys())
        t.midrule()
        for label, t1 in self.terms.items():
            t.cell(label)
            for hemi in ['lh', 'rh']:
                for t2 in self.terms.values():
                    if t1 == t2:
                        t.cell('')
                        continue
                    res = self.loc_ress[t1, t2, hemi].f_tests[0]
                    stars = fmtxt.Stars.from_p(res.p)
                    t.cell([stars, res._asfmtext()])
        doc.append(t)
        return fmtxt.FloatingLayout(doc)
Ejemplo n.º 9
0
 def show_jobs(self, trfs=False, width=None):
     if width is None:
         width = shutil.get_terminal_size((150, 20))[0] - 50
     pending_jobs = {job.path: job for job in self.server.pending_jobs()}
     priority = len({job.priority for job in self._user_jobs}) > 1
     t = fmtxt.Table('lllrrl' + 'l' * priority)
     t.cells("Exp.", "Epoch", "Model", "TRFs", "Pending")
     if priority:
         t.cell("Priority")
     t.cell('Report')
     t.midrule()
     for job in self._user_jobs:
         if job.trf_jobs is None:
             n_trfs = '<not'
             n_missing = 'initialized>'
             report = ''
         else:
             n_trfs = len(job.trf_jobs)
             n_missing = len(job.missing_trfs)
             if job.test_path:
                 if exists(job.test_path):
                     report = '\u2611'
                 else:
                     report = '\u2610'
             elif job.test_path is False:
                 report = '\u2612'
             else:
                 report = ''
         job_desc = job.model_name
         if len(job_desc) > width:
             job_desc = job_desc[:width - 3] + '...'
         t.cell(job.experiment.__class__.__name__)  # Exp
         t.cell(job.options.get('epoch', ''))  # Epoch
         t.cell(job_desc)  # Model
         t.cell(n_trfs)  # TRFs
         t.cell(n_missing)  # Pending
         if priority:
             t.cell(job.priority)
         t.cell(report)
         # TRFs currently being processed
         if trfs and job.trf_jobs:
             trf_jobs = [j for j in job.trf_jobs if j.path in pending_jobs]
             n = Counter(pending_jobs[j.path].worker or 'requested'
                         for j in trf_jobs)
             for worker in sorted(n):
                 t.cells('', '')
                 t.cell(self.server._worker_info.get(worker, worker),
                        just='r')
                 t.cells('', n[worker])
                 t.endline()
     return t
Ejemplo n.º 10
0
 def term_table(self) -> fmtxt.Table:
     show_stimulus = any(term.stimulus for term in self.terms)
     show_shuffle = any(term.shuffle_string for term in self.terms)
     t = fmtxt.Table('l' * (1 + show_stimulus + show_shuffle))
     if show_stimulus:
         t.cell('Stimulus')
     t.cell('Code')
     t.midrule()
     if show_shuffle:
         t.cell('Shuffle')
     for term in self.terms:
         if show_stimulus:
             t.cell(term.stimulus)
         t.cell(term.code)
         if show_shuffle:
             t.cell(term.shuffle_string)
     return t
Ejemplo n.º 11
0
 def table(self, t_start: float = None, t_stop: float = None):
     "fmtxt.Table representation"
     if t_start is None:
         t_start = self.tmin
     if t_stop is None:
         t_stop = self.realizations[-1].tstop
     table = fmtxt.Table('rrll')
     table.cells('#', 'Time', 'Word', 'Phone')
     for i, r in enumerate(self.realizations):
         if r.phones:
             word = r.graphs
             for time, phone in zip(r.times, r.phones):
                 if t_start <= time <= t_stop:
                     table.cells(i, f'{time:.3f}', word, phone)
                 i = word = ''
         # elif t_start <= r.tstop <= t_stop:
         #     table.cells(i, f'{time:.3f}', word, phone)
     return table
Ejemplo n.º 12
0
 def clusters(self, p=0.05):
     """Table with significant clusters"""
     if self.test_type is TestType.TWO_STAGE:
         raise NotImplementedError
     else:
         table = fmtxt.Table('lrrrrll')
         table.cells('Effect', 't-start', 't-stop', fmtxt.symbol(self._statistic, 'max'), fmtxt.symbol('t', 'peak'), fmtxt.symbol('p'), 'sig', just='l')
         table.midrule()
         for key, res in self.items():
             table.cell(key)
             table.endline()
             clusters = res.find_clusters(p, maps=True)
             clusters.sort('tstart')
             if self.test_type is not TestType.MULTI_EFFECT:
                 clusters[:, 'effect'] = ''
             for effect, tstart, tstop, p_, sig, cmap in clusters.zip('effect', 'tstart', 'tstop', 'p', 'sig', 'cluster'):
                 max_stat, max_time = res._max_statistic(mask=cmap != 0, return_time=True)
                 table.cells(f'  {effect}', ms(tstart), ms(tstop), fmtxt.stat(max_stat), ms(max_time), fmtxt.p(p_), sig)
     return table
Ejemplo n.º 13
0
    def remove_broken_jobs(self, pattern):
        """Re-queue jobs for which processing failed

        Parameters
        ----------
        pattern : int | str | list
            Job model or comparison, job path pattern, or one or more job IDs.

        Notes
        -----
        Move jobs back into the queue based on target filename pattern. Assumes
        that the corresponding jobs are not being worked on anymore. Otherwise
        they will be received as orphans and overwrite
        """
        # check if pattern is a model
        if isinstance(pattern, str):
            model_jobs = [
                job for job in self._user_jobs
                if job.trf_jobs and fnmatch.fnmatch(job.model_name, pattern)
            ]
        else:
            model_jobs = None
        # find TRF-job keys
        if model_jobs:
            keys = {
                trfjob.path
                for job in model_jobs for trfjob in job.trf_jobs
            }
            keys.intersection_update(job.path
                                     for job in self.server.pending_jobs())
            keys = list(keys)
        else:
            keys = self.server.find_jobs(pattern, 'pending')

        if not keys:
            print("No jobs match pattern")
            return
        prefix = commonprefix(keys)
        t = fmtxt.Table('lll')
        t.cells("Job", "Worker", "Orphan")
        t.midrule()
        t.caption("Common prefix: %r" % (prefix, ))
        n_prefix = len(prefix)
        for key in keys:
            desc = key[n_prefix:]
            desc = desc if len(desc) < 100 else desc[:97] + '...'
            orphan = '' if key in self._trf_jobs else 'x'
            t.cells(desc, self.server._jobs[key].worker, orphan)
        print(t)
        command = ask(f"Remove {len(keys)} jobs?", {
            'requeue': 'requeue jobs',
            'drop': 'drop jobs',
            'abort': "don't do anything (default)"
        },
                      allow_empty=True)
        if command in ('requeue', 'drop'):
            n_skipped = n_restarted = 0
            self.server.remove_broken_job(keys)
            for key in keys:
                if key in self._trf_jobs:
                    if command == 'requeue':
                        self._trf_job_queue.appendleft(key)
                        n_restarted += 1
                    else:
                        pass  # FIXME: remove job properly
                else:
                    n_skipped += 1
            print(f"{n_restarted} restarted, {n_skipped} skipped")
Ejemplo n.º 14
0
from eelbrain import fmtxt

table = fmtxt.Table('lll')
table.cell("Animal", r"\textbf")
table.cell("Outside", r"\textbf")
table.cell("Inside", r"\textbf")
table.midrule()
table.cell("Duck")
table.cell("Feathers")
table.cell("Duck Meat")
table.cell("Dog")
table.cell("Fur")
table.cell("Hotdog Meat")

# print the string representation
print(table)
Ejemplo n.º 15
0
def test_table():
    table = fmtxt.Table('ll')
    table.cells('A', 'B')
    table.midrule()
    table.cells('a1', 'b1', 'a2', 'b2')
    assert str(table) == 'A    B \n-------\na1   b1\na2   b2'
    assert html(table) == (
        '<figure><table border="1" cellpadding="2" frame="hsides" rules="none"><tr>\n'
        ' <td>A</td>\n <td>B</td>\n</tr>\n<tr>\n'
        ' <td>a1</td>\n <td>b1</td>\n</tr>\n<tr>\n'
        ' <td>a2</td>\n <td>b2</td>\n</tr></table></figure>')
    assert table.get_rtf() == ('\\trowd\n\\cellx0000\n\\cellx1000\n\\row\n'
                               'A\\intbl\\cell\nB\\intbl\\cell\n\\row\n'
                               'a1\\intbl\\cell\nb1\\intbl\\cell\n\\row\n'
                               'a2\\intbl\\cell\nb2\\intbl\\cell\n\\row')
    assert table.get_tex() == (
        '\\begin{center}\n\\begin{tabular}{ll}\n\\toprule\n'
        'A & B \\\\\n\\midrule\n'
        'a1 & b1 \\\\\na2 & b2 \\\\\n'
        '\\bottomrule\n\\end{tabular}\n\\end{center}')

    # empty table
    str(fmtxt.Table(''))

    # saving
    tempdir = TempDir()
    # HTML
    path = os.path.join(tempdir, 'test.html')
    table.save_html(path)
    assert open(path).read() == (
        '<!DOCTYPE html>\n<html>\n<head>\n'
        '    <title>Untitled</title>\n'
        '<style>\n\n.float {\n    float:left\n}\n\n'
        '</style>\n</head>\n\n'
        '<body>\n\n<figure>'
        '<table border="1" cellpadding="2" frame="hsides" rules="none"><tr>\n'
        ' <td>A</td>\n <td>B</td>\n</tr>\n<tr>\n'
        ' <td>a1</td>\n <td>b1</td>\n</tr>\n<tr>\n'
        ' <td>a2</td>\n <td>b2</td>\n</tr>'
        '</table></figure>\n\n</body>\n</html>\n')
    # rtf
    path = os.path.join(tempdir, 'test.rtf')
    table.save_rtf(path)
    assert open(path).read() == ('{\\rtf1\\ansi\\deff0\n\n'
                                 '\\trowd\n\\cellx0000\n\\cellx1000\n\\row\n'
                                 'A\\intbl\\cell\nB\\intbl\\cell\n\\row\n'
                                 'a1\\intbl\\cell\nb1\\intbl\\cell\n\\row\n'
                                 'a2\\intbl\\cell\nb2\\intbl\\cell\n\\row\n}')
    # TeX
    path = os.path.join(tempdir, 'test.tex')
    table.save_tex(path)
    assert open(path).read() == (
        '\\begin{center}\n\\begin{tabular}{ll}\n\\toprule\n'
        'A & B \\\\\n\\midrule\n'
        'a1 & b1 \\\\\na2 & b2 \\\\\n'
        '\\bottomrule\n\\end{tabular}\n\\end{center}')
    # txt
    path = os.path.join(tempdir, 'test.txt')
    table.save_txt(path)
    assert open(path).read() == 'A    B \n-------\na1   b1\na2   b2'

    # editing
    table[0, 0] = 'X'
    assert str(table) == 'X    B \n-------\na1   b1\na2   b2'
    table[0] = ['C', 'D']
    assert str(table) == 'C    D \n-------\na1   b1\na2   b2'
    table[2, 0] = 'cd'
    assert str(table) == 'C    D \n-------\ncd   b1\na2   b2'
    table[2:4, 1] = ['x', 'y']
    assert str(table) == 'C    D\n------\ncd   x\na2   y'
Ejemplo n.º 16
0
from eelbrain import fmtxt
table = fmtxt.Table('lll')
table.cell("Animal", r"\textbf")
table.cell("Outside", r"\textbf")
table.cell("Inside", r"\textbf")
table.midrule()
table.cell("Duck")
table.cell("Feathers")
table.cell("Duck Meat")
table.cell("Dog")
table.cell("Fur")
table.cell("Hotdog Meat")
# print the string representation
print table
# save the table as pdf
table.save_pdf('table.pdf')

Ejemplo n.º 17
0
def source_trfs(
        ress: ResultCollection,
        heading: FMTextArg = None,
        brain_view: Union[str, Sequence[float]] = None,
        axw: float = None,
        surf: str = 'inflated',
        cortex: Any = ((1.00,) * 3, (.4,) * 3),
        vmax: float = None,
        xlim: Tuple[float, float] = None,
        times: Sequence[float] = None,
        cmap: str = None,
        labels: Dict[str, str] = None,
        rasterize: bool = None,
        brain_timewindow: float = 0.050
):
    "Only used for TRFExperiment model-test"
    layout = BrainLayout(brain_view, axw)
    dt = brain_timewindow / 2

    if heading is not None:
        doc = fmtxt.Section(heading)
    else:
        doc = fmtxt.FMText()

    if cmap is None:
        cmap = 'lux-a'

    if labels is None:
        labels = {}

    trf_table = fmtxt.Table('ll')
    for key, res in ress.items():
        trf_resampled = resample(res.masked_difference(), 1000)
        label = labels.get(key, key)
        if rasterize is None:
            rasterize = len(trf_resampled.source) > 500
        # times for anatomical plots
        if times is None:
            trf_tc = abs(trf_resampled).sum('source')
            trf_tc_mask = (~trf_resampled.get_mask()).sum('source') >= 10
            times_ = find_peak_times(trf_tc, trf_tc_mask)
        else:
            times_ = times
        # butterfly-plot
        p = plot.Butterfly(trf_resampled, h=3, w=4, ylabel=False, title=label, vmax=vmax, xlim=xlim, show=False)
        for t in times_:
            p.add_vline(t, color='k')
        trf_table.cell(fmtxt.asfmtext(p, rasterize=rasterize))
        p.close()
        # peak sources
        if not times_:
            trf_table.cell()
        sp = plot.brain.SequencePlotter()
        if layout.brain_view:
            sp.set_parallel_view(*layout.brain_view)
        sp.set_brain_args(surf=surf, cortex=cortex)
        for t in times_:
            yt = trf_resampled.mean(time=(t - dt, t + dt + 0.001))
            if isinstance(cmap, str):
                vmax_ = vmax or max(-yt.min(), yt.max()) or 1
                cmap_ = plot.soft_threshold_colormap(cmap, vmax_ / 10, vmax_)
            else:
                cmap_ = cmap
            sp.add_ndvar(yt, cmap=cmap_, label=f'{t * 1000:.0f} ms', smoothing_steps=10)
        p = sp.plot_table(view='lateral', orientation='vertical', **layout.table_args)
        trf_table.cell(p)
        p.close()
    doc.append(fmtxt.Figure(trf_table))
    return doc
Ejemplo n.º 18
0
def gentle_to_grid(gentle_file, out_file=None):
    "Convert *.json file from Gentle to Praat TextGrid"
    if '*' in gentle_file:
        if out_file is not None:
            raise TypeError("out can not be set during batch-conversion")
        for filename in glob(gentle_file):
            gentle_to_grid(filename)
        return

    gentle_file = Path(gentle_file)
    if out_file is None:
        out_file = gentle_file.with_suffix('.TextGrid')
    else:
        out_file = Path(out_file)
        if out_file.suffix.lower() != '.textgrid':
            out_file = out_file.with_suffix('.TextGrid')

    with gentle_file.open() as fid:
        g = json.load(fid)

    # find valid words
    words = g['words']
    n_issues = 0
    for i, word in enumerate(words):
        if word['case'] == 'success':
            if word['alignedWord'] == '<unk>':
                n_issues += 1
                word['issue'] = 'OOV'
            else:
                word['issue'] = None
        else:
            n_issues += 1
            word['issue'] = word['case']

    # add missing times
    last_end = 0
    not_in_audio_words = []  # buffer
    for word in words:
        if 'start' in word:
            if not_in_audio_words:
                duration = word['start'] - last_end
                for j, word_ in enumerate(not_in_audio_words):
                    word_['start'] = last_end + j * duration
                    word_['end'] = last_end + (j + 1) * duration
                not_in_audio_words = []
            last_end = word['end']
        else:
            not_in_audio_words.append(word)
    for word in not_in_audio_words:
        word['start'] = last_end
        word['end'] = last_end = last_end + 0.100

    # round times
    for word in words:
        word['start'] = round(word['start'], 3)
        word['end'] = round(word['end'], 3)

    # avoid overlapping words
    last_start = words[-1]['end'] + 1
    for word in reversed(words):
        if word['end'] > last_start:
            word['end'] = last_start
        if word['start'] >= word['end']:
            word['start'] = word['end'] - .001
        last_start = word['start']
        # gentle seems to work at 10 ms resolution
        if word['end'] - word['start'] < 0.015 and 'issue' not in word:
            word['issue'] = 'short'

    # log issues
    if n_issues:
        log = fmtxt.Table('rrrll')
        log.cell('Time')
        log.cell('Duration', width=2)
        log.cells('Word', 'Issue')
        log.midrule()
        for word in words:
            if word['issue']:
                duration = word['end'] - word['start']
                d_marker = '*' if duration < 0.015 else ''
                log.cells(f"{word['start']:.3f}", d_marker, f"{duration:.3f}",
                          word['word'], word['issue'])
        print(log)
        log.save_tsv(out_file.with_suffix('.log'))

    # build textgrid
    phone_tier = textgrid.IntervalTier('phones')
    word_tier = textgrid.IntervalTier('words')
    for i, word in enumerate(words):
        t = word['start']
        word_tstop = word['end']
        # add word
        word_tier.add(t, word_tstop, word['word'])
        # make sure we have at least one phone
        phones = word.get('phones', ())
        if not phones:
            phones = ({'phone': '', 'duration': word['end'] - word['start']}, )
        # add phones
        for phone in phones:
            tstop = min(round(t + phone['duration'], 3), word_tstop)
            if t >= tstop:
                continue
            mark = phone['phone'].split('_')[0].upper()
            if mark == 'OOV':
                continue
            phone_tier.add(t, tstop, mark)
            t = tstop
    grid = textgrid.TextGrid()
    grid.extend((phone_tier, word_tier))
    grid.write(out_file)