def fuse_conv_and_bias_callback(self, op, label_map_op_list): """ Callback function that handles fusion for Conv + bias pattern """ for (label_map, op) in label_map_op_list: map_roles = label_map[self.map_roles_label] conv_op = self.op_arg(map_roles, 0) bias = label_map[self.conv_bias_label] new_op_map = {} if isinstance(conv_op, ConvolutionOp): conv_new_op = ConvolutionOp(conv_op.conv_params, self.op_arg(conv_op, 0), self.op_arg(conv_op, 1), bias, axes=conv_op.axes) new_op_map[conv_op] = conv_new_op # Create ops that are downstream to convolution but still upstream of the add 'op' prev_op = self.op_arg(op, 0) upstream_ops = [] while prev_op != conv_op: upstream_ops.append(prev_op) prev_op = self.op_arg(prev_op, 0) for old_op in reversed(upstream_ops): new_arg = new_op_map[self.op_arg(old_op, 0)] if isinstance(old_op, MapRolesOp): new_op_map[old_op] = MapRolesOp(new_arg, old_op.axes_map) elif isinstance(old_op, TensorSliceOp): new_op_map[old_op] = TensorSliceOp(new_arg, old_op.slices, old_op.axes) elif isinstance(old_op, ReorderAxes): new_op_map[old_op] = ReorderAxes(new_arg, old_op.axes) self.replace_op(op, new_op_map[self.op_arg(op, 0)]) self.replace_op(conv_op, conv_new_op)
def visit(self, op): x, y = op.args x_reduction_axes = op.x_reduction_axes y_reduction_axes = op.y_reduction_axes out_axes = op.axes if len(x_reduction_axes) == 0: d = make_axis(1) x_reduction_axes = make_axes((d, )) y_reduction_axes = x_reduction_axes x = broadcast(x, x.axes + x_reduction_axes) y = broadcast(y, y_reduction_axes + y.axes) if x.is_scalar: temp = x x = y y = temp if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(x_reduction_axes) > 0: out = out * x_reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, x_reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: x_rem_axes = x.axes - x_reduction_axes x = axes_with_order(x, x_rem_axes + x_reduction_axes) y_rem_axes = y.axes - y_reduction_axes y = axes_with_order(y, y_reduction_axes + y_rem_axes) x = flatten_at(x, len(x.axes) - len(x_reduction_axes)) y = flatten_at(y, len(y_reduction_axes)) if len(out_axes) == 0: out = DotOneDimensional(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotTwoByOne(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotTwoByOne(x, y, axes=x.axes[0]) else: out = DotTwoDimensional( x, y, axes=([op.x_out_axes.flatten(), op.y_out_axes.flatten()])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)
def visit(self, op): x, y = op.args reduction_axes = op.reduction_axes out_axes = op.axes if len(reduction_axes) == 0: # TODO: this is a weird case, should we really support it? d = make_axis(1) reduction_axes = make_axes((d, )) x = broadcast(x, x.axes + reduction_axes) y = broadcast(y, reduction_axes + y.axes) if x.is_scalar: x, y = y, x if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(reduction_axes) > 0: out = out * reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: # move reduction_axes to end x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes) # move reduction axes to front y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes)) # flatten non-reduction axes together and reduction axes together x = flatten_at(x, len(x.axes) - len(reduction_axes)) # flatten non-reduction axes together and reduction axes together y = flatten_at(y, len(reduction_axes)) if len(out_axes) == 0: out = DotLowDimension(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotLowDimension(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotLowDimension(x, y, axes=x.axes[0]) else: out = DotLowDimension(x, y, axes=([ op.x_out_axes.flatten(True), op.y_out_axes.flatten(True) ])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)