def local_abstract_batch_norm_train_grad(node): if not isinstance(node.op, AbstractBatchNormTrainGrad): return None x, dy, scale, x_mean, x_invstd, epsilon = node.inputs axes = node.op.axes if min(axes) < 0 or max(axes) > x.ndim: return None if not isinstance(x.type, TensorType) or \ not isinstance(dy.type, TensorType) or \ not isinstance(scale.type, TensorType) or \ not isinstance(x_mean.type, TensorType) or \ not isinstance(x_invstd.type, TensorType) or \ not isinstance(epsilon.type, TensorType): return None x_diff = x - x_mean mean_dy_x_diff = T.mean(dy * x_diff, axis=axes, keepdims=True) c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd ** 3)) g_wrt_inputs = scale * (c - T.mean(c, axis=axes, keepdims=True)) g_wrt_scale = T.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True) g_wrt_bias = T.sum(dy, axis=axes, keepdims=True) results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias] results = [T.patternbroadcast(r, r_orig.broadcastable) for (r, r_orig) in zip(results, node.outputs)] for var in theano.gof.graph.variables(node.inputs, results): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return results
def grad(self, inp, grads): x, dy, scale, x_mean, x_invstd, epsilon = inp ddinputs, ddscale, ddbias = grads x_diff = x - x_mean mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True) # compute gradients given each of the output gradients g_wrt_x = 0 g_wrt_dy = 0 g_wrt_scale = 0 g_wrt_x_mean = 0 g_wrt_x_invstd = 0 if not isinstance(ddinputs.type, theano.gradient.DisconnectedType): ccc = scale * (ddinputs - T.mean(ddinputs, axis=self.axes, keepdims=True)) ddd = (x_invstd**3) * ( ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) + dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True)) g_wrt_x = g_wrt_x - ddd g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) - ( (x_invstd**3) * x_diff * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))) eee = (dy * x_invstd) - ((x_invstd**3) * x_diff * mean_dy_x_diff) g_wrt_scale = g_wrt_scale + T.sum( ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)), axis=self.axes, keepdims=True) g_wrt_x_mean = g_wrt_x_mean + T.sum( ddd, axis=self.axes, keepdims=True) g_wrt_x_invstd = g_wrt_x_invstd + T.sum( ccc * (dy - 3 * (x_invstd**2) * x_diff * mean_dy_x_diff), axis=self.axes, keepdims=True) if not isinstance(ddscale.type, theano.gradient.DisconnectedType): g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy) g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff) g_wrt_x_mean = g_wrt_x_mean - ( x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True)) g_wrt_x_invstd = g_wrt_x_invstd + ( ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True)) if not isinstance(ddbias.type, theano.gradient.DisconnectedType): g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias) # depending on which output gradients are given, # some inputs should be disconnected results = [ g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd, theano.gradient.DisconnectedType()() ] return [ theano.gradient.DisconnectedType()() if r is 0 else r for r in results ]
def grad(self, inp, grads): x, dy, scale, x_mean, x_invstd, epsilon = inp ddinputs, ddscale, ddbias = grads x_diff = x - x_mean mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True) # compute gradients given each of the output gradients g_wrt_x = 0 g_wrt_dy = 0 g_wrt_scale = 0 g_wrt_x_mean = 0 g_wrt_x_invstd = 0 if not isinstance(ddinputs.type, theano.gradient.DisconnectedType): ccc = scale * (ddinputs - T.mean(ddinputs, axis=self.axes, keepdims=True)) ddd = (x_invstd ** 3) * (ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) + dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True)) g_wrt_x = g_wrt_x - ddd g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) - ((x_invstd ** 3) * x_diff * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))) eee = (dy * x_invstd) - ((x_invstd ** 3) * x_diff * mean_dy_x_diff) g_wrt_scale = g_wrt_scale + T.sum(ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)), axis=self.axes, keepdims=True) g_wrt_x_mean = g_wrt_x_mean + T.sum(ddd, axis=self.axes, keepdims=True) g_wrt_x_invstd = g_wrt_x_invstd + T.sum(ccc * (dy - 3 * (x_invstd ** 2) * x_diff * mean_dy_x_diff), axis=self.axes, keepdims=True) if not isinstance(ddscale.type, theano.gradient.DisconnectedType): g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy) g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff) g_wrt_x_mean = g_wrt_x_mean - (x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True)) g_wrt_x_invstd = g_wrt_x_invstd + (ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True)) if not isinstance(ddbias.type, theano.gradient.DisconnectedType): g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias) # depending on which output gradients are given, # some inputs should be disconnected results = [g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd, theano.gradient.DisconnectedType()()] return [theano.gradient.DisconnectedType()() if r is 0 else r for r in results]