Esempio n. 1
0
 def _get_transformed_random_signs(self):
   transformed_random_signs = []
   for loss in self._layers.losses:
     transformed_random_signs.append(
         loss.multiply_fisher_factor(
             utils.generate_random_signs(loss.fisher_factor_inner_shape)))
   return transformed_random_signs
Esempio n. 2
0
 def _get_transformed_random_signs(self):
   transformed_random_signs = []
   for loss in self._layers.losses:
     transformed_random_signs.append(
         loss.multiply_fisher_factor(
             utils.generate_random_signs(loss.fisher_factor_inner_shape)))
   return transformed_random_signs
Esempio n. 3
0
 def _get_transformed_random_signs(self):
   transformed_random_signs = []
   for loss in self._layers.losses:
     with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
       transformed_random_signs.append(
           loss.multiply_fisher_factor(
               utils.generate_random_signs(loss.fisher_factor_inner_shape)))
   return transformed_random_signs
Esempio n. 4
0
 def _get_transformed_random_signs(self):
   transformed_random_signs = []
   for loss in self._layers.losses:
     with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
       transformed_random_signs.append(
           loss.multiply_fisher_factor(
               utils.generate_random_signs(loss.fisher_factor_inner_shape)))
   return transformed_random_signs
Esempio n. 5
0
    def _setup(self, cov_ema_decay):
        """Sets up the various operations.

    Args:
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.

    Returns:
      A triple (covs_update_op, invs_update_op, inv_updates_dict), where
      covs_update_op is the grouped Op to update all the covariance estimates,
      invs_update_op is the grouped Op to update all the inverses, and
      inv_updates_dict is a dict mapping Op names to individual inverse updates.

    Raises:
      ValueError: If estimation_mode was improperly specified at construction.
    """
        damping = self.damping

        fisher_blocks_list = self._layers.get_blocks()

        tensors_to_compute_grads = [
            fb.tensors_to_compute_grads() for fb in fisher_blocks_list
        ]
        tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads)

        if self._estimation_mode == "gradients":
            grads_flat = gradients_impl.gradients(
                self._layers.total_sampled_loss(),
                tensors_to_compute_grads_flat)
            grads_all = nest.pack_sequence_as(tensors_to_compute_grads,
                                              grads_flat)
            grads_lists = tuple((grad, ) for grad in grads_all)

        elif self._estimation_mode == "empirical":
            grads_flat = gradients_impl.gradients(
                self._layers.total_loss(), tensors_to_compute_grads_flat)
            grads_all = nest.pack_sequence_as(tensors_to_compute_grads,
                                              grads_flat)
            grads_lists = tuple((grad, ) for grad in grads_all)

        elif self._estimation_mode == "curvature_prop":
            loss_inputs = list(loss.inputs for loss in self._layers.losses)
            loss_inputs_flat = nest.flatten(loss_inputs)

            transformed_random_signs = list(
                loss.multiply_fisher_factor(
                    utils.generate_random_signs(
                        loss.fisher_factor_inner_shape))
                for loss in self._layers.losses)

            transformed_random_signs_flat = nest.flatten(
                transformed_random_signs)

            grads_flat = gradients_impl.gradients(
                loss_inputs_flat,
                tensors_to_compute_grads_flat,
                grad_ys=transformed_random_signs_flat)
            grads_all = nest.pack_sequence_as(tensors_to_compute_grads,
                                              grads_flat)
            grads_lists = tuple((grad, ) for grad in grads_all)

        elif self._estimation_mode == "exact":
            # Loop over all coordinates of all losses.
            grads_all = []
            for loss in self._layers.losses:
                for index in np.ndindex(
                        *loss.fisher_factor_inner_static_shape[1:]):
                    transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
                        index)
                    grads_flat = gradients_impl.gradients(
                        loss.inputs,
                        tensors_to_compute_grads_flat,
                        grad_ys=transformed_one_hot)
                    grads_all.append(
                        nest.pack_sequence_as(tensors_to_compute_grads,
                                              grads_flat))

            grads_lists = zip(*grads_all)

        else:
            raise ValueError(
                "Unrecognized value {} for estimation_mode.".format(
                    self._estimation_mode))

        for grads_list, fb in zip(grads_lists, fisher_blocks_list):
            fb.instantiate_factors(grads_list, damping)

        cov_updates = [
            factor.make_covariance_update_op(cov_ema_decay)
            for factor in self._layers.get_factors()
        ]
        inv_updates = {
            op.name: op
            for factor in self._layers.get_factors()
            for op in factor.make_inverse_update_ops()
        }

        return control_flow_ops.group(*cov_updates), control_flow_ops.group(
            *inv_updates.values()), inv_updates
Esempio n. 6
0
  def _setup(self, cov_ema_decay):
    """Sets up the various operations.

    Args:
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.

    Returns:
      A triple (covs_update_op, invs_update_op, inv_updates_dict), where
      covs_update_op is the grouped Op to update all the covariance estimates,
      invs_update_op is the grouped Op to update all the inverses, and
      inv_updates_dict is a dict mapping Op names to individual inverse updates.

    Raises:
      ValueError: If estimation_mode was improperly specified at construction.
    """
    damping = self.damping

    fisher_blocks_list = self._layers.get_blocks()

    tensors_to_compute_grads = [
        fb.tensors_to_compute_grads() for fb in fisher_blocks_list
    ]
    tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads)

    if self._estimation_mode == "gradients":
      grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(),
                                            tensors_to_compute_grads_flat)
      grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)
      grads_lists = tuple((grad,) for grad in grads_all)

    elif self._estimation_mode == "empirical":
      grads_flat = gradients_impl.gradients(self._layers.total_loss(),
                                            tensors_to_compute_grads_flat)
      grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)
      grads_lists = tuple((grad,) for grad in grads_all)

    elif self._estimation_mode == "curvature_prop":
      loss_inputs = list(loss.inputs for loss in self._layers.losses)
      loss_inputs_flat = nest.flatten(loss_inputs)

      transformed_random_signs = list(loss.multiply_fisher_factor(
          utils.generate_random_signs(loss.fisher_factor_inner_shape))
                                      for loss in self._layers.losses)

      transformed_random_signs_flat = nest.flatten(transformed_random_signs)

      grads_flat = gradients_impl.gradients(loss_inputs_flat,
                                            tensors_to_compute_grads_flat,
                                            grad_ys
                                            =transformed_random_signs_flat)
      grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)
      grads_lists = tuple((grad,) for grad in grads_all)

    elif self._estimation_mode == "exact":
      # Loop over all coordinates of all losses.
      grads_all = []
      for loss in self._layers.losses:
        for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
          transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
              index)
          grads_flat = gradients_impl.gradients(loss.inputs,
                                                tensors_to_compute_grads_flat,
                                                grad_ys=transformed_one_hot)
          grads_all.append(nest.pack_sequence_as(tensors_to_compute_grads,
                                                 grads_flat))

      grads_lists = zip(*grads_all)

    else:
      raise ValueError("Unrecognized value {} for estimation_mode.".format(
          self._estimation_mode))

    for grads_list, fb in zip(grads_lists, fisher_blocks_list):
      fb.instantiate_factors(grads_list, damping)

    cov_updates = [
        factor.make_covariance_update_op(cov_ema_decay)
        for factor in self._layers.get_factors()
    ]
    inv_updates = {
        op.name: op
        for factor in self._layers.get_factors()
        for op in factor.make_inverse_update_ops()
    }

    return control_flow_ops.group(*cov_updates), control_flow_ops.group(
        *inv_updates.values()), inv_updates