def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]: """ Set of potential fusions, in reverse order. The order is reversed to match how fusion patterns are defined in quantization code. Fusion format: ((fusion_op_0, fusion_op_1), base_op_idx) Where base_op_idx is the idx of the op we should use to match other related ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx of 0 represents the first op in regular (non-reverse) order, 1 represents the second op, etc. """ results: Set[Tuple[NSFusionType, int]] = set([]) # Possible syntaxes: # * single op: torch.nn.Conv2d # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) # For fusions, we only care about patterns composed of multiple ops. # TODO(future PR): allow customizations from default patterns. all_quant_patterns = get_default_quant_patterns() default_base_op_idx = 0 for quant_pattern, _quant_handler in all_quant_patterns.items(): # this only takes patterns of multiple ops if isinstance(quant_pattern, tuple): results.add( (quant_pattern, default_base_op_idx)) # type: ignore[arg-type] # After this point, results countains values such as # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...] # Patterns for matching fp16 emulation are not specified in the quantization # fusion mappings. For now, define them here. fp16_em_base_op_idx = 1 patterns_to_add = [ # linear-relu fp16 emulation: # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16 ( (("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx, ), ] for p in patterns_to_add: results.add(p) return results
def test_op_relationship_mapping(self): """ Tests that the mapping of op relationships is complete. """ base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops( ) type_a_related_to_b = \ get_type_a_related_to_b(base_name_to_sets_of_related_ops) # 1. check static quant module mappings static_quant_mod_mappings = get_default_static_quant_module_mappings() for fp32_type, int8_type in static_quant_mod_mappings.items(): # skip quants and dequants, for the purposes of Numerical Suite types_to_skip = ( torch.quantization.QuantStub, torch.quantization.DeQuantStub, nnq.FloatFunctional, ) if fp32_type in types_to_skip: continue # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 2. check static quant op mappings static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings( ) for fp32_type, int8_type in static_quant_fun_mappings.items(): # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 3. check dynamic quant mappings dynamic_quant_mappings = get_default_dynamic_quant_module_mappings() for fp32_type, int8_type in dynamic_quant_mappings.items(): # TODO(future PR): enable correct weight extraction for these # and remove from this list. types_to_skip = ( nn.GRUCell, nn.GRU, nn.LSTMCell, nn.RNNCell, ) if fp32_type in types_to_skip: continue # verify relatedness in_type_a_related_to_b = \ (fp32_type, int8_type) in type_a_related_to_b self.assertTrue( in_type_a_related_to_b, f"{fp32_type} and {int8_type} need a relationship mapping") # 4. go through the ops mapped to each QuantizeHandler type, and verify # correctness. def _op_in_base_sets_of_related_ops(op): for name, ops in base_name_to_sets_of_related_ops.items(): if op in ops: return True return False default_quant_patterns = get_default_quant_patterns() for pattern, qhandler_cls in default_quant_patterns.items(): base_op = None if isinstance(pattern, tuple): base_op = pattern[-1] elif isinstance(pattern, str): # TODO(future PR): add handling for these continue else: base_op = pattern qhandler_cls_all_ops_quantizeable = [ qp.CatQuantizeHandler, qp.ConvReluQuantizeHandler, qp.LinearReLUQuantizeHandler, qp.BatchNormQuantizeHandler, qp.EmbeddingQuantizeHandler, qp.RNNDynamicQuantizeHandler, qp.ELUQuantizeHandler, ] qhandler_cls_quant_op_same_signature = [ qp.FixedQParamsOpQuantizeHandler, qp.CopyNodeQuantizeHandler, ] if qhandler_cls == qp.BinaryOpQuantizeHandler: # these ops do not have quantized equivalents ops_to_skip = [ torch.bmm, torch.sum, torch.div, torch.sub, operator.truediv, operator.sub ] if base_op in ops_to_skip: continue self.assertTrue(_op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") elif qhandler_cls == qp.RNNDynamicQuantizeHandler: # TODO(future PR): add support for all classes in # RNNDynamicQuantizeHandler pass elif qhandler_cls == qp.DefaultNodeQuantizeHandler: ops_to_skip = [ torch.nn.SiLU, torch.nn.functional.silu, ] if base_op in ops_to_skip: continue self.assertTrue(_op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") elif qhandler_cls in qhandler_cls_quant_op_same_signature: # these ops use the same op signature for fp32 and quantized # tensors pass elif qhandler_cls in qhandler_cls_all_ops_quantizeable: self.assertTrue(_op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") else: raise AssertionError( f"handing for {qhandler_cls} not implemented")
def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]: """ Set of potential fusions, in reverse order. The order is reversed to match how fusion patterns are defined in quantization code. Fusion format: ((fusion_op_0, fusion_op_1), base_op_idx) Where base_op_idx is the idx of the op we should use to match other related ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx of 0 represents the first op in regular (non-reverse) order, 1 represents the second op, etc. """ results: List[Tuple[NSFusionType, int]] = [] # Possible syntaxes: # * single op: torch.nn.Conv2d # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) # For fusions, we only care about patterns composed of multiple ops. # TODO(future PR): allow customizations from default patterns. all_quant_patterns = get_default_quant_patterns() default_base_op_idx = 0 for quant_pattern, _quant_handler in all_quant_patterns.items(): # Only patterns of multiple ops are fusions, ignore # patterns which contain a single ops (they get matched # without caring about fusions). if isinstance(quant_pattern, tuple): results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type] # For each pattern, add additional patterns with observers and # fake quants at the end. # TODO(future PR): if needed, implement matching for a node # having multiple output observers. for cls in (ObserverBase, FakeQuantizeBase): if isinstance(quant_pattern, tuple): new_pattern = (cls, *quant_pattern) else: new_pattern = (cls, quant_pattern) results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type] # After this point, results countains values such as # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...] # Patterns for matching fp16 emulation are not specified in the quantization # fusion mappings. For now, define them here. fp16_em_base_op_idx = 1 patterns_to_add = [ # linear-relu fp16 emulation: # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16 ((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,), # Conv-BN fusion (this happens outside of quantization patterns, # which is why it is defined separately here). ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), ] for p in patterns_to_add: results.append(p) # type: ignore[arg-type] results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type] results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type] return results