def transform(self, input_expr: dy.Expression, mask: Optional[batchers.Mask]=None): """ Apply batch norm. Args: input_expr: input mask: compute statistics only over unmasked parts of the input expression """ dim_in = input_expr.dim() param_bn_gamma = dy.parameter(self.gamma) param_bn_beta = dy.parameter(self.beta) if self.train: num_unmasked = 0 if mask is not None: input_expr = set_masked_to_mean(mask, input_expr, self.time_first) num_unmasked = (mask.np_arr.size - np.count_nonzero(mask.np_arr)) * broadcast_factor(mask, input_expr) bn_mean = dy.moment_dim(input_expr, self.get_stat_dimensions(), 1, True, num_unmasked) neg_bn_mean_reshaped = -dy.reshape(-bn_mean, self.get_normalizer_dimensionality()) self.population_running_mean += (-BN_MOMENTUM) * self.population_running_mean + BN_MOMENTUM * bn_mean.npvalue() bn_std = dy.std_dim(input_expr, self.get_stat_dimensions(), True, num_unmasked) self.population_running_std += (-BN_MOMENTUM) * self.population_running_std + BN_MOMENTUM * bn_std.npvalue() else: neg_bn_mean_reshaped = -dy.reshape(dy.inputVector(self.population_running_mean), self.get_normalizer_dimensionality()) bn_std = dy.inputVector(self.population_running_std) bn_numerator = input_expr + neg_bn_mean_reshaped bn_xhat = dy.cdiv(bn_numerator, dy.reshape(bn_std, self.get_normalizer_dimensionality()) + BN_EPS) bn_y = dy.cmult(param_bn_gamma, bn_xhat) + param_bn_beta # y = gamma * xhat + beta dim_out = bn_y.dim() self.save_processed_arg("population_running_mean", self.population_running_mean) self.save_processed_arg("population_running_std", self.population_running_std) assert dim_out == dim_in return bn_y
def normalize(self, activations): means = dy.mean_dim(activations, d=[1], b=0) stds = dy.std_dim(activations, d=[1], b=0) activations = dy.cdiv((activations - means), stds) return activations