Esempio n. 1
0
  def transform(self, input_expr: dy.Expression, mask: Optional[batchers.Mask]=None):
    """
    Apply batch norm.

    Args:
      input_expr: input
      mask: compute statistics only over unmasked parts of the input expression
    """
    dim_in = input_expr.dim()
    param_bn_gamma = dy.parameter(self.gamma)
    param_bn_beta = dy.parameter(self.beta)
    if self.train:
      num_unmasked = 0
      if mask is not None:
        input_expr = set_masked_to_mean(mask, input_expr, self.time_first)
        num_unmasked = (mask.np_arr.size - np.count_nonzero(mask.np_arr)) * broadcast_factor(mask, input_expr)
      bn_mean = dy.moment_dim(input_expr, self.get_stat_dimensions(), 1, True, num_unmasked)
      neg_bn_mean_reshaped = -dy.reshape(-bn_mean, self.get_normalizer_dimensionality())
      self.population_running_mean += (-BN_MOMENTUM) * self.population_running_mean + BN_MOMENTUM * bn_mean.npvalue()
      bn_std = dy.std_dim(input_expr, self.get_stat_dimensions(), True, num_unmasked)
      self.population_running_std += (-BN_MOMENTUM) * self.population_running_std + BN_MOMENTUM * bn_std.npvalue()
    else:
      neg_bn_mean_reshaped = -dy.reshape(dy.inputVector(self.population_running_mean), self.get_normalizer_dimensionality())
      bn_std = dy.inputVector(self.population_running_std)
    bn_numerator = input_expr + neg_bn_mean_reshaped
    bn_xhat = dy.cdiv(bn_numerator, dy.reshape(bn_std, self.get_normalizer_dimensionality()) + BN_EPS)
    bn_y = dy.cmult(param_bn_gamma, bn_xhat) + param_bn_beta # y = gamma * xhat + beta
    dim_out = bn_y.dim()
    self.save_processed_arg("population_running_mean", self.population_running_mean)
    self.save_processed_arg("population_running_std", self.population_running_std)
    assert dim_out == dim_in
    return bn_y
Esempio n. 2
0
    def normalize(self, activations):

        means = dy.mean_dim(activations, d=[1], b=0)
        stds = dy.std_dim(activations, d=[1], b=0)
        activations = dy.cdiv((activations - means), stds)
        return activations