Пример #1
0
    def export_quant_config(self, export_file=None, adjust_pos=True):
        if NndctOption.nndct_param_corr.value > 0:
            if self.quant_mode == 1:
                # gather bias correction, how to get nn module objec?
                for node in self.Nndctgraph.nodes:
                    if node.op.type in [
                            NNDCT_OP.CONV1D, NNDCT_OP.CONV2D,
                            NNDCT_OP.CONVTRANSPOSE2D,
                            NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.DENSE,
                            NNDCT_OP.DEPTHWISE_CONVTRANSPOSE2D
                    ]:
                        if node.module.bias is not None:
                            self.bias_corr[node.name] = node.module.bias_corr()

                # export bias correction
                torch.save(self.bias_corr, self.bias_corr_file)
                self.bias_corrected = True

        # export quant steps
        file_name = export_file or self.export_file
        if isinstance(file_name, str):
            NndctScreenLogger().info(f"=>Exporting quant config.({file_name})")
            if adjust_pos:
                self.organize_quant_pos()
            with open(file_name, 'w') as f:
                f.write(nndct_utils.to_jsonstr(self.quant_config))
Пример #2
0
 def export_quant_config(self, export_file=None):
     file_name = export_file or self.export_file
     if isinstance(file_name, str):
         if self.quant_mode in [1, 3]:
             self.organize_quant_pos()
             with open(file_name, 'w') as f:
                 f.write(nndct_utils.to_jsonstr(self.quant_config))
Пример #3
0
 def export_quant_config(self, export_file=None):
     if self.quant_mode in [1, 3]:
         file_name = export_file or self.export_file
         if isinstance(file_name, str):
             NndctScreenLogger().info(
                 f"=>Exporting quant config.({file_name})")
             self.organize_quant_pos()
             with open(file_name, 'w') as f:
                 f.write(nndct_utils.to_jsonstr(self.quant_config))
Пример #4
0
 def export(self,
            file_prefix='NndctGen_graph',
            graph_format=False,
            with_params=False,
            prefix=''):
     graph_des = self.as_description(prefix)
     with open(file_prefix + NNDCT_KEYS.XMODEL_SUFFIX, 'w') as f:
         if graph_format:
             f.write(nndct_utils.to_jsonstr(graph_des))
         else:
             f.write(json.dumps(graph_des))
     if with_params:
         if os.path.exists(file_prefix + NNDCT_KEYS.XPARAM_SUFFIX):
             os.remove(file_prefix + NNDCT_KEYS.XPARAM_SUFFIX)
         with nndct_utils.HDFShapedStore(file_prefix +
                                         NNDCT_KEYS.XPARAM_SUFFIX) as store:
             for k, v in self._params.items():
                 store.save(k, v.shape, v.data)