def hybrid_forward(self, F, x, gamma, beta): if self._axis == 1: return F.InstanceNorm(x, gamma, beta, name='fwd', eps=self._epsilon) x = x.swapaxes(1, self._axis) return F.InstanceNorm(x, gamma, beta, name='fwd', eps=self._epsilon).swapaxes(1, self._axis)
def verify_instance_norm_rewrite(shp, eps): # assert len(shp) == 4 # NCHW assert len(shp) >= 3 vshp = (shp[1], ) data_np = np.random.uniform(size=shp) gamma_np = np.random.uniform(size=vshp) beta_np = np.random.uniform(size=vshp) x = nd.array(data_np) gamma = nd.array(gamma_np) beta = nd.array(beta_np) # org op y = nd.InstanceNorm(x, gamma=gamma, beta=beta, eps=eps) # rewrite op axis = [i for i in range(len(shp)) if i != 1] for i in axis: gamma = nd.expand_dims(gamma, axis=i) beta = nd.expand_dims(beta, axis=i) n = np.product(shp[2:]) mean = nd.sum(x, axis=axis, keepdims=True) / n dev = x - mean var = nd.sum(dev * dev, axis=axis, keepdims=True) / n std = nd.sqrt(var) + eps frac = dev / std z = frac * gamma + beta # compare assert z.shape == y.shape zn, zp = get_norm(z) yn, yp = get_norm(y) rn = np.linalg.norm(zp - yp) print(zn, yn, rn)
def hybrid_forward(self, F, x, gamma, beta): return F.InstanceNorm(x, gamma, beta, name='fwd', **self._kwargs)