示例#1
0
 def test_find_layers(self):
     mlp = MLP(14, 5)
     autoenc = AutoencoderLinear(11, 3)
     self.assertEqual(list(find_layers(mlp, layer_class=MLP)), [mlp])
     model1 = nn.Sequential(mlp, nn.Linear(12, 4))
     self.assertEqual(list(find_layers(model1, layer_class=MLP)), [mlp])
     model2 = nn.Sequential(model1, autoenc, nn.Sequential(mlp))
     self.assertEqual(
         list(find_layers(model2, layer_class=AutoencoderLinear)),
         [autoenc])
示例#2
0
    def prepare(self,
                model: nn.Module,
                monitor_layers=(nn.Linear, ),
                monitor_layers_count=1):
        """
        Sorts the model layers in order and selects last
        `monitor_layers_count` layers.

        Parameters
        ----------
        model : nn.Module
            A model to monitor.
        monitor_layers_count : int, optional
            The number of last layers to monitor for MI.
        """
        self.is_active = True  # turn on the feature
        self._prepare_input()
        self._prepare_input_finished()
        images_batch, _ = self.data_loader.sample()
        image_sample = images_batch[0]

        layers_of_interest = set(find_layers(model,
                                             layer_class=monitor_layers))
        layers_ordered = get_layers_ordered(model, image_sample)
        layers_ordered = [
            layer for layer in layers_ordered if layer in layers_of_interest
        ]
        layers_ordered = layers_ordered[-monitor_layers_count:]
        for name, layer in find_named_layers(model,
                                             layer_class=monitor_layers):
            if layer in layers_ordered:
                self.register_layer(layer=layer, name=name)

        print(f"Monitoring only these last layers for mutual information "
              f"estimation: {list(self.layer_to_name.values())}")
示例#3
0
    def __init__(self, kwta_scheduler: Optional[KWTAScheduler] = None):
        self.watch_modules = self.watch_modules + (KWinnersTakeAll,
                                                   SynapticScaling)
        layers_ordered = self.get_layers_ordered(
            ignore_children=KWinnersTakeAll)
        self.kwta_layers = tuple(layer for layer in layers_ordered
                                 if isinstance(layer, KWinnersTakeAll))
        # self.kwta_layers = tuple(find_layers(self.model, KWinnersTakeAll))
        if not self.has_kwta():
            warnings.warn(
                "For models with no kWTA layer, use TrainerEmbedding")
            self.env_name = f"{self.env_name} no-kwta"
            if kwta_scheduler is not None:
                warnings.warn("Turning off KWTAScheduler, because the model "
                              "does not have kWTA layers.")
                kwta_scheduler = None
            if isinstance(self.accuracy_measure, AccuracyEmbeddingKWTA):
                warnings.warn(
                    "Setting AccuracyEmbeddingKWTA.sparsity to None, "
                    "because the model does not have kWTA layers.")
                self.accuracy_measure.sparsity = None

        kwta_soft = tuple(find_layers(self.model, KWinnersTakeAllSoft))
        if any(kwta.threshold is not None for kwta in kwta_soft):
            # kwta-soft with a threshold
            self.env_name = f"{self.env_name} threshold"
        if any(
                isinstance(kwta.sparsity, SparsityPredictor)
                for kwta in kwta_soft):
            self.env_name = f"{self.env_name} sparsity-predictor"

        self.kwta_scheduler = kwta_scheduler
        self._update_accuracy_state()
        self._post_init_monitor()
示例#4
0
 def __init__(self,
              model: nn.Module,
              step_size: int,
              gamma_sparsity=0.5,
              min_sparsity=0.05,
              gamma_hardness=2.0,
              max_hardness=10):
     self.kwta_layers = tuple(
         find_layers(model, layer_class=KWinnersTakeAll))
     self.step_size = step_size
     self.gamma_sparsity = gamma_sparsity
     self.min_sparsity = min_sparsity
     self.gamma_hardness = gamma_hardness
     self.max_hardness = max_hardness
     self.epoch = 0
示例#5
0
    def monitor_functions(self):
        super().monitor_functions()

        try:
            softshrink = next(find_layers(self.model, layer_class=Softshrink))
        except StopIteration:
            softshrink = None

        def lambda_mean(viz):
            # effective (positive) lambda
            lambd = softshrink.lambd.data.clamp(min=0).cpu()
            viz.line_update(y=[lambd.mean()],
                            opts=dict(
                                xlabel='Epoch',
                                ylabel='Lambda mean',
                                title='Softshrink lambda threshold',
                            ))

        if softshrink is not None:
            self.monitor.register_func(lambda_mean)

        solver_online = self.model.solver.online

        def plot_dv_norm(viz):
            dv_norm = solver_online['dv_norm'].get_mean()
            viz.line_update(y=[dv_norm],
                            opts=dict(
                                xlabel='Epoch',
                                ylabel='dv_norm (final improvement)',
                                title='Solver convergence',
                            ))

        def plot_iterations(viz):
            iterations = solver_online['iterations'].get_mean()
            viz.line_update(y=[iterations],
                            opts=dict(
                                xlabel='Epoch',
                                ylabel='solver iterations',
                                title='Solver iterations run',
                            ))

        self.monitor.register_func(plot_dv_norm)
        self.monitor.register_func(plot_iterations)