def visit(self, op): x, y = op.args reduction_axes = op.reduction_axes out_axes = op.axes if len(reduction_axes) == 0: # TODO: this is a weird case, should we really support it? d = make_axis(1) reduction_axes = make_axes((d, )) x = broadcast(x, x.axes + reduction_axes) y = broadcast(y, reduction_axes + y.axes) if x.is_scalar: x, y = y, x if y.is_scalar: if x.is_scalar: out = x.scalar_op * y.scalar_op if len(reduction_axes) > 0: out = out * reduction_axes.size out = broadcast(out, op.axes) else: out = Sum(x, reduction_axes) * y.scalar_op out = broadcast(out, op.axes) else: # move reduction_axes to end x = axes_with_order(x, (x.axes - reduction_axes) + reduction_axes) # move reduction axes to front y = axes_with_order(y, reduction_axes + (y.axes - reduction_axes)) # flatten non-reduction axes together and reduction axes together x = flatten_at(x, len(x.axes) - len(reduction_axes)) # flatten non-reduction axes together and reduction axes together y = flatten_at(y, len(reduction_axes)) if len(out_axes) == 0: out = DotLowDimension(x, y, axes=()) elif len(x.axes) == 1: y = Transpose(y) out = DotLowDimension(y, x, axes=y.axes[0]) elif len(y.axes) == 1: out = DotLowDimension(x, y, axes=x.axes[0]) else: out = DotLowDimension(x, y, axes=([ op.x_out_axes.flatten(True), op.y_out_axes.flatten(True) ])) out = unflatten(out) out = ReorderAxes(out, out_axes) self.replace_op(op, out)
def construct_batchnorm_bprop_pattern(self): """ Generate graph op that represents a pattern for batchnorm backprop operation. dgamma = np.sum(delta * xhat) dbeta = np.sum(delta) dx = gamma_scale * (delta - (xhat * dgamma + dbeta) / m) In this pattern we are only generating the pattern for dx. Returns: Single pattern that matches batchnorm bprop op """ self.batchnorm_bprop_input_tensor = "input_tensor" self.batchnorm_bprop_delta = "delta" self.batchnorm_bprop_gamma_label = "gamma" self.batchnorm_bprop_var_label = "var" self.batchnorm_bprop_ivar_label = "ivar" self.batchnorm_bprop_xmu1_label = "xmu1" self.batchnorm_bprop_xmu2_label = "xmu2" self.batchnorm_bprop_negative_inverse_sqrtvar = "negative_inverse_sqrtvar" self.batchnorm_bprop_inverse_sqrtvar = "inverse_sqrtvar" self.batchnorm_bprop_sqrtvar_label = "sqrtvar" self.batchnorm_bprop_sqrsum = "sqrsum" self.batchnorm_bprop_mean_1 = "mean_1" self.batchnorm_bprop_mean_2 = "mean_2" self.batchnorm_bprop_input_sum = "input_sum" # bind the op's to the label input_tensor = PatternLabelOp(self.batchnorm_bprop_input_tensor, (lambda op: isinstance(op, Flatten))) var = PatternLabelOp(self.batchnorm_bprop_var_label, (lambda op: isinstance(op, Divide))) gamma = PatternLabelOp(self.batchnorm_bprop_gamma_label, (lambda op: isinstance(op, BroadcastOp))) delta = PatternLabelOp(self.batchnorm_bprop_delta, (lambda op: isinstance(op, Flatten))) xmu1 = PatternLabelOp(self.batchnorm_bprop_xmu1_label, (lambda op: isinstance(op, Subtract))) xmu2 = PatternLabelOp(self.batchnorm_bprop_xmu2_label, (lambda op: isinstance(op, Subtract))) ivar = PatternLabelOp(self.batchnorm_bprop_ivar_label, (lambda op: isinstance(op, BroadcastOp))) negative_inverse_sqrtvar = PatternLabelOp( self.batchnorm_bprop_negative_inverse_sqrtvar, (lambda op: isinstance(op, NegativeOp))) inverse_sqrtvar = PatternLabelOp( self.batchnorm_bprop_inverse_sqrtvar, (lambda op: isinstance(op, ReciprocalOp))) sqrtvar = PatternLabelOp(self.batchnorm_bprop_sqrtvar_label, (lambda op: isinstance(op, SqrtOp))) sqrsum = PatternLabelOp(self.batchnorm_bprop_sqrsum, (lambda op: isinstance(op, Sum))) mean_1 = PatternLabelOp(self.batchnorm_bprop_mean_1, (lambda op: isinstance(op, Divide))) mean_2 = PatternLabelOp(self.batchnorm_bprop_mean_2, (lambda op: isinstance(op, Divide))) input_sum = PatternLabelOp(self.batchnorm_bprop_input_sum, (lambda op: isinstance(op, Sum))) constant_point_5 = ng.constant(0.5) constant_point_5_w_broadcast = ng.PatternSkipOp( constant_point_5, lambda op: isinstance(op, BroadcastOp)) constant_two = ng.constant(2) constant_two_w_broadcast = ng.PatternSkipOp( constant_two, lambda op: isinstance(op, BroadcastOp)) # construct the pattern dxhat = Multiply(gamma, delta) # divar = np.sum(dxhat*xmu, axis=0) divar = Sum(Multiply(dxhat, xmu1)) # dxmu1 = dxhat * ivar dxmu1 = Multiply(dxhat, ivar) # dsqrtvar = -1. /(sqrtvar**2) * divar dsqrtvar = Multiply( Multiply(inverse_sqrtvar, negative_inverse_sqrtvar), divar) # dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar dvar = Divide(Multiply(dsqrtvar, constant_point_5_w_broadcast), sqrtvar) # dsq = 1. / N * np.ones((N, D)) * dvar dsq = Divide(Multiply(dvar, var), sqrsum) dsq_w_broadcast = ng.PatternSkipOp( dsq, (lambda op: isinstance(op, BroadcastOp))) # dxmu2 = 2 * xmu * dsq dxmu2 = Multiply(xmu2, Multiply(constant_two_w_broadcast, dsq_w_broadcast)) # dx1 = (dxmu1 + dxmu2) # dmu = -1 * np.sum(dxmu1 + dxmu2, axis=0) # dx2 = 1. /N * np.ones((N,D)) * dmu # dx = dx1 + dx2 dxmu2_mul = Multiply(Sum(ng.negative(dxmu2)), mean_2) dxmu2_div = Divide(dxmu2_mul, input_sum) dxmu2_div_w_broadcast = ng.PatternSkipOp( dxmu2_div, (lambda op: isinstance(op, BroadcastOp))) dxmu2_div_plus_dxmu2 = Add(dxmu2_div_w_broadcast, dxmu2) dx1 = Add(dxmu1, dxmu2_div_plus_dxmu2) dxmu1_mul = Multiply(Sum(ng.negative(dxmu1)), mean_1) dxmu1_div = Divide(dxmu1_mul, Sum(input_tensor)) dxmu1_div_w_broadcast = ng.PatternSkipOp( dxmu1_div, (lambda op: isinstance(op, BroadcastOp))) dx = Add(dxmu1_div_w_broadcast, dx1) return dx