Esempio n. 1
0
 def register_fully_connected(self,
                              params,
                              inputs,
                              outputs,
                              approx=APPROX_KRONECKER_NAME):
   has_bias = isinstance(params, (tuple, list))
   if approx == APPROX_KRONECKER_NAME:
     self.register_block(params,
                         fb.FullyConnectedKFACBasicFB(self, inputs, outputs,
                                                      has_bias))
   elif approx == APPROX_DIAGONAL_NAME:
     self.register_block(params,
                         fb.FullyConnectedDiagonalFB(self, inputs, outputs,
                                                     has_bias))
   else:
     raise ValueError("Bad value {} for approx.".format(approx))
    def runFisherBlockOps(self, params, inputs, outputs, output_grads):
        """Run Ops guaranteed by FisherBlock interface.

    Args:
      params: Tensor or 2-tuple of Tensors. Represents weights or weights and
        bias of this layer.
      inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
        layer.
      outputs: list of Tensors of shape [batch_size, output_size].
        Preactivations produced by layer.
      output_grads: list of Tensors of shape [batch_size, output_size].
        Gradient of loss with respect to 'outputs'.

    Returns:
      multiply_result: Result of FisherBlock.multiply(params)
      multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
    """
        def _as_tensors(tensor_or_tuple):
            if isinstance(tensor_or_tuple, (tuple, list)):
                return tuple(ops.convert_to_tensor(t) for t in tensor_or_tuple)
            return ops.convert_to_tensor(tensor_or_tuple)

        with ops.Graph().as_default(), self.test_session() as sess:
            inputs = [_as_tensors(i) for i in inputs]
            outputs = [_as_tensors(o) for o in outputs]
            output_grads = [_as_tensors(og) for og in output_grads]
            params = _as_tensors(params)

            block = fb.FullyConnectedDiagonalFB(lc.LayerCollection(),
                                                has_bias=isinstance(
                                                    params, (tuple, list)))
            for (i, o) in zip(inputs, outputs):
                block.register_additional_minibatch(i, o)

            block.instantiate_factors((output_grads, ), damping=0.0)

            sess.run(tf_variables.global_variables_initializer())
            sess.run(block._factor.make_covariance_update_op(0.0))
            multiply_result = sess.run(block.multiply(params))
            multiply_inverse_result = sess.run(block.multiply_inverse(params))

        return multiply_result, multiply_inverse_result