Ejemplo n.º 1
0
    def _run_epoch(self) -> None:
        """A method to perform an epoch of activity.

        This method requires that the current mode and epoch already be specified within the self.system object.
        """
        traces = get_current_items(self.traces_in_use,
                                   run_modes=self.system.mode,
                                   epoch=self.system.epoch_idx)
        trace_input_keys = set()
        for trace in traces:
            trace_input_keys.update(trace.inputs)
        loader = self._configure_loader(
            self.pipeline.get_loader(self.system.mode, self.system.epoch_idx))
        iterator = iter(loader)
        self.network.load_epoch(mode=self.system.mode,
                                epoch=self.system.epoch_idx,
                                output_keys=trace_input_keys)
        self.system.batch_idx = None
        with Suppressor():
            batch = next(iterator)
        traces = self._sort_traces(
            traces,
            available_outputs=to_set(batch.keys())
            | self.network.get_all_output_keys(self.system.mode,
                                               self.system.epoch_idx))
        self._run_traces_on_epoch_begin(traces=traces)
        while True:
            try:
                if self.system.mode == "train":
                    self.system.update_global_step()
                self.system.update_batch_idx()
                batch = self._configure_tensor(loader, batch)
                self._run_traces_on_batch_begin(batch, traces=traces)
                batch, prediction = self.network.run_step(batch)
                self._run_traces_on_batch_end(batch, prediction, traces=traces)
                if isinstance(loader, DataLoader) and (
                    (self.system.batch_idx
                     == self.system.max_train_steps_per_epoch
                     and self.system.mode == "train") or
                    (self.system.batch_idx
                     == self.system.max_eval_steps_per_epoch
                     and self.system.mode == "eval")):
                    raise StopIteration
                with Suppressor():
                    batch = next(iterator)
            except StopIteration:
                break
        self._run_traces_on_epoch_end(traces=traces)
        self.network.unload_epoch()
Ejemplo n.º 2
0
    def on_epoch_end(self, state):
        if state['epoch'] % self.im_freq != 0:
            return
        if state['epoch'] == 0:
            self.data = tf.concat(self.data, axis=0)
            self.data = self.data[:self.n_inputs]

        with Suppressor():
            fig = plot_caricature(self.model,
                                  self.data,
                                  self.layer_ids,
                                  decode_dictionary=self.decode_dictionary,
                                  n_steps=self.n_steps,
                                  learning_rate=self.learning_rate,
                                  blur=self.blur,
                                  cossim_pow=self.cossim_pow,
                                  sd=self.sd,
                                  fft=self.fft,
                                  decorrelate=self.decorrelate,
                                  sigmoid=self.sigmoid)
        # TODO - Figure out how to get this to work without it displaying the figure. maybe fig.canvas.draw
        plt.draw()
        plt.pause(0.000001)
        state.maps[1][self.output_key] = np.fromstring(
            fig.canvas.tostring_rgb(), dtype=np.uint8,
            sep='').reshape((1, ) + fig.canvas.get_width_height()[::-1] +
                            (3, ))
        plt.close(fig)
Ejemplo n.º 3
0
    def load_and_transform(self, batches=None):
        if batches is None:
            batches = self.idx
        data = [None for _ in range(self.num_layers)]
        classes = []

        if batches == 0:
            return data, classes

        for layer in trange(self.num_layers,
                            desc='Computing UMaps',
                            unit='layer'):
            layer_data = []
            for batch in trange(batches,
                                desc='Loading Cache',
                                unit='batch',
                                leave=False):
                dat = np.load(os.path.join(
                    self.root_path,
                    "layer{}-batch{}.npy".format(self.layers[layer], batch)),
                              allow_pickle=True)
                layer_data.append(dat)
            layer_data = np.concatenate(layer_data, axis=0)
            with Suppressor():  # Silence a bunch of numba warnings
                data[layer] = self.fit.fit_transform(layer_data)

        for batch in range(batches):
            clazz = np.load(os.path.join(self.root_path,
                                         "class{}.npy".format(batch)),
                            allow_pickle=True)
            classes.extend(clazz)

        return data, classes
Ejemplo n.º 4
0
    def _warmup(self, eager: bool = True) -> None:
        """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later.

        Traces are not executed in the warmup since they are likely to contain state variables which could become
        corrupted by running extra steps.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"})
        sort_traces(all_traces)  # This ensures that the traces can sort properly for on_begin and on_end
        monitor_names = self.monitor_names
        for mode in self.pipeline.get_modes() - {"test"}:
            scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items(
                mode) + self.get_scheduled_items(mode)
            signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode)
            epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode)
            for epoch in signature_epochs:
                if epoch not in epochs_with_data:
                    continue
                network_output_keys = self.network.get_all_output_keys(mode, epoch)
                network_input_keys = self.network.get_effective_input_keys(mode, epoch)
                trace_input_keys = set()
                trace_output_keys = {"*"}
                traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch)
                for idx, trace in enumerate(traces):
                    if idx > 0:  # ignore TrainEssential and EvalEssential's inputs for unmet requirement checking
                        trace_input_keys.update(trace.inputs)
                    trace_output_keys.update(trace.outputs)
                # key checking
                loader = self._configure_loader(
                    self.pipeline.get_loader(mode,
                                             epoch,
                                             output_keys=trace_input_keys - network_output_keys | network_input_keys))
                with Suppressor():
                    if isinstance(loader, tf.data.Dataset):
                        batch = list(loader.take(1))[0]
                    else:
                        batch = next(iter(loader))
                batch = self._configure_tensor(loader, batch)
                assert isinstance(batch, dict), "please make sure data output format is dictionary"
                pipeline_output_keys = to_set(batch.keys())

                monitor_names = monitor_names - (pipeline_output_keys | network_output_keys)
                unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys | trace_output_keys)
                assert not unmet_requirements, \
                    "found missing key(s) during epoch {} mode {}: {}".format(epoch, mode, unmet_requirements)
                sort_traces(traces, available_outputs=pipeline_output_keys | network_output_keys)
                trace_input_keys.update(traces[0].inputs)
                self.network.load_epoch(mode, epoch, output_keys=trace_input_keys, warmup=True, eager=eager)
                self.network.run_step(batch)
                self.network.unload_epoch()
        assert not monitor_names, "found missing key(s): {}".format(monitor_names)
Ejemplo n.º 5
0
    def plot_umap(self,
                  data,
                  labels=None,
                  legend_loc='best',
                  title=None,
                  fig_ax=None):
        color_list = self._map_classes_to_colors(labels)
        if self.legend_elems is None and color_list is not None:
            self.legend_elems = [
                Line2D([0], [0],
                       marker='o',
                       color='w',
                       markerfacecolor=self.color_dict[clazz],
                       label=clazz
                       if self.label_dict is None else self.label_dict[clazz],
                       markersize=7) for clazz in self.color_dict
            ]
        with Suppressor():  # Silence a bunch of numba warnings
            points = self.fit.fit_transform(data)

        if not fig_ax:
            fig = plt.figure(dpi=96)
            ax = fig.add_subplot(111)
        else:
            fig, ax = fig_ax
        ax.set_yticks([], [])
        ax.set_yticklabels([])
        ax.set_xticks([], [])
        ax.set_xticklabels([])
        if title:
            ax.set_title(title)
        if self.n_components == 1:
            ax.scatter(points[:, 0],
                       range(len(points)),
                       c=color_list or 'b',
                       s=3)
        if self.n_components == 2:
            ax.scatter(points[:, 0], points[:, 1], c=color_list or 'b', s=3)
        if self.n_components == 3:
            ax.scatter(points[:, 0],
                       points[:, 1],
                       points[:, 2],
                       c=color_list or 'b',
                       s=3)
        if self.legend_elems and legend_loc != 'off':
            ax.legend(handles=self.legend_elems,
                      loc=legend_loc,
                      fontsize='small')
        plt.tight_layout()
        return fig
def search_max_lr(pipeline, model, network, epochs):
    traces = [
        Accuracy(true_key="y", pred_key="y_pred"), LRScheduler(model=model, lr_fn=lambda step: linear_increase(step))
    ]
    estimator = fe.Estimator(pipeline=pipeline,
                             network=network,
                             epochs=epochs,
                             traces=traces,
                             max_train_steps_per_epoch=10,
                             log_steps=10)
    print("Running LR range test for super convergence. It will take a while...")
    with Suppressor():
        summary = estimator.fit("LR_range_test")

    best_step = max(summary.history["eval"]["accuracy"].items(), key=lambda k: k[1])[0]
    max_lr = summary.history["train"]["model_lr"][best_step]
    return max_lr
 def __init__(self,
              image_dir: str,
              annotation_file: str,
              caption_file: str,
              include_bboxes: bool = True,
              include_masks: bool = False,
              include_captions: bool = False):
     super().__init__(root_dir=image_dir,
                      data_key="image",
                      recursive_search=False)
     if include_masks:
         assert include_bboxes, "must include bboxes with mask data"
     self.include_bboxes = include_bboxes
     self.include_masks = include_masks
     with Suppressor():
         self.instances = COCO(annotation_file)
         self.captions = COCO(caption_file) if include_captions else None
Ejemplo n.º 8
0
    def on_epoch_end(self, state):
        if state['epoch'] % self.im_freq != 0:
            return
        if state['epoch'] == 0:
            self.data = tf.concat(self.data, axis=0)
            self.data = self.data[:self.n_inputs]

        with Suppressor():
            fig = plot_caricature(self.model,
                                  self.data,
                                  self.layer_ids,
                                  decode_dictionary=self.decode_dictionary,
                                  n_steps=self.n_steps,
                                  learning_rate=self.learning_rate,
                                  blur=self.blur,
                                  cossim_pow=self.cossim_pow,
                                  sd=self.sd,
                                  fft=self.fft,
                                  decorrelate=self.decorrelate,
                                  sigmoid=self.sigmoid)
        # TODO - Figure out how to get this to work without it displaying the figure. maybe fig.canvas.draw
        plt.draw()
        plt.pause(0.000001)
        flat_image = np.fromstring(fig.canvas.tostring_rgb(),
                                   dtype=np.uint8,
                                   sep='')
        flat_image_pixels = flat_image.shape[0] // 3
        width, height = fig.canvas.get_width_height()
        if flat_image_pixels % height != 0:
            # Canvas returned incorrect width/height. This seems to happen sometimes in Jupyter. TODO: figure out why.
            search = 1
            guess = height + search
            while flat_image_pixels % guess != 0:
                if search < 0:
                    search = -1 * search + 1
                else:
                    search = -1 * search
                guess = height + search
            height = guess
            width = flat_image_pixels // height
        state.maps[1][self.output_key] = flat_image.reshape(
            (1, height, width, 3))
        plt.close(fig)
Ejemplo n.º 9
0
 def on_epoch_end(self, state):
     super().on_epoch_end(state)
     if state['epoch'] % self.im_freq != 0:
         return
     old_backend = matplotlib.get_backend()
     matplotlib.use("Agg")
     with Suppressor():
         fig = plot_caricature(self.model,
                               self.data[state['mode']],
                               self.layer_ids,
                               decode_dictionary=self.decode_dictionary,
                               n_steps=self.n_steps,
                               learning_rate=self.learning_rate,
                               blur=self.blur,
                               cossim_pow=self.cossim_pow,
                               sd=self.sd,
                               fft=self.fft,
                               decorrelate=self.decorrelate,
                               sigmoid=self.sigmoid)
     fig.canvas.draw()
     state.maps[1][self.output_key] = fig_to_img(fig)
     matplotlib.use(old_backend)
Ejemplo n.º 10
0
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.resource_dir,
                                "{}_{}.pdf".format(self.report_name,
                                                   model.model_name))
                            dot = tf.keras.utils.model_to_dot(
                                model, show_shapes=True, expand_nested=True)
                            # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less than
                            # 226 inches. However, the 'size' parameter doesn't account for the whole node height, so
                            # set the limit lower (100 inches) to leave some wiggle room.
                            dot.set('size', '100')
                            dot.write(file_path, format='pdf')
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(
                                    pms.summary(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs,
                                        print_summary=False)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))
                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(
                                        model.module if
                                        self.system.num_devices > 1 else model,
                                        inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                # LaTeX \maxdim is around 575cm (226 inches), so the image must have max dimension less
                                # than 226 inches. However, the 'size' parameter doesn't account for the whole node
                                # height, so set the limit lower (100 inches) to leave some wiggle room.
                                graph.attr(size="100,100")
                                graph.attr(margin='0')
                                file_path = graph.render(
                                    filename="{}_{}".format(
                                        self.report_name, model.model_name),
                                    directory=self.resource_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            file_path = None
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
Ejemplo n.º 11
0
    def _run_epoch(self, eager: bool) -> None:
        """A method to perform an epoch of activity.

        This method requires that the current mode and epoch already be specified within the self.system object.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        ds_ids = self.pipeline.get_ds_ids(self.system.epoch_idx,
                                          self.system.mode)
        epoch_traces = sort_traces(get_current_items(
            self.traces_in_use,
            run_modes=self.system.mode,
            epoch=self.system.epoch_idx),
                                   ds_ids=ds_ids)
        self._run_traces_on_epoch_begin(traces=epoch_traces)
        self.system.batch_idx = None
        end_epoch_data = Data(
        )  # We will aggregate data over on_ds_end and put it into on_epoch_end for printing
        # run for each dataset
        for self.system.ds_id in ds_ids:
            ds_traces = get_current_items(self.traces_in_use,
                                          run_modes=self.system.mode,
                                          epoch=self.system.epoch_idx,
                                          ds_id=self.system.ds_id)
            trace_input_keys = set()
            for ds_trace in ds_traces:
                trace_input_keys.update(ds_trace.inputs)
            network_input_keys = self.network.get_effective_input_keys(
                mode=self.system.mode,
                epoch=self.system.epoch_idx,
                ds_id=self.system.ds_id)
            network_output_keys = self.network.get_all_output_keys(
                mode=self.system.mode,
                epoch=self.system.epoch_idx,
                ds_id=self.system.ds_id)
            self.network.load_epoch(mode=self.system.mode,
                                    epoch=self.system.epoch_idx,
                                    ds_id=self.system.ds_id,
                                    output_keys=trace_input_keys,
                                    eager=eager)

            with self.pipeline(
                    mode=self.system.mode,
                    epoch=self.system.epoch_idx,
                    ds_id=self.system.ds_id,
                    steps_per_epoch=self.system.steps_per_epoch,
                    output_keys=trace_input_keys - network_output_keys
                    | network_input_keys) as loader:
                loader = self._configure_loader(loader)
                iterator = iter(loader)
                with Suppressor():
                    batch = next(iterator)
                ds_traces = sort_traces(ds_traces,
                                        available_outputs=to_set(batch.keys())
                                        | network_output_keys,
                                        ds_ids=ds_ids)
                per_ds_traces = [
                    trace for trace in ds_traces
                    if isinstance(trace, PerDSTrace)
                ]
                self._run_traces_on_ds_begin(traces=per_ds_traces)
                while True:
                    try:
                        if self.system.mode == "train":
                            self.system.update_global_step()
                        self.system.update_batch_idx()
                        batch = self._configure_tensor(loader, batch)
                        self._run_traces_on_batch_begin(batch,
                                                        traces=ds_traces)
                        batch, prediction = self.network.run_step(batch)
                        self._run_traces_on_batch_end(batch,
                                                      prediction,
                                                      traces=ds_traces)
                        if isinstance(loader, DataLoader) and (
                            (self.system.batch_idx
                             == self.system.train_steps_per_epoch
                             and self.system.mode == "train") or
                            (self.system.batch_idx
                             == self.system.eval_steps_per_epoch
                             and self.system.mode == "eval")):
                            raise StopIteration
                        with Suppressor():
                            batch = next(iterator)
                    except StopIteration:
                        break
                self._run_traces_on_ds_end(traces=per_ds_traces,
                                           data=end_epoch_data)
            self.network.unload_epoch()
        self._run_traces_on_epoch_end(traces=epoch_traces, data=end_epoch_data)
Ejemplo n.º 12
0
    def _document_models(self) -> None:
        """Add model summaries to the traceability document.
        """
        with self.doc.create(Section("Models")):
            for model in humansorted(self.system.network.models,
                                     key=lambda m: m.model_name):
                if not isinstance(model, (tf.keras.Model, torch.nn.Module)):
                    continue
                self.doc.append(NoEscape(r'\FloatBarrier'))
                with self.doc.create(Subsection(f"{model.model_name}")):
                    if isinstance(model, tf.keras.Model):
                        # Text Summary
                        summary = []
                        model.summary(line_length=92,
                                      print_fn=lambda x: summary.append(x))
                        summary = "\n".join(summary)
                        self.doc.append(Verbatim(summary))
                        with self.doc.create(Center()):
                            self.doc.append(
                                HrefFEID(FEID(id(model)), model.model_name))

                        # Visual Summary
                        # noinspection PyBroadException
                        try:
                            file_path = os.path.join(
                                self.figure_dir,
                                f"FE_Model_{model.model_name}.pdf")
                            tf.keras.utils.plot_model(model,
                                                      to_file=file_path,
                                                      show_shapes=True,
                                                      expand_nested=True)
                            # TODO - cap output image size like in the pytorch implementation in case of huge network
                            # TODO - save raw .dot file in case system lacks graphviz
                        except Exception:
                            file_path = None
                            print(
                                f"FastEstimator-Warn: Model {model.model_name} could not be visualized by Traceability"
                            )
                    elif isinstance(model, torch.nn.Module):
                        if hasattr(model, 'fe_input_spec'):
                            # Text Summary
                            # noinspection PyUnresolvedReferences
                            inputs = model.fe_input_spec.get_dummy_input()
                            self.doc.append(
                                Verbatim(pms.summary(model, inputs)))
                            with self.doc.create(Center()):
                                self.doc.append(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name))

                            # Visual Summary
                            # Import has to be done while matplotlib is using the Agg backend
                            old_backend = matplotlib.get_backend() or 'Agg'
                            matplotlib.use('Agg')
                            # noinspection PyBroadException
                            try:
                                # Fake the IPython import when user isn't running from Jupyter
                                sys.modules.setdefault('IPython', MagicMock())
                                sys.modules.setdefault('IPython.display',
                                                       MagicMock())
                                import hiddenlayer as hl
                                with Suppressor():
                                    graph = hl.build_graph(model, inputs)
                                graph = graph.build_dot()
                                graph.attr(
                                    rankdir='TB'
                                )  # Switch it to Top-to-Bottom instead of Left-to-Right
                                graph.attr(
                                    size="200,200"
                                )  # LaTeX \maxdim is around 575cm (226 inches)
                                graph.attr(margin='0')
                                # TODO - save raw .dot file in case system lacks graphviz
                                file_path = graph.render(
                                    filename=f"FE_Model_{model.model_name}",
                                    directory=self.figure_dir,
                                    format='pdf',
                                    cleanup=True)
                            except Exception:
                                file_path = None
                                print(
                                    "FastEstimator-Warn: Model {} could not be visualized by Traceability"
                                    .format(model.model_name))
                            finally:
                                matplotlib.use(old_backend)
                        else:
                            self.doc.append(
                                "This model was not used by the Network during training."
                            )
                    if file_path:
                        with self.doc.create(Figure(position='ht!')) as fig:
                            fig.append(
                                Label(
                                    Marker(name=str(FEID(id(model))),
                                           prefix="model")))
                            fig.add_image(
                                os.path.relpath(file_path,
                                                start=self.save_dir),
                                width=NoEscape(
                                    r'1.0\textwidth,height=0.95\textheight,keepaspectratio'
                                ))
                            fig.add_caption(
                                NoEscape(
                                    HrefFEID(FEID(id(model)),
                                             model.model_name).dumps()))
Ejemplo n.º 13
0
def umap_layers(model_path,
                input_root_path,
                print_layers=False,
                strip_alpha=False,
                layers=None,
                input_extension=None,
                batch=10,
                use_cache=True,
                cache_dir=None,
                dictionary_path=None,
                save=False,
                save_dir=None,
                legend_mode='shared',
                umap_parameters=None):
    if umap_parameters is None:
        umap_parameters = {}
    if save_dir is None:
        save_dir = os.path.dirname(model_path)
    if cache_dir is None:
        # If the user passes the input dir as a relative path without ./ then dirname will contain all path info
        if os.path.basename(input_root_path) == "":
            cache_dir = os.path.dirname(input_root_path) + "__layer_outputs"
        else:
            cache_dir = os.path.join(
                os.path.dirname(input_root_path),
                os.path.basename(input_root_path) + "__layer_outputs")

    network = keras.models.load_model(model_path)
    if print_layers:
        for idx, layer in enumerate(network.layers):
            print("{}: {} --- output shape: {}".format(idx, layer.name,
                                                       layer.output_shape))
        return

    evaluator = Evaluator(network, layers=layers)
    loader = ImageLoader(input_root_path,
                         network,
                         batch=batch,
                         input_extension=input_extension,
                         strip_alpha=strip_alpha)
    cache = FileCache(cache_dir, evaluator.layers,
                      umap_parameters) if use_cache else None

    classes = []
    layer_outputs = None
    for batch_id, (batch_inputs, batch_classes) in enumerate(
            tqdm(loader, desc='Computing Outputs', unit='batch')):
        if use_cache and cache.batch_cached(batch_id):
            continue
        batch_layer_outputs = evaluator.evaluate(batch_inputs)
        if use_cache:
            cache.save(batch_layer_outputs, batch_classes)
        else:
            if layer_outputs is None:
                layer_outputs = batch_layer_outputs
            else:
                for i, (layer, batch_layer) in enumerate(
                        zip(layer_outputs, batch_layer_outputs)):
                    layer_outputs[i] = np.concatenate((layer, batch_layer),
                                                      axis=0)
            classes.extend(batch_classes)
    if use_cache:
        layer_outputs, classes = cache.load_and_transform(len(loader))
    else:
        fit = umap.UMAP(**umap_parameters)
        with Suppressor():  # Silence a bunch of numba warnings
            layer_outputs = [
                fit.fit_transform(layer) for layer in layer_outputs
            ]
    draw_umaps(layer_outputs,
               classes,
               layer_ids=layers,
               layers=network.layers,
               save=save,
               save_path=save_dir,
               dictionary=load_dict(dictionary_path, True),
               legend_mode=legend_mode)