def test_fixed_op_model_with_specialization(self): op_model = delay_model.OpModel( text_format.Parse( 'op: "kFoo" estimator { fixed: 42 } ' 'specializations { kind: OPERANDS_IDENTICAL estimator { fixed: 123 } }', delay_model_pb2.OpModel()), ()) self.assertEqual(op_model.op, 'kFoo') self.assertEqual( op_model.estimator.operation_delay( _parse_operation('op: "kBar" bit_count: 123')), 42) self.assertEqual( op_model.specializations[ delay_model_pb2.SpecializationKind.OPERANDS_IDENTICAL]. operation_delay(_parse_operation('op: "kBar" bit_count: 123')), 123) self.assertEqualIgnoringWhitespace( op_model.cpp_delay_function(), """ absl::StatusOr<int64_t> FooDelay(Node* node) { if (std::all_of(node->operands().begin(), node->operands().end(), [&](Node* n) { return n == node->operand(0); })) { return 123; } return 42; } """)
def test_regression_op_model_with_bounding_box_specialization(self): def gen_data_point(bit_count, delay, specialization=''): return _parse_data_point( 'operation { op: "kFoo" bit_count: %d %s} delay: %d delay_offset: 0' % (bit_count, specialization, delay)) op_model = delay_model.OpModel( text_format.Parse( 'op: "kFoo" estimator { regression { expressions { factor { source: RESULT_BIT_COUNT } } } }' 'specializations { kind: OPERANDS_IDENTICAL ' 'estimator { bounding_box { factors { source: RESULT_BIT_COUNT } } } }', delay_model_pb2.OpModel()), [gen_data_point(bc, 10 * bc) for bc in range(1, 10)] + [ gen_data_point(bc, 2 * bc, 'specialization: OPERANDS_IDENTICAL') for bc in range(1, 3) ]) self.assertEqual(op_model.op, 'kFoo') self.assertEqualIgnoringWhitespaceAndFloats( op_model.cpp_delay_function(), """ absl::StatusOr<int64_t> FooDelay(Node* node) { if (std::all_of(node->operands().begin(), node->operands().end(), [&](Node* n) { return n == node->operand(0); })) { if (node->GetType()->GetFlatBitCount() <= 1) { return 2; } if (node->GetType()->GetFlatBitCount() <= 2) { return 4; } return absl::UnimplementedError( "Unhandled node for delay estimation: " + node->ToStringWithOperandTypes()); } return std::round( 0.0 + 0.0 * static_cast<float>(node->GetType()->GetFlatBitCount()) + 0.0 * std::log2(static_cast<float>(node->GetType()->GetFlatBitCount()))); } """)
def test_fixed_op_model(self): op_model = delay_model.OpModel( text_format.Parse('op: "kFoo" estimator { fixed: 42 }', delay_model_pb2.OpModel()), ()) self.assertEqual(op_model.op, 'kFoo') self.assertEqual( op_model.estimator.operation_delay( _parse_operation('op: "kBar" bit_count: 123')), 42) self.assertEqualIgnoringWhitespace( op_model.cpp_delay_function(), """absl::StatusOr<int64_t> FooDelay(Node* node) { return 42; }""")