Exemplo n.º 1
0
  def register_conv2d(self, params, strides, padding, inputs, outputs,
                      approx=APPROX_KRONECKER_NAME):

    if approx == APPROX_KRONECKER_NAME:
      self.register_block(params,
                          fb.ConvKFCBasicFB(self, params, inputs, outputs,
                                            strides, padding))
    elif approx == APPROX_DIAGONAL_NAME:
      self.register_block(params,
                          fb.ConvDiagonalFB(self, params, inputs, outputs,
                                            strides, padding))
Exemplo n.º 2
0
    def runFisherBlockOps(self, params, inputs, outputs, output_grads):
        """Run Ops guaranteed by FisherBlock interface.

    Args:
      params: Tensor or 2-tuple of Tensors. Represents weights or weights and
        bias of this layer.
      inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
        layer.
      outputs: list of Tensors of shape [batch_size, output_size].
        Preactivations produced by layer.
      output_grads: list of Tensors of shape [batch_size, output_size].
        Gradient of loss with respect to 'outputs'.

    Returns:
      multiply_result: Result of FisherBlock.multiply(params)
      multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
    """
        with ops.Graph().as_default(), self.test_session() as sess:
            inputs = as_tensors(inputs)
            outputs = as_tensors(outputs)
            output_grads = as_tensors(output_grads)
            params = as_tensors(params)

            block = fb.ConvDiagonalFB(lc.LayerCollection(),
                                      params,
                                      strides=[1, 1, 1, 1],
                                      padding='SAME')
            for (i, o) in zip(inputs, outputs):
                block.register_additional_tower(i, o)

            block.instantiate_factors((output_grads, ), damping=0.0)
            block._factor.instantiate_cov_variables()

            sess.run(tf_variables.global_variables_initializer())
            sess.run(block._factor.make_covariance_update_op(0.0))
            multiply_result = sess.run(block.multiply(params))
            multiply_inverse_result = sess.run(block.multiply_inverse(params))

        return multiply_result, multiply_inverse_result