def ref_group_normalization(x, beta, gamma, num_groups, channel_axis, batch_axis, eps, output_stat): cdim = x.shape[channel_axis] if cdim % num_groups > 0: raise ValueError() shape = x.shape[:channel_axis] + (num_groups, int(cdim / num_groups)) if channel_axis < len(x.shape) - 1: shape += x.shape[channel_axis + 1:] tmp = x.reshape(shape).copy() ignore_axes = _force_list(batch_axis) + [ channel_axis, ] axes = tuple(_get_axes_excluding(len(shape), ignore_axes)) x_mean = tmp.mean(axis=axes, keepdims=True) x_std = tmp.std(axis=axes, keepdims=True) if output_stat: return ((tmp - x_mean) / (x_std + eps) * gamma + beta).reshape( x.shape), x_mean, x_std return ((tmp - x_mean) / (x_std + eps) * gamma + beta).reshape(x.shape)
def ref_weight_standardization(w, channel_axis, eps, output_stat): axes = tuple(_get_axes_excluding(len(w.shape), channel_axis)) w_mean = w.mean(axis=axes, keepdims=True) w_std = w.std(axis=axes, keepdims=True) if output_stat: return (w - w_mean) / (w_std + eps), w_mean, w_std return (w - w_mean) / (w_std + eps)
def ref_layer_normalization(x, beta, gamma, batch_axis, eps, output_stat): batch_axis = _force_list(batch_axis) axes = tuple(_get_axes_excluding(len(x.shape), batch_axis)) x_mean = x.mean(axis=axes, keepdims=True) x_std = x.std(axis=axes, keepdims=True) if output_stat: return (x - x_mean) / (x_std + eps) * gamma + beta, x_mean, x_std return (x - x_mean) / (x_std + eps) * gamma + beta
def ref_weight_standardization(w, channel_axis, eps, output_stat): axes = tuple(_get_axes_excluding(len(w.shape), channel_axis)) w_mean = w.mean(axis=axes, keepdims=True) w_var = w.var(axis=axes, keepdims=True) norm = (w - w_mean) / (w_var + eps)**0.5 if output_stat: return norm, w_mean, w_var return norm
def ref_instance_normalization(x, beta, gamma, channel_axis, batch_axis, eps, output_stat): ignore_axes = _force_list(batch_axis) + [channel_axis, ] axes = tuple(_get_axes_excluding(len(x.shape), ignore_axes)) x_mean = x.mean(axis=axes, keepdims=True) x_std = x.std(axis=axes, keepdims=True) if output_stat: return (x - x_mean) / (x_std + eps) * gamma + beta, x_mean, x_std return (x - x_mean) / (x_std + eps) * gamma + beta
def ref_layer_normalization(x, beta, gamma, batch_axis, eps, output_stat): batch_axis = _force_list(batch_axis) axes = tuple(_get_axes_excluding(len(x.shape), batch_axis)) x_mean = x.mean(axis=axes, keepdims=True) x_var = x.var(axis=axes, keepdims=True) norm = (x - x_mean) / (x_var + eps)**0.5 if gamma is not None: norm *= gamma if beta is not None: norm += beta if output_stat: return norm, x_mean, x_var return norm
def ref_group_normalization(x, beta, gamma, num_groups, channel_axis, batch_axis, eps, output_stat): cdim = x.shape[channel_axis] if cdim % num_groups > 0: raise ValueError() shape = x.shape[:channel_axis] + (num_groups, int(cdim / num_groups)) channel_axis += x.ndim * (channel_axis < 0) batch_axis = _force_list(batch_axis) batch_axis = [b + x.ndim * (b < 0) for b in batch_axis] if channel_axis < len(x.shape) - 1: shape += x.shape[channel_axis + 1:] tmp = x.reshape(shape).copy() ignore_axes = batch_axis + [ channel_axis, ] axes = tuple(_get_axes_excluding(len(shape), ignore_axes)) x_mean = tmp.mean(axis=axes, keepdims=True) x_var = tmp.var(axis=axes, keepdims=True) norm = (tmp - x_mean) / (x_var + eps)**0.5 norm = norm.reshape(x.shape) if gamma is not None: norm *= gamma if beta is not None: norm += beta if output_stat: return norm, x_mean, x_var return norm