def execute(cls, params, in_tensors, qrec: QRec, **kwargs): if qrec is None: qrec = AllFloatQRec() details = kwargs.get('details') if details is not None: current_control = SymbolStats() Symbol.set_default_control(current_control) results = {} else: results = None current_control = None in_tensors = qrec.prepare_inputs(params, in_tensors, ktype="float") in_vars = {params.input_symbols[i]: in_tensor for i, in_tensor in enumerate(in_tensors)} func_col = qrec.cache.get('qfunc_col') if func_col is None: func_col = params.func_col out_vars = func_col(**in_vars, calculate_ranges=current_control is not None, track_results=results) out_tensors = [out_vars[out_sym_name] for out_sym_name in params.output_symbols] if current_control: details.update(current_control.stats) details['results'] = results return qrec.get_outputs(params, out_tensors, ktype="float")
def _quantize(cls, params, in_qs, stats, **kwargs): force_out_qs, _ = cls.get_mult_opts(**kwargs) if stats is None or 'expression' not in stats: raise ValueError( f'no valid range information is present for {params.name}') # # expressions need a symmetric input # in_qs = cls.force_symmetric(in_qs) # if in_qs is None: # LOG.info('expression quantizer for {params.name} was not able to force input symmetric') # return None symbol_control = SymbolStats(stats['expression']) # preload the input and output quantization # This will force variables to the right scales in the expression quantizer # first the input prequant = { params.input_symbols[idx]: in_q for idx, in_q in enumerate(in_qs) } # now the output o_qs = [] for idx, sym_name in enumerate(params.output_symbols): if force_out_qs and force_out_qs[idx]: o_q = force_out_qs[idx] else: cls.check_valid_ranges(params, stats, idx=idx, dirs='out') o_q = QType.from_min_max_sq(stats['range_out'][idx]['min'], stats['range_out'][idx]['max'], dtype=np.int8) prequant[sym_name] = o_q o_qs.append(o_q) qfunc_col = params.func_col.quantize(Q15ScaledQuantization, symbol_control, quantize_inputs=False, qtypes=prequant) return QRec.scaled(in_qs=in_qs, out_qs=o_qs, qfunc_col=qfunc_col)
def _quantize(cls, params, in_qs, stats, **kwargs): _, dtype = cls.get_float_opts(**kwargs) force_out_qs = kwargs.get('force_out_qs') if stats is None or 'expression' not in stats: raise ValueError( f'no valid range information is present for {params.name}') symbol_control = SymbolStats(stats['expression']) # preload the input and output quantization # This will force variables to the right scales in the expression quantizer # first the input prequant = { params.input_symbols[idx]: in_q for idx, in_q in enumerate(in_qs) } # now the output o_qs = [] for idx, sym_name in enumerate(params.output_symbols): if force_out_qs and force_out_qs[idx]: o_q = force_out_qs[idx] else: cls.check_valid_ranges(params, stats, idx=idx, dirs='out') o_q = QType(dtype=dtype) prequant[sym_name] = o_q o_qs.append(o_q) qfunc_col = params.func_col.quantize(FloatQuantization, symbol_control, quantize_inputs=False, qtypes=prequant) return QRec.float(in_qs=in_qs, out_qs=o_qs, float_dtype=dtype, qfunc_col=qfunc_col)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False to_quantize = [] node_sets = self.find_sets(G) for node_set in node_sets: Symbol.set_default_control(SymbolStats()) has_modified_graph = True in_edges, out_edges, internal_edges = group_edges(G, node_set) frag = GraphView() for node in node_set: frag.add_node(node) for edge in internal_edges: frag.add_edge(edge) in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group] for edge_group in in_edges.values()] in_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in in_edges ] out_dims = [ from_node.out_dims[from_idx] for from_node, from_idx in out_edges ] out_mapping = list(out_edges.keys()) constant_inputs = [ node_edge_idx[0] for node_edge_idx in in_edges if isinstance(node_edge_idx[0], ConstantInputParameters) ] LOG.debug( "inputs coming from: %s", ",".join(f"{from_node.__repr__()}:{from_idx}" for from_node, from_idx in in_edges)) LOG.info("fusing nodes: %s into expr_%s", ",".join(node.__repr__() for node in node_set), self._expr_num) expr = ExpressionFusionParameters( G.unique_name(f"expr_{self._expr_num}"), subgraph=frag, qrecs=G.quantization, input_mapping=in_mapping, output_mapping=out_mapping, in_dims=in_dims, out_dims=out_dims, constant_inputs=constant_inputs) in_edge_mapping = list(in_edges.keys()) out_edge_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_set] for edge_set in out_edges.values()] G.replace_fragment( frag, expr, frag_in_edges=list(set.union(*in_edges.values())), frag_out_edges=list(set.union(*out_edges.values())), edge_in_mapping=in_edge_mapping, edge_out_mapping=out_edge_mapping, edge_class=NNEdge) if G.quantization: qrecs = G.quantization in_qs = [ qrecs[NodeId(in_map[0][0])].in_qs[in_map[0][1]] for in_map in in_mapping ] out_qs = [ qrecs[NodeId(node)].out_qs[idx] for node, idx in out_mapping ] stats = Symbol.CURRENT_CONTROL.stats func_col = expr.func_col for idx, qtype in enumerate(in_qs): symbol = func_col.variables[func_col.input_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } for idx, qtype in enumerate(out_qs): symbol = func_col.variables[func_col.output_names[idx]] stats[symbol.name] = { 'min': qtype.min_val, 'max': qtype.max_val } G.quantization[NodeId(expr)] = QRec(in_qs=in_qs, out_qs=out_qs, expression=stats, ktype='scaled') # delete any quantize parameters on outputs to allow the quantizer # to fuse them into the expression out_edges = G.out_edges(expr.name) for edge in out_edges: if isinstance(edge.to_node, QuantizeParameters): G.remove_and_reconnect(edge.to_node) if NodeId(edge.to_node) in G.quantization: del G.quantization[NodeId(edge.to_node)] to_quantize.append(expr) self._expr_num += 1 if to_quantize: quantizer = UnifiedQuantizer.from_quantized_graph(G) G.quantization = quantizer.quantize(G, start_nodes=to_quantize) if set_identity: self.set_identity(G) return has_modified_graph