def create_weight_sparsifying_operation(self, module): return BinaryMask(module.weight.size())
def __init__(self, layer): super().__init__() self.layer = layer sparsifier = BinaryMask(layer.weight.size()) self.op_key = self.layer.register_pre_forward_operation( UpdateWeight(sparsifier))