Esempio n. 1
0
 def _document_fe_graph(self) -> None:
     """Add FE execution graphs into the traceability document.
     """
     with self.doc.create(Section("FastEstimator Architecture")):
         for mode in self.system.pipeline.data.keys():
             scheduled_items = self.system.pipeline.get_scheduled_items(
                 mode) + self.system.network.get_scheduled_items(
                     mode) + self.system.traces
             signature_epochs = get_signature_epochs(
                 scheduled_items,
                 total_epochs=self.system.epoch_idx,
                 mode=mode)
             epochs_with_data = self.system.pipeline.get_epochs_with_data(
                 total_epochs=self.system.epoch_idx, mode=mode)
             if set(signature_epochs) & epochs_with_data:
                 self.doc.append(NoEscape(r'\FloatBarrier'))
                 with self.doc.create(Subsection(mode.capitalize())):
                     for epoch in signature_epochs:
                         if epoch not in epochs_with_data:
                             continue
                         self.doc.append(NoEscape(r'\FloatBarrier'))
                         with self.doc.create(
                                 Subsubsection(
                                     f"Epoch {epoch}",
                                     label=Label(
                                         Marker(name=f"{mode}{epoch}",
                                                prefix="ssubsec")))):
                             ds_ids = self.system.pipeline.get_ds_ids(
                                 epoch=epoch, mode=mode)
                             for ds_id in ds_ids:
                                 with NonContext(
                                 ) if ds_id == '' else self.doc.create(
                                         Paragraph(
                                             f"Dataset {ds_id}",
                                             label=Label(
                                                 Marker(
                                                     name=
                                                     f"{mode}{epoch}{ds_id}",
                                                     prefix="para")))):
                                     diagram = self._draw_diagram(
                                         mode, epoch, ds_id)
                                     ltx = d2t.dot2tex(diagram.to_string(),
                                                       figonly=True)
                                     args = Arguments(
                                         **{
                                             'max width':
                                             r'\textwidth, max height=0.9\textheight'
                                         })
                                     args.escape = False
                                     with self.doc.create(Center()):
                                         with self.doc.create(
                                                 AdjustBox(arguments=args)
                                         ) as box:
                                             box.append(NoEscape(ltx))
    def render_table(self,
                     doc: Document,
                     name_override: Optional[LatexObject] = None,
                     toc_ref: Optional[str] = None,
                     extra_rows: Optional[List[Tuple[str, Any]]] = None) -> None:
        """Write this table into a LaTeX document.

        Args:
            doc: The LaTeX document to be appended to.
            name_override: An optional replacement for this table's name field.
            toc_ref: A reference to be added to the table of contents.
            extra_rows: Any extra rows to be added to the table before the kwargs.
        """
        with doc.create(Table(position='htp!')) as table:
            table.append(NoEscape(r'\refstepcounter{table}'))
            table.append(Label(Marker(name=str(self.fe_id), prefix="tbl")))
            if toc_ref:
                table.append(NoEscape(r'\addcontentsline{toc}{subsection}{' + escape_latex(toc_ref) + '}'))
            with doc.create(Tabularx('|lX|', booktabs=True)) as tabular:
                package = Package('xcolor', options='table')
                if package not in tabular.packages:
                    # Need to invoke a table color before invoking TextColor (bug?)
                    tabular.packages.append(package)
                package = Package('seqsplit')
                if package not in tabular.packages:
                    tabular.packages.append(package)
                tabular.add_row((name_override if name_override else bold(self.name),
                                 MultiColumn(size=1, align='r|', data=TextColor('blue', self.fe_id))))
                tabular.add_hline()
                type_str = f"{self.type}"
                match = re.fullmatch(r'^<.* \'(?P<typ>.*)\'>$', type_str)
                type_str = match.group("typ") if match else type_str
                tabular.add_row(("Type: ", escape_latex(type_str)))
                if self.path:
                    if isinstance(self.path, LatexObject):
                        tabular.add_row(("", self.path))
                    else:
                        tabular.add_row(("", escape_latex(self.path)))
                for k, v in self.fields.items():
                    tabular.add_hline()
                    tabular.add_row((f"{k.capitalize()}: ", v))
                if self.args:
                    tabular.add_hline()
                    tabular.add_row(("Args: ", self.args))
                if extra_rows:
                    for (key, val) in extra_rows:
                        tabular.add_hline()
                        tabular.add_row(key, val)
                if self.kwargs:
                    tabular.add_hline()
                    for idx, (kwarg, val) in enumerate(self.kwargs.items()):
                        tabular.add_row((italic(kwarg), val), color='white' if idx % 2 else 'black!5')
Esempio n. 3
0
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.resource_dir,
                                "{}_{}.pdf".format(self.report_name,
                                                   model.model_name))
                            dot = tf.keras.utils.model_to_dot(
                                model, show_shapes=True, expand_nested=True)
                            # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than
                            # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so
                            # set the limit lower (100 inches) to leave some wiggle room.
                            dot.set('size', '100')
                            dot.write(file_path, format='pdf')
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(
                                    pms.summary(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs,
                                        print_summary=False)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))
                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less
                                # than 226 inches. However, the 'size' parameter doesn't account for the whole node
                                # height, so set the limit lower (100 inches) to leave some wiggle room.
                                graph.attr(size="100,100")
                                graph.attr(margin='0')
                                file_path = graph.render(
                                    filename="{}_{}".format(
                                        self.report_name, model.model_name),
                                    directory=self.resource_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            file_path = None
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.figure_dir,
                                f"FE_Model_{model.model_name}.pdf")
                            tf.keras.utils.plot_model(model,
                                                      to_file=file_path,
                                                      show_shapes=True,
                                                      expand_nested=True)
                            # TODO - cap output image size like in the pytorch implementation in case of huge network
                            # TODO - save raw .dot file in case system lacks graphviz
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(pms.summary(model, inputs)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))

                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(model, inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                graph.attr(
                                    size="200,200"
                                )  # LaTeX \maxdim is around 575cm (226 inches)
                                graph.attr(margin='0')
                                # TODO - save raw .dot file in case system lacks graphviz
                                file_path = graph.render(
                                    filename=f"FE_Model_{model.model_name}",
                                    directory=self.figure_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))