def test_find_layers(self): mlp = MLP(14, 5) autoenc = AutoencoderLinear(11, 3) self.assertEqual(list(find_layers(mlp, layer_class=MLP)), [mlp]) model1 = nn.Sequential(mlp, nn.Linear(12, 4)) self.assertEqual(list(find_layers(model1, layer_class=MLP)), [mlp]) model2 = nn.Sequential(model1, autoenc, nn.Sequential(mlp)) self.assertEqual( list(find_layers(model2, layer_class=AutoencoderLinear)), [autoenc])
def prepare(self, model: nn.Module, monitor_layers=(nn.Linear, ), monitor_layers_count=1): """ Sorts the model layers in order and selects last `monitor_layers_count` layers. Parameters ---------- model : nn.Module A model to monitor. monitor_layers_count : int, optional The number of last layers to monitor for MI. """ self.is_active = True # turn on the feature self._prepare_input() self._prepare_input_finished() images_batch, _ = self.data_loader.sample() image_sample = images_batch[0] layers_of_interest = set(find_layers(model, layer_class=monitor_layers)) layers_ordered = get_layers_ordered(model, image_sample) layers_ordered = [ layer for layer in layers_ordered if layer in layers_of_interest ] layers_ordered = layers_ordered[-monitor_layers_count:] for name, layer in find_named_layers(model, layer_class=monitor_layers): if layer in layers_ordered: self.register_layer(layer=layer, name=name) print(f"Monitoring only these last layers for mutual information " f"estimation: {list(self.layer_to_name.values())}")
def __init__(self, kwta_scheduler: Optional[KWTAScheduler] = None): self.watch_modules = self.watch_modules + (KWinnersTakeAll, SynapticScaling) layers_ordered = self.get_layers_ordered( ignore_children=KWinnersTakeAll) self.kwta_layers = tuple(layer for layer in layers_ordered if isinstance(layer, KWinnersTakeAll)) # self.kwta_layers = tuple(find_layers(self.model, KWinnersTakeAll)) if not self.has_kwta(): warnings.warn( "For models with no kWTA layer, use TrainerEmbedding") self.env_name = f"{self.env_name} no-kwta" if kwta_scheduler is not None: warnings.warn("Turning off KWTAScheduler, because the model " "does not have kWTA layers.") kwta_scheduler = None if isinstance(self.accuracy_measure, AccuracyEmbeddingKWTA): warnings.warn( "Setting AccuracyEmbeddingKWTA.sparsity to None, " "because the model does not have kWTA layers.") self.accuracy_measure.sparsity = None kwta_soft = tuple(find_layers(self.model, KWinnersTakeAllSoft)) if any(kwta.threshold is not None for kwta in kwta_soft): # kwta-soft with a threshold self.env_name = f"{self.env_name} threshold" if any( isinstance(kwta.sparsity, SparsityPredictor) for kwta in kwta_soft): self.env_name = f"{self.env_name} sparsity-predictor" self.kwta_scheduler = kwta_scheduler self._update_accuracy_state() self._post_init_monitor()
def __init__(self, model: nn.Module, step_size: int, gamma_sparsity=0.5, min_sparsity=0.05, gamma_hardness=2.0, max_hardness=10): self.kwta_layers = tuple( find_layers(model, layer_class=KWinnersTakeAll)) self.step_size = step_size self.gamma_sparsity = gamma_sparsity self.min_sparsity = min_sparsity self.gamma_hardness = gamma_hardness self.max_hardness = max_hardness self.epoch = 0
def monitor_functions(self): super().monitor_functions() try: softshrink = next(find_layers(self.model, layer_class=Softshrink)) except StopIteration: softshrink = None def lambda_mean(viz): # effective (positive) lambda lambd = softshrink.lambd.data.clamp(min=0).cpu() viz.line_update(y=[lambd.mean()], opts=dict( xlabel='Epoch', ylabel='Lambda mean', title='Softshrink lambda threshold', )) if softshrink is not None: self.monitor.register_func(lambda_mean) solver_online = self.model.solver.online def plot_dv_norm(viz): dv_norm = solver_online['dv_norm'].get_mean() viz.line_update(y=[dv_norm], opts=dict( xlabel='Epoch', ylabel='dv_norm (final improvement)', title='Solver convergence', )) def plot_iterations(viz): iterations = solver_online['iterations'].get_mean() viz.line_update(y=[iterations], opts=dict( xlabel='Epoch', ylabel='solver iterations', title='Solver iterations run', )) self.monitor.register_func(plot_dv_norm) self.monitor.register_func(plot_iterations)