Ejemplo n.º 1
0
  def test_save_load_checkpoint(self):
    model = delay_model_pb2.DelayModel()
    # Maps an op name to the set of bit configurations we've run that op with.
    data_points = collections.defaultdict(set)

    # Set up some dummy data.
    ops = ["op_a", "op_b", "op_c", "op_d", "op_e"]
    bit_configs = ["3, 1, 1", "4, 1, 2", "5, 2, 1", "6, 2, 2"]
    for op in ops:
      data_points[op] = set()
      for bit_config in bit_configs:
        data_points[op].add(bit_config)
        result = delay_model_pb2.DataPoint()
        result.operation.op = op
        for elem in bit_config.split(",")[1:]:
          operand = delay_model_pb2.Operation.Operand()
          operand.bit_count = int(elem)
          result.operation.operands.append(operand)
        result.operation.bit_count = int(bit_config[0])
        result.delay = 5
        model.data_points.append(result)
    tf = tempfile.NamedTemporaryFile()
    client.save_checkpoint(model, tf.name)
    loaded_data_points, loaded_model = client.init_data(tf.name)

    self.assertEqual(model, loaded_model)
    # Fancy equality checking so we get clearer error messages on
    # mismatch.
    for op in ops:
      self.assertIn(op, loaded_data_points)
      loaded_op = loaded_data_points[op]
      for bit_config in bit_configs:
        self.assertIn(bit_config, loaded_op)
        self.assertIn(bit_config, data_points[op])
    self.assertEqual(data_points, loaded_data_points)
def _synthesize_op_and_make_bare_data_point(
        op: str,
        kop: str,
        op_type: str,
        operand_types: List[str],
        stub: synthesis_service_pb2_grpc.SynthesisServiceStub,
        attributes: Sequence[Tuple[str, str]] = (),
        literal_operand: Optional[int] = None) -> delay_model_pb2.DataPoint:
    """Characterize an operation via synthesis server.

  Sets the area and op of the data point but not any other information about the
  node / operands.

  Args:
    op: Operation name to use for generating an IR package; e.g. 'add'.
    kop: Operation name to emit into datapoints, generally in kConstant form for
      use in the delay model; e.g. 'kAdd'.
    op_type: The type of the operation result.
    operand_types: The type of each operation.
    stub: Handle to the synthesis server.
    attributes: Attributes to include in the operation mnemonic. For example,
      "new_bit_count" in extend operations. Forwarded to generate_ir_package.
    literal_operand: Optionally specifies that the given operand number should
      be substituted with a randomly generated literal instead of a function
      parameter. Forwarded to generate_ir_package.

  Returns:
    datapoint produced via the synthesis server with the area and op (but no
    other fields)
    set.

  """
    ir_text = op_module_generator.generate_ir_package(op, op_type,
                                                      operand_types,
                                                      attributes,
                                                      literal_operand)
    module_name_safe_op_type = op_type.replace('[', '_').replace(']', '')
    module_name = f'{op}_{module_name_safe_op_type}'
    mod_generator_result = op_module_generator.generate_verilog_module(
        module_name, ir_text)
    top_name = module_name + '_wrapper'
    verilog_text = op_module_generator.generate_parallel_module(
        [mod_generator_result], top_name)

    response = _synth(stub, verilog_text, top_name)
    result = delay_model_pb2.DataPoint()
    _record_area(response, result)
    result.operation.op = kop

    return result
def _synthesize_ir(
        stub: synthesis_service_pb2_grpc.SynthesisServiceStub,
        ir_text: str,
        op: str,
        result_bit_count: int,
        operand_bit_counts=Sequence[int]) -> delay_model_pb2.DataPoint:
    """Synthesis the given IR text and return a data point."""
    module_name = 'top'
    mod_generator_result = op_module_generator.generate_verilog_module(
        module_name, ir_text)
    verilog_text = mod_generator_result.verilog_text
    result = _synth(stub, verilog_text, module_name)
    ps = 1e12 / result.max_frequency_hz
    result = delay_model_pb2.DataPoint()
    result.operation.op = op
    result.operation.bit_count = result_bit_count
    for bit_count in operand_bit_counts:
        operand = result.operation.operands.add()
        operand.bit_count = bit_count
    result.delay = int(ps)
    return result
def _synthesize_ir(stub: synthesis_service_pb2_grpc.SynthesisServiceStub,
                   model: delay_model_pb2.DelayModel,
                   data_points: Dict[str, Set[str]], ir_text: str, op: str,
                   result_bit_count: int,
                   operand_bit_counts: Sequence[int]) -> None:
    """Synthesizes the given IR text and returns a data point."""
    if op not in data_points:
        data_points[op] = set()

    bit_count_strs = []
    for bit_count in operand_bit_counts:
        operand = delay_model_pb2.Operation.Operand(bit_count=bit_count,
                                                    element_count=0)
        bit_count_strs.append(str(operand))
    key = ', '.join([str(result_bit_count)] + bit_count_strs)
    if key in data_points[op]:
        return
    data_points[op].add(key)

    logging.info('Running %s with %d / %s', op, result_bit_count,
                 ', '.join([str(x) for x in operand_bit_counts]))
    module_name = 'top'
    mod_generator_result = op_module_generator.generate_verilog_module(
        module_name, ir_text)
    verilog_text = mod_generator_result.verilog_text
    result = _synth(stub, verilog_text, module_name)
    logging.info('Result: %s', result)
    ps = 1e12 / result.max_frequency_hz
    result = delay_model_pb2.DataPoint()
    result.operation.op = op
    result.operation.bit_count = result_bit_count
    for bit_count in operand_bit_counts:
        operand = result.operation.operands.add()
        operand.bit_count = bit_count
    result.delay = int(ps)

    # Checkpoint after every run.
    model.data_points.append(result)
    save_checkpoint(model, FLAGS.checkpoint_path)
Ejemplo n.º 5
0
def _parse_data_point(s: Text) -> delay_model_pb2.DataPoint:
    """Parses a text proto representation of a DataPoint."""
    return text_format.Parse(s, delay_model_pb2.DataPoint())