def construct_batchnorm_fprop_pattern(self): """ Generate graph op that represents a pattern for batchnorm fprop operation. self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta Returns: Single pattern that matches batchnorm fprop op """ self.batchnorm_fprop_input_tensor_label = "in_obj" self.batchnorm_fprop_gamma_label = "gamma" self.batchnorm_fprop_beta_label = "beta" self.batchnorm_fprop_variance_label = "variance" self.batchnorm_fprop_epsilon_label = "epsilon" self.batchnorm_fprop_mean_label = "mean" # bind the label to the op's which needed to be updated in the dict in_obj = PatternLabelOp(self.batchnorm_fprop_input_tensor_label, (lambda op: isinstance(op, ContiguousOp))) flatten_tensor = PatternSkipOp(in_obj, (lambda op: isinstance(op, Flatten))) gamma = PatternLabelOp(self.batchnorm_fprop_gamma_label, (lambda op: isinstance(op, BroadcastOp))) beta = PatternLabelOp(self.batchnorm_fprop_beta_label, (lambda op: isinstance(op, BroadcastOp))) variance = PatternLabelOp(self.batchnorm_fprop_variance_label, (lambda op: isinstance(op, Divide))) epsilon = PatternLabelOp(self.batchnorm_fprop_epsilon_label, (lambda op: isinstance(op, BroadcastOp))) mean = PatternLabelOp(self.batchnorm_fprop_mean_label, (lambda op: isinstance(op, Divide))) # construct the fprop batchnorm pattern matching the computation graph # ng.sqrt(xvar + self.eps) SqrtofVarianceAndEps = ng.sqrt(ng.add(variance, epsilon)) # ng.reciprocal(ng.sqrt(xvar + self.eps)) reciprocal_op = ng.reciprocal(SqrtofVarianceAndEps) reciprocal_op_w_braodcast = ng.PatternSkipOp(reciprocal_op, lambda op: isinstance(op, BroadcastOp)) mean_bcast = ng.PatternSkipOp(mean, lambda op: isinstance(op, BroadcastOp)) # (in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps)) mul_op_1 = ng.multiply(ng.subtract(flatten_tensor, mean_bcast), reciprocal_op_w_braodcast) # "self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) MultiplyGamma = ng.multiply(mul_op_1, gamma) # self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta AddBeta = ng.Unflatten(ng.Add(MultiplyGamma, beta)) return AddBeta
def construct_relu_bprop_pattern(self): """ Generate graph op that represents a pattern for Relu backprop operation. delta * greater(x, 0) + delta * slope * less(x, 0) Returns: Single pattern that matches Relu bprop op """ # We want to match x tensor, slope and delta for Relu. self.relu_bwd_slope_label = "S" self.relu_bwd_x_label = "X" self.relu_bwd_delta_label = "D" # construct 1st operand of Add zero = ng.constant(0) zero_w_broadcast = ng.PatternSkipOp( zero, (lambda op: isinstance(op, BroadcastOp))) x = ng.PatternLabelOp( self.relu_bwd_x_label, (lambda op: not op.is_scalar)) # X is not scalar. greater_op = Greater(x, zero_w_broadcast) delta = PatternLabelOp( self.relu_bwd_delta_label, (lambda op: not op.is_scalar)) # delta is not scalar. mul_greater_delta_op = Multiply(greater_op, delta) # Construct 2nd operand of Add # We bind slope op to S only if it is scalar. slope = PatternLabelOp(self.relu_bwd_slope_label, (lambda op: op.is_scalar)) less_op = Less(x, zero_w_broadcast) mul_slope_delta_op = Multiply(slope, delta) mul_slope_delta_less_op = Multiply(less_op, mul_slope_delta_op) add_op = Add(mul_greater_delta_op, mul_slope_delta_less_op) return add_op
def construct_batchnorm_bprop_pattern(self): """ Generate graph op that represents a pattern for batchnorm backprop operation. dgamma = np.sum(delta * xhat) dbeta = np.sum(delta) dx = gamma_scale * (delta - (xhat * dgamma + dbeta) / m) In this pattern we are only generating the pattern for dx. Returns: Single pattern that matches batchnorm bprop op """ self.batchnorm_bprop_input_tensor = "input_tensor" self.batchnorm_bprop_delta = "delta" self.batchnorm_bprop_gamma_label = "gamma" self.batchnorm_bprop_var_label = "var" self.batchnorm_bprop_ivar_label = "ivar" self.batchnorm_bprop_xmu1_label = "xmu1" self.batchnorm_bprop_xmu2_label = "xmu2" self.batchnorm_bprop_negative_inverse_sqrtvar = "negative_inverse_sqrtvar" self.batchnorm_bprop_inverse_sqrtvar = "inverse_sqrtvar" self.batchnorm_bprop_sqrtvar_label = "sqrtvar" self.batchnorm_bprop_sqrsum = "sqrsum" self.batchnorm_bprop_mean_1 = "mean_1" self.batchnorm_bprop_mean_2 = "mean_2" self.batchnorm_bprop_input_sum = "input_sum" # bind the op's to the label input_tensor = PatternLabelOp(self.batchnorm_bprop_input_tensor, (lambda op: isinstance(op, Flatten))) var = PatternLabelOp(self.batchnorm_bprop_var_label, (lambda op: isinstance(op, Divide))) gamma = PatternLabelOp(self.batchnorm_bprop_gamma_label, (lambda op: isinstance(op, BroadcastOp))) delta = PatternLabelOp(self.batchnorm_bprop_delta, (lambda op: isinstance(op, Flatten))) xmu1 = PatternLabelOp(self.batchnorm_bprop_xmu1_label, (lambda op: isinstance(op, Subtract))) xmu2 = PatternLabelOp(self.batchnorm_bprop_xmu2_label, (lambda op: isinstance(op, Subtract))) ivar = PatternLabelOp(self.batchnorm_bprop_ivar_label, (lambda op: isinstance(op, BroadcastOp))) negative_inverse_sqrtvar = PatternLabelOp( self.batchnorm_bprop_negative_inverse_sqrtvar, (lambda op: isinstance(op, NegativeOp))) inverse_sqrtvar = PatternLabelOp( self.batchnorm_bprop_inverse_sqrtvar, (lambda op: isinstance(op, ReciprocalOp))) sqrtvar = PatternLabelOp(self.batchnorm_bprop_sqrtvar_label, (lambda op: isinstance(op, SqrtOp))) sqrsum = PatternLabelOp(self.batchnorm_bprop_sqrsum, (lambda op: isinstance(op, Sum))) mean_1 = PatternLabelOp(self.batchnorm_bprop_mean_1, (lambda op: isinstance(op, Divide))) mean_2 = PatternLabelOp(self.batchnorm_bprop_mean_2, (lambda op: isinstance(op, Divide))) input_sum = PatternLabelOp(self.batchnorm_bprop_input_sum, (lambda op: isinstance(op, Sum))) constant_point_5 = ng.constant(0.5) constant_point_5_w_broadcast = ng.PatternSkipOp( constant_point_5, lambda op: isinstance(op, BroadcastOp)) constant_two = ng.constant(2) constant_two_w_broadcast = ng.PatternSkipOp( constant_two, lambda op: isinstance(op, BroadcastOp)) # construct the pattern dxhat = Multiply(gamma, delta) # divar = np.sum(dxhat*xmu, axis=0) divar = Sum(Multiply(dxhat, xmu1)) # dxmu1 = dxhat * ivar dxmu1 = Multiply(dxhat, ivar) # dsqrtvar = -1. /(sqrtvar**2) * divar dsqrtvar = Multiply( Multiply(inverse_sqrtvar, negative_inverse_sqrtvar), divar) # dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar dvar = Divide(Multiply(dsqrtvar, constant_point_5_w_broadcast), sqrtvar) # dsq = 1. / N * np.ones((N, D)) * dvar dsq = Divide(Multiply(dvar, var), sqrsum) dsq_w_broadcast = ng.PatternSkipOp( dsq, (lambda op: isinstance(op, BroadcastOp))) # dxmu2 = 2 * xmu * dsq dxmu2 = Multiply(xmu2, Multiply(constant_two_w_broadcast, dsq_w_broadcast)) # dx1 = (dxmu1 + dxmu2) # dmu = -1 * np.sum(dxmu1 + dxmu2, axis=0) # dx2 = 1. /N * np.ones((N,D)) * dmu # dx = dx1 + dx2 dxmu2_mul = Multiply(Sum(ng.negative(dxmu2)), mean_2) dxmu2_div = Divide(dxmu2_mul, input_sum) dxmu2_div_w_broadcast = ng.PatternSkipOp( dxmu2_div, (lambda op: isinstance(op, BroadcastOp))) dxmu2_div_plus_dxmu2 = Add(dxmu2_div_w_broadcast, dxmu2) dx1 = Add(dxmu1, dxmu2_div_plus_dxmu2) dxmu1_mul = Multiply(Sum(ng.negative(dxmu1)), mean_1) dxmu1_div = Divide(dxmu1_mul, Sum(input_tensor)) dxmu1_div_w_broadcast = ng.PatternSkipOp( dxmu1_div, (lambda op: isinstance(op, BroadcastOp))) dx = Add(dxmu1_div_w_broadcast, dx1) return dx