def _unstrip_tensor(tensor: onnx.TensorProto) -> None: meta_dict = {} meta_dict_idx = 0 for i, external_data in enumerate(tensor.external_data): if external_data.key != "location": continue try: external_data_dict = json.loads(external_data.value) if external_data_dict.get("type", "") == "stripped": meta_dict = external_data_dict meta_dict_idx = i break except ValueError: continue if not meta_dict: return None ave = meta_dict.get("average", None) var = meta_dict.get("variance", None) if ave is None or var is None: return None np_dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type] dummy_array = numpy.random.normal(ave, math.sqrt(var), tensor.dims).astype(np_dtype) dummy_tensor = onnx.numpy_helper.from_array(dummy_array) tensor.data_location = onnx.TensorProto.DEFAULT tensor.raw_data = dummy_tensor.raw_data del tensor.external_data[meta_dict_idx]
def _strip_raw_data(tensor: onnx.TensorProto) -> onnx.TensorProto: arr = onnx.numpy_helper.to_array(tensor) meta_dict = {} meta_dict['type'] = "stripped" meta_dict['average'] = float(arr.mean()) # type: ignore[assignment] meta_dict['variance'] = float(arr.var()) # type: ignore[assignment] if not tensor.HasField("raw_data"): tensor.raw_data = onnx.numpy_helper.from_array(arr, tensor.name).raw_data onnx.external_data_helper.set_external_data(tensor, location=json.dumps(meta_dict), length=arr.nbytes) tensor.data_location = onnx.TensorProto.EXTERNAL tensor.ClearField('raw_data') tensor.ClearField('float_data') return tensor