Ejemplo n.º 1
0
 def hybrid_forward(self, F, x, gamma, beta):
     if self._axis == 1:
         return F.InstanceNorm(x,
                               gamma,
                               beta,
                               name='fwd',
                               eps=self._epsilon)
     x = x.swapaxes(1, self._axis)
     return F.InstanceNorm(x, gamma, beta, name='fwd',
                           eps=self._epsilon).swapaxes(1, self._axis)
Ejemplo n.º 2
0
def verify_instance_norm_rewrite(shp, eps):
    # assert len(shp) == 4 # NCHW
    assert len(shp) >= 3
    vshp = (shp[1], )
    data_np = np.random.uniform(size=shp)
    gamma_np = np.random.uniform(size=vshp)
    beta_np = np.random.uniform(size=vshp)
    x = nd.array(data_np)
    gamma = nd.array(gamma_np)
    beta = nd.array(beta_np)

    # org op
    y = nd.InstanceNorm(x, gamma=gamma, beta=beta, eps=eps)

    # rewrite op
    axis = [i for i in range(len(shp)) if i != 1]
    for i in axis:
        gamma = nd.expand_dims(gamma, axis=i)
        beta = nd.expand_dims(beta, axis=i)

    n = np.product(shp[2:])
    mean = nd.sum(x, axis=axis, keepdims=True) / n
    dev = x - mean
    var = nd.sum(dev * dev, axis=axis, keepdims=True) / n
    std = nd.sqrt(var) + eps
    frac = dev / std
    z = frac * gamma + beta

    # compare
    assert z.shape == y.shape
    zn, zp = get_norm(z)
    yn, yp = get_norm(y)
    rn = np.linalg.norm(zp - yp)
    print(zn, yn, rn)
Ejemplo n.º 3
0
 def hybrid_forward(self, F, x, gamma, beta):
     return F.InstanceNorm(x, gamma, beta, name='fwd', **self._kwargs)