def InstanceNormalization( num_channel, initial_scale=1, initial_bias=0, epsilon=C.default_override_or(0.00001), name=''): """ Instance Normalization (2016) """ epsilon = C.get_default_override(InstanceNormalization, epsilon=epsilon) dtype = C.get_default_override(None, dtype=C.default_override_or(np.float32)) scale = C.Parameter(num_channel, init=initial_scale, name='scale') bias = C.Parameter(num_channel, init=initial_bias, name='bias') epsilon = np.asarray(epsilon, dtype=dtype) @C.BlockFunction('InstanceNormalization', name) def instance_normalization(x): mean = C.reduce_mean(x, axis=(1, 2)) x0 = x - mean std = C.sqrt(C.reduce_mean(x0 * x0, axis=(1, 2))) if epsilon != 0: std += epsilon x_hat = x0 / std return x_hat * C.reshape(scale, (-1, 1, 1)) + C.reshape(bias, (-1, 1, 1)) return instance_normalization
def Test(some_param=default_override_or(13)): some_param = get_default_override(Test, some_param=some_param) return some_param