def get_variant_clipper(self, faulty_model, name=None): model = self.get_model(name=name, training_variant='none') def ranger_layer_factory(insert_layer_name): return ClipperLayer(name=insert_layer_name, bounds=self.bounds) model = insert_layer_nonseq(model, self.activation_name_pattern, ranger_layer_factory, 'dummy', model_name=name) return model
def get_variant_profiler(self, faulty_model, name=None): model = self.copy_model(faulty_model, name=(name or '') + '_base_copy') def ranger_layer_factory(insert_layer_name): return ProfileLayer(name=insert_layer_name) model = insert_layer_nonseq(model, self.activation_name_pattern, ranger_layer_factory, 'dummy', model_name=name) setattr(model, 'dropin', faulty_model.dropin) return model
def get_variant_ranger(self, faulty_model, name=None): model = self.copy_model(faulty_model, name=name + '_base_copy') def ranger_layer_factory(insert_layer_name): return RangerLayer(name=insert_layer_name, bounds=self.bounds) model = insert_layer_nonseq(model, self.activation_name_pattern, ranger_layer_factory, 'dummy', model_name=name) return model
def __init__(self, model, representative_dataset=None, a=None, b=None, r=0.5, mode='worst', regex='conv2d.*|dense.*', perturb=lambda x, p: x + p, count=1, portion=None) -> None: super().__init__() self.model = model self.representative_dataset = representative_dataset self.r = r self.mode = mode self.regex = regex self.perturb = perturb self.count = count self.portion = portion if self.representative_dataset: DropinProfiler.a, DropinProfiler.b = None, None def profiler_layer_factory(insert_layer_name): return DropinProfiler(name=insert_layer_name) profiler = insert_layer_nonseq(model, self.regex, profiler_layer_factory, 'profiler', only_last_node=True) profiler.run_eagerly = True train_data_size = len(representative_dataset) for i, data in enumerate(self.representative_dataset): x, y = data profiler.predict(x) logger.info('Done with {}/{} batches.'.format( i, train_data_size)) self.a, self.b = DropinProfiler.a, DropinProfiler.b else: assert None not in (a, b) self.a, self.b = a, b self.perturbation_inputs = []