def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """ Quantize the weights since training has finalized. """ if hasattr(pl_module, "_quantized"): return pl_module._quantized = self.convert(pl_module._prepared, self.qconfig_dicts.keys()) self.quantized = pl_module._quantized
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Make sure we have a quantized version. This handles the edge case where a user does .test() without .fit() first. """ if hasattr(pl_module, "_quantized"): return pl_module._quantized = self.convert( pl_module._prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs ) self.quantized = pl_module._quantized