def apply_scale(graph: Graph, input_node: Node, node_mean_scale_values: dict): if 'scale' in node_mean_scale_values and node_mean_scale_values[ 'scale'] is not None: if all([x == 1 for x in node_mean_scale_values['scale']]): return out_node = input_node.out_node() if not input_node.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format( input_node.id)) input_shape = input_node.shape # Create Mul node value = 1 / np.array(node_mean_scale_values['scale']) graph.remove_edge(input_node.id, out_node.id) mul_node = Mul(graph, dict(name="Mul_")) mul_data = Op.create_input_data_node(graph, "data_mul_", np.array(value)) Op.expand_node_shape(mul_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)) mul_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape}) mul_node.create_node_with_data(inputs=[mul_input, mul_data], data_nodes=out_node)
def replace_pattern(self, graph: Graph, match: dict): node = match['minimum'] # Constant propagation case if node.in_node(0).value is not None and node.in_node( 1).value is not None: return neg_1_const = Const( graph, dict(value=np.array(-1), name=node.name + '/negate1_const')) neg_2_const = Const( graph, dict(value=np.array(-1), name=node.name + '/negate2_const')) negate_1 = Mul(graph, dict(name=node.name + '/negate1_')) negate_2 = Mul(graph, dict(name=node.name + '/negate2_')) maximum = Maximum(graph, dict(name=node.name + '/Max_')) negate_output_const = Const( graph, dict(value=np.array(-1), name=node.name + '/negate_out_const_')) negate_output = Mul(graph, dict(scale=-1, name=node.name + '/negate_out_')) negate_output.create_node_with_data(inputs=[ maximum.create_node_with_data([ negate_1.create_node_with_data( [node.in_node(0), neg_1_const.create_node_with_data()]), negate_2.create_node_with_data( [node.in_node(1), neg_2_const.create_node_with_data()]) ]), negate_output_const.create_node_with_data() ], data_nodes=node.out_node()) # Delete minimum vertex node.graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): scale = graph.graph['cmd_params'].scale if scale is None or scale == 1: return assert (len(match['placeholder'].out_nodes())) tinput = match['placeholder'] if not tinput.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format( tinput.id)) input_shape = tinput.shape toutput = match['data'] # Create Mul node value = np.array([1 / scale]) # Disconnect input with data node graph.remove_edge(tinput.id, toutput.id) # Create Mul node mul_node = Mul(graph, dict(name="Mul1_")) mul_data = Op.create_input_data_node(graph, "data_mul_scale_", np.array(value)) Op.expand_node_shape( mul_data, len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0) mul_input = Op.create_data_node(graph, tinput, {'shape': toutput.shape}) mul_node.create_node_with_data(inputs=[mul_input, mul_data], data_nodes=toutput)
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_')) shift_add = Add(graph, dict(name=node.name + '/shift_add_')) mean_add = Add(graph, dict(name=node.name + '/mean_add_')) variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_')) neg_const = Const( graph, dict(value=np.array(-1), name=node.name + '/mean_negate_')) mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data( [node.in_node(3), neg_const.create_node_with_data()]) ]) shift_const = Const( graph, dict(value=node.eps, name=node.name + '/variance_denom_shift_const_')) power_const = Const( graph, dict(value=-0.5, name=node.name + '/variance_denom_power_const_')) variance_denom_shift = Add( graph, dict(name=node.name + '/variance_denom_shift_')) variance_denom_power = Pow( graph, dict(name=node.name + '/variance_denom_power_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom_power.create_node_with_data([ variance_denom_shift.create_node_with_data( [node.in_node(4), shift_const.create_node_with_data()]), power_const.create_node_with_data() ]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)