def run_test(self, model, input, check_value=True): script_module = torch.jit.script(model) model_ir = convert_to_graph(script_module, model) model_code = model_to_pytorch_script(model_ir) print(model_code) from .inject_nn import remove_inject_pytorch_nn remove_inject_pytorch_nn() exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] converted_state_dict = self._match_state_dict( list(model.state_dict().values()), dict(converted_model.state_dict())) converted_model.load_state_dict(converted_state_dict) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) if check_value: try: self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): torch.eq(a, b) except: self.assertEqual(converted_output, expected_output) return converted_model
def _get_model_and_mutators(): base_model = Net() script_module = torch.jit.script(base_model) base_model_ir = convert_to_graph(script_module, base_model) base_model_ir.evaluator = DebugEvaluator() mutators = process_inline_mutation(base_model_ir) return base_model_ir, mutators
def _get_model_and_mutators(): base_model = Net() script_module = torch.jit.script(base_model) base_model_ir = convert_to_graph(script_module, base_model) base_model_ir.training_config = DebugTraining() mutators = process_inline_mutation(base_model_ir) return base_model_ir, mutators
def checkExportImport(self, model, input, check_value=True): script_module = torch.jit.script(model) model_ir = convert_to_graph(script_module, model) model_code = model_to_pytorch_script(model_ir) print(model_code) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] converted_state_dict = self._match_state_dict( list(model.state_dict().values()), dict(converted_model.state_dict())) converted_model.load_state_dict(converted_state_dict) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) if check_value: self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): if hasattr(a, 'dtype') and a.dtype == torch.bool: self.assertEqual((a ^ b), False) elif isinstance((a - b), int): self.assertEqual((a - b), 0) else: self.assertLess((a - b).abs().max().item(), 1E-4) return converted_model
def _form_latency_table(self, model, dummy_input, dump_lat_table): latency_table = {} from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter.graph_gen import GraphConverterWithShape from nni.retiarii.converter.utils import flatten_model_graph_without_layerchoice, is_layerchoice_node script_module = torch.jit.script(model) base_model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=torch.randn(*dummy_input)) # form the latency of layerchoice blocks for the latency table temp_ir_model = base_model_ir.fork() cell_nodes = base_model_ir.get_cell_nodes() layerchoice_nodes = [node for node in cell_nodes if is_layerchoice_node(node)] for lc_node in layerchoice_nodes: cand_lat = {} for candidate in lc_node.operation.parameters['candidates']: node_graph = base_model_ir.graphs.get(candidate) if node_graph is not None: temp_ir_model._root_graph_name = node_graph.name latency = self.latency_predictor.predict(temp_ir_model, model_type = 'nni-ir') else: _logger.warning(f"Could not found graph for layerchoice candidate {candidate}") latency = 0 cand_lat[candidate.split('_')[-1]] = float(latency) latency_table[lc_node.operation.parameters['label']] = cand_lat # form the latency of the stationary block in the latency table temp_ir_model._root_graph_name = base_model_ir._root_graph_name temp_ir_model = flatten_model_graph_without_layerchoice(temp_ir_model) latency = self.latency_predictor.predict(temp_ir_model, model_type = 'nni-ir') latency_table['stationary_block'] = {'root': float(latency)} # save latency table if dump_lat_table: import os, yaml os.makedirs(os.path.dirname(dump_lat_table), exist_ok=True) with open(dump_lat_table, 'a') as fp: yaml.dump([{ "applied_hardware": self.predictor_name, 'latency_table': latency_table }], fp) _logger.info("Latency lookup table form done") return latency_table
def checkExportImport(self, model, input): script_module = torch.jit.script(model) model_ir = convert_to_graph(script_module, model) model_code = model_to_pytorch_script(model_ir) exec_vars = {} exec(model_code + '\n\nconverted_model = _model()', exec_vars) converted_model = exec_vars['converted_model'] converted_state_dict = self._match_state_dict( list(model.state_dict().values()), dict(converted_model.state_dict())) converted_model.load_state_dict(converted_state_dict) with torch.no_grad(): expected_output = model.eval()(*input) converted_output = converted_model.eval()(*input) self.assertEqual(len(converted_output), len(expected_output)) for a, b in zip(converted_output, expected_output): self.assertLess((a - b).abs().max().item(), 1E-4) return converted_model
def _convert_to_ir(self, model): script_module = torch.jit.script(model) return convert_to_graph(script_module, model)