Ejemplo n.º 1
0
 def _get_transformed_random_signs(self):
   transformed_random_signs = []
   for loss in self.layers.losses:
     with tf.colocate_with(self.layers.loss_colocation_ops[loss]):
       transformed_random_signs.append(
           tf.sqrt(self.layers.loss_coeffs[loss])*loss.multiply_fisher_factor(
               utils.generate_random_signs(loss.fisher_factor_inner_shape)))
   return transformed_random_signs
Ejemplo n.º 2
0
  def _get_transformed_random_signs(self):
    if self.mat_type == "Fisher":
      mult_func = lambda loss, index: loss.multiply_fisher_factor(index)
      inner_shape_func = lambda loss: loss.fisher_factor_inner_shape
    elif self.mat_type == "GGN":
      mult_func = lambda loss, index: loss.multiply_ggn_factor(index)
      inner_shape_func = lambda loss: loss.ggn_factor_inner_shape

    transformed_random_signs = []
    for loss in self.layers.losses:
      with tf.colocate_with(self.layers.loss_colocation_ops[loss]):
        value = mult_func(loss,
                          utils.generate_random_signs(inner_shape_func(loss)))
        coeff = tf.cast(self.layers.loss_coeffs[loss], dtype=value.dtype)
        transformed_random_signs.append(tf.sqrt(coeff) * value)
    return transformed_random_signs
Ejemplo n.º 3
0
    def _get_transformed_random_signs(self):
        """No docstring required."""
        if self.mat_type == "Fisher":
            mult_func = lambda loss, index: loss.multiply_fisher_factor(index)
            inner_shape_func = lambda loss: loss.fisher_factor_inner_shape
        elif self.mat_type == "GGN":
            mult_func = lambda loss, index: loss.multiply_ggn_factor(index)
            inner_shape_func = lambda loss: loss.ggn_factor_inner_shape

        transformed_random_signs = []
        for loss in self.layers.losses:
            with tf.colocate_with(self.layers.loss_colocation_ops[loss]):
                transformed_random_signs.append(
                    tf.sqrt(self.layers.loss_coeffs[loss]) * mult_func(
                        loss,
                        utils.generate_random_signs(inner_shape_func(loss))))
        return transformed_random_signs