def testErrorWithLayerNormNoScale(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.LayerNormalization(scale=False), layers.GlobalMaxPool2D(), ]) with self.assertRaisesRegex(ValueError, '.*scale=False.*'): utils.get_layer_collection(model, 'binary_crossentropy')
def testMultiOutputNestedModelFails(self): inp = tf.keras.Input(shape=(1, )) out1 = layers.Dense(1)(inp) out2 = layers.Dense(1)(inp) model = tf.keras.Model(inputs=inp, outputs=[out1, out2]) inp2 = tf.keras.Input(shape=(1, )) out = model(inp2) model = tf.keras.Model(inputs=inp2, outputs=out) with self.assertRaisesRegex( ValueError, 'Nested models with multiple outputs are unsupported.'): utils.get_layer_collection(model, loss=['mse', 'mse'])
def testFisherApproxLayerNames(self, fisher_approx): model = tf.keras.Sequential([ layers.Dense(10, input_shape=(20, ), name='l1'), layers.Activation('relu'), layers.Dense(13, activation='relu', name='l2'), layers.Dense(23, trainable=False), layers.Dense(17, name='l3'), layers.Activation('relu'), layers.Dense(3, name='l4') ]) lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) trainable_layers = [model.layers[i] for i in [0, 2, 4, 6]] expected_in_diag_approx = [False, True, False, True] expected_out_diag_approx = [False, False, True, True] for layer, in_diag, out_diag in zip(trainable_layers, expected_in_diag_approx, expected_out_diag_approx): self.assertEqual( in_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_input) self.assertEqual( out_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_output)
def testModelAsCallable(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5), layers.BatchNormalization(name='bn', fused=False), layers.Conv2D(23, 3), layers.LayerNormalization(), layers.GlobalMaxPool2D(), ]) inp = tf.random_normal((10, 28, 28, 3)) inp = tf.keras.Input(tensor=inp) inp2 = tf.random_normal((10, 28, 28, 3)) inp2 = tf.keras.Input(tensor=inp2) fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'} _ = model(inp) _ = model( inp2) # with multiple calls, the latest should be registered. lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) for i in (0, 2): conv_block = lc.fisher_blocks[model.layers[i].trainable_weights] conv_inp = model.layers[i].inbound_nodes[-1].input_tensors conv_out = model.layers[i].inbound_nodes[-1].output_tensors self.assertEqual(conv_inp, conv_block._inputs[0]) self.assertEqual(conv_out, conv_block._outputs[0])
def testMultipleLoss(self, loss, loss_weights): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights) self.assertLen(lc.loss_coeffs.keys(), 2) self.assertLen(lc.loss_colocation_ops.keys(), 2) l1 = lc._loss_dict['sigmoid_cross_entropy_loss'] l2 = lc._loss_dict['sparse_softmax_cross_entropy_loss'] self.assertLen(l1, 1) self.assertLen(l2, 1) l1, l2 = l1[0], l2[0] self.assertIsInstance(l1, loss_functions.MultiBernoulliNegativeLogProbLoss) self.assertIsInstance( l2, loss_functions.CategoricalLogitsNegativeLogProbLoss) self.assertEqual(lc.loss_coeffs[l1], 0.1) self.assertEqual(lc.loss_coeffs[l2], 0.9) self.assertEqual(lc.loss_colocation_ops[l1], out1) self.assertEqual(lc.loss_colocation_ops[l2], out2) self.assertEqual(lc.loss_coeffs[l1], 0.1) self.assertEqual(lc.loss_coeffs[l2], 0.9)
def testInstantiationWithLayerCollection(self): model = _simple_mlp() lc = utils.get_layer_collection(model, 'mse') opt = optimizers.Kfac( learning_rate=0.1, damping=0.2, layer_collection=lc) model.compile(optimizer=opt, loss='mse') opt.get_updates(model.total_loss, model.trainable_weights)
def testRegisterLayersWithLayerCollection(self): model, loss = _mnist_model(), 'categorical_crossentropy' lc = utils.get_layer_collection(model, loss) opt = optimizers.Kfac(learning_rate=0.01, damping=0.001) opt.register_layers(layer_collection=lc) model.compile(optimizer=opt, loss=loss) opt.get_updates(model.total_loss, model.trainable_weights)
def testFisherApproxLayerClass(self, fisher_approx, block_types): model = _cnn() lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) trainable_layers = [model.layers[0], model.layers[2]] for layer, block_type in zip(trainable_layers, block_types): self.assertIsInstance(lc.fisher_blocks[layer.weights], block_type)
def testMultipleLossWeights(self, loss_weights): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) loss = ['binary_crossentropy', 'categorical_crossentropy'] lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights) l1 = lc._loss_dict['sigmoid_cross_entropy_loss'][0] self.assertEqual(lc.loss_coeffs[l1], 1.0)
def testNumBatchNormUsesNoPhase(self): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.BatchNormalization(fused=True), layers.GlobalMaxPool2D(), ]) lc = utils.get_layer_collection(model, 'binary_crossentropy') for w in model.layers[1].trainable_weights: self.assertEqual(lc._vars_to_uses[w], 2)
def register_layers(self, model=None, loss=None, layer_collection=None): if not layer_collection: if not loss and hasattr(model, 'loss'): loss = model.loss if not (model and loss): raise ValueError('Please provide a model with a loss, a model and loss,' ' or a LayerCollection') layer_collection = utils.get_layer_collection( model, loss, **self._layer_collection_kwargs) self._layer_collection = layer_collection self._kfac_kwargs['var_list'] = layer_collection.registered_variables
def testValidMSE(self, loss, model_builder): """Ensures variations of MSE and output variables work. Args: loss: A tf.keras.losses function (in serialized form or actual reference) model_builder: Function that returns a Keras model. """ model = model_builder() lc = utils.get_layer_collection(model, loss) self.assertIsInstance(lc.losses[0], loss_functions.NormalMeanNegativeLogProbLoss) self.assertEqual(lc.losses[0].params, model.layers[-1].output)
def testValidLogitLossFunctionsMLP(self, loss, kfac_loss): """Ensures correct tensorflow_kfac loss function and variable for a MLP. Args: loss: A losses function (in serialized form or actual reference) kfac_loss: tensorflow_kfac.python.ops loss function. """ with tf.Graph().as_default(): model = _mlp() lc = utils.get_layer_collection(model, loss) self.assertIsInstance(lc.losses[0], kfac_loss) self.assertEqual(lc.losses[0].params, model.layers[-1].output)
def testLayerRegistration(self, model_builder): model = model_builder() model.layers[0].trainable = False lc = utils.get_layer_collection(model, 'mse') registered = set(lc.registered_variables) variables = set() for layer in model.layers[1:]: if layer.trainable and layer.count_params(): variables |= set(layer.weights) self.assertEqual(registered, variables)
def testNestedModels(self, fisher_approx): # Note this is not a valid trainable model, it was just created to test # order of the dict and list test the DFS order in utils as well. layer1 = layers.Dense(10, input_shape=(1, ), name='l1') layer2 = layers.Dense(10, activation='relu', name='l2') layer3 = layers.Dense(10, activation='relu', name='l3') inner_model0 = tf.keras.Sequential([layer1]) inner_model1 = tf.keras.Sequential() inner_model1.add(inner_model0) inner_model1.add(layers.Activation('relu')) inner_model1.add(layer2) inner_inp = layers.Input(shape=(1, )) x = layer3(inner_inp) x = layers.Reshape(target_shape=(10, 1))(x) x = layers.GlobalMaxPool1D()(x) inner_model2 = tf.keras.Model(inputs=inner_inp, outputs=x) inp = layers.Input(shape=(1, )) branch1 = inner_model1(inp) branch2 = inner_model2(inp) out = layers.Add()([branch1, branch2]) model = tf.keras.Model(inputs=inp, outputs=out) lc = utils.get_layer_collection(model=model, loss='mse', fisher_approx=fisher_approx) expected_in_diag_approx = [False, True, True] expected_out_diag_approx = [True, False, True] trainable_layers = [layer1, layer2, layer3] for layer, in_diag, out_diag in zip(trainable_layers, expected_in_diag_approx, expected_out_diag_approx): self.assertIsInstance(lc.fisher_blocks[layer.weights], fisher_blocks.FullyConnectedKFACBasicFB) self.assertEqual( in_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_input) self.assertEqual( out_diag, lc.fisher_blocks[layer.weights]._diagonal_approx_for_output)
def testNormalizationLayers(self, has_shift): model = tf.keras.Sequential([ layers.Conv2D(13, 5, input_shape=(28, 28, 3)), layers.BatchNormalization(center=has_shift, name='bn'), layers.Conv2D(23, 3), layers.LayerNormalization(center=has_shift), layers.GlobalMaxPool2D(), ]) fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'} lc = utils.get_layer_collection(model, 'mse', fisher_approx=fisher_approx) bn_weights = model.layers[1].trainable_weights ln_weights = model.layers[3].trainable_weights if not has_shift: bn_weights, ln_weights = bn_weights[0], ln_weights[0] bn_block = lc.fisher_blocks[bn_weights] ln_block = lc.fisher_blocks[ln_weights] self.assertIsInstance(bn_block, fisher_blocks.ScaleAndShiftDiagonalFB) self.assertIsInstance(ln_block, fisher_blocks.ScaleAndShiftFullFB) self.assertEqual(bn_block._has_shift, has_shift) self.assertEqual(ln_block._has_shift, has_shift)
def testFisherApproxErrors(self, fisher_approx): with self.assertRaisesRegex(ValueError, '.*fisher_approx.*'): utils.get_layer_collection(_cnn(), 'mse', fisher_approx=fisher_approx)
def testInvalidLossFunctions(self, loss): with self.assertRaisesRegex(ValueError, '.*loss function:.*'): model = _mlp() utils.get_layer_collection(model, loss)
def testLossErrors(self, loss): with self.assertRaisesRegex(ValueError, '.*loss dict.*'): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) utils.get_layer_collection(model, loss)
def testLossWeightErrors(self, loss_weights): with self.assertRaisesRegex(ValueError, '.*loss_weights.*'): inputs, (out1, out2) = _two_loss_model() model = tf.keras.Model(inputs=inputs, outputs=[out1, out2]) loss = ['binary_crossentropy', 'categorical_crossentropy'] utils.get_layer_collection(model, loss, loss_weights=loss_weights)
def testInvalidCNNLayers(self, layer): with self.assertRaises(ValueError): model = tf.keras.Sequential( [layers.Input(shape=(28, 28, 3)), layer]) utils.get_layer_collection(model, 'mse')
def testSeed(self): lc = utils.get_layer_collection(model=_mlp(), loss='mse', seed=4321) self.assertEqual(lc._loss_dict['squared_error_loss'][0]._default_seed, 4321)