def load_state(graph_file: str, return_extra=False): graph_base, _ = os.path.splitext(graph_file) state_filename = graph_base + STATE_EXTENSION state_file = Path(state_filename) LOG.info("loading graph state from %s", state_filename) if not state_file.is_file(): raise ValueError("state file not found") with state_file.open('r') as json_fp: info_state = json.load(json_fp, cls=StateDecoder) info_state['info'] = convert_str_to_keys(info_state['info']) if 'node_options' in info_state: info_state['node_options'] = convert_str_to_keys( info_state['node_options']) else: info_state['node_options'] = {} if info_state['load_parameters']: pickle_filename = graph_base + ARRS_EXTENSION LOG.info("loading tensors from %s", pickle_filename) arrs_file = Path(pickle_filename) if not arrs_file.is_file(): raise ValueError("arrays file not found") with arrs_file.open('rb') as arrs_fp: parameters = pickle.load(arrs_fp) else: parameters = None # Here load the orignal graph and replay the transforms that were done to it if info_state['info'].get('has_quantized_parameters'): opts = {'load_tensors': True, 'load_quantization': True} else: opts = { 'load_tensors': False, } # Retrieve the identity of the saved state identity = GraphIdentity(None) identity.identity = info_state['identity'] LOG.info("loading graph from %s", identity.filename) G = create_graph(identity.filename, opts=opts) if 'name' in info_state: G.name = info_state['name'] G.add_dimensions() freeze_options = { k: v for k, v in info_state['node_options'].items() if 'FIXED_ORDER' in list(v.set_options) } set_options(G, freeze_options) if identity.is_adjusted: # If weights were saved then don't reshaoe them since it was already done # before they were saved LOG.info("adjusting dimensions") G.adjust_order(reshape_weights=not info_state['load_parameters']) G.add_dimensions() if identity.is_fused: LOG.info("fusing nodes") # replay the fusions that were carried out for fusion_name in identity.fusions: fusion = get_fusion(fusion_name) fusion.match(G) G.add_dimensions() set_parameters(G, parameters) # Update the identity to match the saved graph G.info = info_state['info'] G.changes.replay(G) G.graph_identity = identity G.node_options = info_state['node_options'] set_options(G, info_state['node_options'], info_state['node_options']) if identity.extracted_step is not None: extract_node(G, G.graph_state.steps[identity.extracted_step]['node']) G.add_dimensions() if return_extra: return G, info_state['extra'] return G
def _dencapsulate(cls, val): return cls(init=convert_str_to_keys(val))
def _dencapsulate(cls, val): return SymmetricQuantizer(convert_str_to_keys(val['activation_stats']), val['force_width'])
def _dencapsulate(cls, val): return MultQuantizer(convert_str_to_keys(val['activation_stats']), val['force_width'], val['quantized_dimension'])
def _dencapsulate(cls, val): return SimpleQuantizer(convert_str_to_keys(val['activation_stats']), convert_str_to_keys(val['filter_stats']), val['min_qsnr'], val['force_width'])