def test_save_load_checkpoint(self): model = delay_model_pb2.DelayModel() # Maps an op name to the set of bit configurations we've run that op with. data_points = collections.defaultdict(set) # Set up some dummy data. ops = ["op_a", "op_b", "op_c", "op_d", "op_e"] bit_configs = ["3, 1, 1", "4, 1, 2", "5, 2, 1", "6, 2, 2"] for op in ops: data_points[op] = set() for bit_config in bit_configs: data_points[op].add(bit_config) result = delay_model_pb2.DataPoint() result.operation.op = op for elem in bit_config.split(",")[1:]: operand = delay_model_pb2.Operation.Operand() operand.bit_count = int(elem) result.operation.operands.append(operand) result.operation.bit_count = int(bit_config[0]) result.delay = 5 model.data_points.append(result) tf = tempfile.NamedTemporaryFile() client.save_checkpoint(model, tf.name) loaded_data_points, loaded_model = client.init_data(tf.name) self.assertEqual(model, loaded_model) # Fancy equality checking so we get clearer error messages on # mismatch. for op in ops: self.assertIn(op, loaded_data_points) loaded_op = loaded_data_points[op] for bit_config in bit_configs: self.assertIn(bit_config, loaded_op) self.assertIn(bit_config, data_points[op]) self.assertEqual(data_points, loaded_data_points)
def _synthesize_op_and_make_bare_data_point( op: str, kop: str, op_type: str, operand_types: List[str], stub: synthesis_service_pb2_grpc.SynthesisServiceStub, attributes: Sequence[Tuple[str, str]] = (), literal_operand: Optional[int] = None) -> delay_model_pb2.DataPoint: """Characterize an operation via synthesis server. Sets the area and op of the data point but not any other information about the node / operands. Args: op: Operation name to use for generating an IR package; e.g. 'add'. kop: Operation name to emit into datapoints, generally in kConstant form for use in the delay model; e.g. 'kAdd'. op_type: The type of the operation result. operand_types: The type of each operation. stub: Handle to the synthesis server. attributes: Attributes to include in the operation mnemonic. For example, "new_bit_count" in extend operations. Forwarded to generate_ir_package. literal_operand: Optionally specifies that the given operand number should be substituted with a randomly generated literal instead of a function parameter. Forwarded to generate_ir_package. Returns: datapoint produced via the synthesis server with the area and op (but no other fields) set. """ ir_text = op_module_generator.generate_ir_package(op, op_type, operand_types, attributes, literal_operand) module_name_safe_op_type = op_type.replace('[', '_').replace(']', '') module_name = f'{op}_{module_name_safe_op_type}' mod_generator_result = op_module_generator.generate_verilog_module( module_name, ir_text) top_name = module_name + '_wrapper' verilog_text = op_module_generator.generate_parallel_module( [mod_generator_result], top_name) response = _synth(stub, verilog_text, top_name) result = delay_model_pb2.DataPoint() _record_area(response, result) result.operation.op = kop return result
def _synthesize_ir( stub: synthesis_service_pb2_grpc.SynthesisServiceStub, ir_text: str, op: str, result_bit_count: int, operand_bit_counts=Sequence[int]) -> delay_model_pb2.DataPoint: """Synthesis the given IR text and return a data point.""" module_name = 'top' mod_generator_result = op_module_generator.generate_verilog_module( module_name, ir_text) verilog_text = mod_generator_result.verilog_text result = _synth(stub, verilog_text, module_name) ps = 1e12 / result.max_frequency_hz result = delay_model_pb2.DataPoint() result.operation.op = op result.operation.bit_count = result_bit_count for bit_count in operand_bit_counts: operand = result.operation.operands.add() operand.bit_count = bit_count result.delay = int(ps) return result
def _synthesize_ir(stub: synthesis_service_pb2_grpc.SynthesisServiceStub, model: delay_model_pb2.DelayModel, data_points: Dict[str, Set[str]], ir_text: str, op: str, result_bit_count: int, operand_bit_counts: Sequence[int]) -> None: """Synthesizes the given IR text and returns a data point.""" if op not in data_points: data_points[op] = set() bit_count_strs = [] for bit_count in operand_bit_counts: operand = delay_model_pb2.Operation.Operand(bit_count=bit_count, element_count=0) bit_count_strs.append(str(operand)) key = ', '.join([str(result_bit_count)] + bit_count_strs) if key in data_points[op]: return data_points[op].add(key) logging.info('Running %s with %d / %s', op, result_bit_count, ', '.join([str(x) for x in operand_bit_counts])) module_name = 'top' mod_generator_result = op_module_generator.generate_verilog_module( module_name, ir_text) verilog_text = mod_generator_result.verilog_text result = _synth(stub, verilog_text, module_name) logging.info('Result: %s', result) ps = 1e12 / result.max_frequency_hz result = delay_model_pb2.DataPoint() result.operation.op = op result.operation.bit_count = result_bit_count for bit_count in operand_bit_counts: operand = result.operation.operands.add() operand.bit_count = bit_count result.delay = int(ps) # Checkpoint after every run. model.data_points.append(result) save_checkpoint(model, FLAGS.checkpoint_path)
def _parse_data_point(s: Text) -> delay_model_pb2.DataPoint: """Parses a text proto representation of a DataPoint.""" return text_format.Parse(s, delay_model_pb2.DataPoint())