def test_graph_identity_hash(): gi_1 = GraphIdentity("test") gi_2 = GraphIdentity("test1") gi_3 = GraphIdentity("test1") assert gi_1.hexdigest != gi_2.hexdigest and gi_1.hexdigest != gi_3.hexdigest assert gi_2.hexdigest == gi_3.hexdigest gi_3.is_adjusted = True assert gi_2.hexdigest != gi_3.hexdigest
def __init__(self, model=None, name=None, filename=None, constant_store=None): super().__init__() self.model = model self.num_inputs = 0 self.num_outputs = 0 self.num_constants = 0 self.node_options = {} self.num_rinputs = 0 self.num_routputs = 0 self.graph_state = NNGraphState() self.load_function = None self.graphname = name self.constant_store = constant_store self.graph_identity = GraphIdentity(filename) self._info = { 'quantization': None, }
def load_state(graph_file: str, return_extra=False): graph_base, _ = os.path.splitext(graph_file) state_filename = graph_base + STATE_EXTENSION state_file = Path(state_filename) LOG.info("loading graph state from %s", state_filename) if not state_file.is_file(): raise ValueError("state file not found") with state_file.open('r') as json_fp: info_state = json.load(json_fp, cls=StateDecoder) info_state['info'] = convert_str_to_keys(info_state['info']) if 'node_options' in info_state: info_state['node_options'] = convert_str_to_keys( info_state['node_options']) else: info_state['node_options'] = {} if info_state['load_parameters']: pickle_filename = graph_base + ARRS_EXTENSION LOG.info("loading tensors from %s", pickle_filename) arrs_file = Path(pickle_filename) if not arrs_file.is_file(): raise ValueError("arrays file not found") with arrs_file.open('rb') as arrs_fp: parameters = pickle.load(arrs_fp) else: parameters = None # Here load the orignal graph and replay the transforms that were done to it if info_state['info'].get('has_quantized_parameters'): opts = {'load_tensors': True, 'load_quantization': True} else: opts = { 'load_tensors': False, } # Retrieve the identity of the saved state identity = GraphIdentity(None) identity.identity = info_state['identity'] LOG.info("loading graph from %s", identity.filename) G = create_graph(identity.filename, opts=opts) if 'name' in info_state: G.name = info_state['name'] G.add_dimensions() freeze_options = { k: v for k, v in info_state['node_options'].items() if 'FIXED_ORDER' in list(v.set_options) } set_options(G, freeze_options) if identity.is_adjusted: # If weights were saved then don't reshaoe them since it was already done # before they were saved LOG.info("adjusting dimensions") G.adjust_order(reshape_weights=not info_state['load_parameters']) G.add_dimensions() if identity.is_fused: LOG.info("fusing nodes") # replay the fusions that were carried out for fusion_name in identity.fusions: fusion = get_fusion(fusion_name) fusion.match(G) G.add_dimensions() set_parameters(G, parameters) # Update the identity to match the saved graph G.info = info_state['info'] G.changes.replay(G) G.graph_identity = identity G.node_options = info_state['node_options'] set_options(G, info_state['node_options'], info_state['node_options']) if identity.extracted_step is not None: extract_node(G, G.graph_state.steps[identity.extracted_step]['node']) G.add_dimensions() if return_extra: return G, info_state['extra'] return G