def register_fully_connected(self, params, inputs, outputs, approx=APPROX_KRONECKER_NAME): has_bias = isinstance(params, (tuple, list)) if approx == APPROX_KRONECKER_NAME: self.register_block(params, fb.FullyConnectedKFACBasicFB(self, inputs, outputs, has_bias)) elif approx == APPROX_DIAGONAL_NAME: self.register_block(params, fb.FullyConnectedDiagonalFB(self, inputs, outputs, has_bias)) else: raise ValueError("Bad value {} for approx.".format(approx))
def runFisherBlockOps(self, params, inputs, outputs, output_grads): """Run Ops guaranteed by FisherBlock interface. Args: params: Tensor or 2-tuple of Tensors. Represents weights or weights and bias of this layer. inputs: list of Tensors of shape [batch_size, input_size]. Inputs to layer. outputs: list of Tensors of shape [batch_size, output_size]. Preactivations produced by layer. output_grads: list of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to 'outputs'. Returns: multiply_result: Result of FisherBlock.multiply(params) multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) """ def _as_tensors(tensor_or_tuple): if isinstance(tensor_or_tuple, (tuple, list)): return tuple(ops.convert_to_tensor(t) for t in tensor_or_tuple) return ops.convert_to_tensor(tensor_or_tuple) with ops.Graph().as_default(), self.test_session() as sess: inputs = [_as_tensors(i) for i in inputs] outputs = [_as_tensors(o) for o in outputs] output_grads = [_as_tensors(og) for og in output_grads] params = _as_tensors(params) block = fb.FullyConnectedDiagonalFB(lc.LayerCollection(), has_bias=isinstance( params, (tuple, list))) for (i, o) in zip(inputs, outputs): block.register_additional_minibatch(i, o) block.instantiate_factors((output_grads, ), damping=0.0) sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) multiply_result = sess.run(block.multiply(params)) multiply_inverse_result = sess.run(block.multiply_inverse(params)) return multiply_result, multiply_inverse_result