Exemple #1
0
 def testErrorWithLayerNormNoScale(self):
     model = tf.keras.Sequential([
         layers.Conv2D(13, 5, input_shape=(28, 28, 3)),
         layers.LayerNormalization(scale=False),
         layers.GlobalMaxPool2D(),
     ])
     with self.assertRaisesRegex(ValueError, '.*scale=False.*'):
         utils.get_layer_collection(model, 'binary_crossentropy')
Exemple #2
0
    def testMultiOutputNestedModelFails(self):
        inp = tf.keras.Input(shape=(1, ))
        out1 = layers.Dense(1)(inp)
        out2 = layers.Dense(1)(inp)
        model = tf.keras.Model(inputs=inp, outputs=[out1, out2])

        inp2 = tf.keras.Input(shape=(1, ))
        out = model(inp2)
        model = tf.keras.Model(inputs=inp2, outputs=out)

        with self.assertRaisesRegex(
                ValueError,
                'Nested models with multiple outputs are unsupported.'):
            utils.get_layer_collection(model, loss=['mse', 'mse'])
Exemple #3
0
    def testFisherApproxLayerNames(self, fisher_approx):
        model = tf.keras.Sequential([
            layers.Dense(10, input_shape=(20, ), name='l1'),
            layers.Activation('relu'),
            layers.Dense(13, activation='relu', name='l2'),
            layers.Dense(23, trainable=False),
            layers.Dense(17, name='l3'),
            layers.Activation('relu'),
            layers.Dense(3, name='l4')
        ])
        lc = utils.get_layer_collection(model,
                                        'mse',
                                        fisher_approx=fisher_approx)
        trainable_layers = [model.layers[i] for i in [0, 2, 4, 6]]
        expected_in_diag_approx = [False, True, False, True]
        expected_out_diag_approx = [False, False, True, True]

        for layer, in_diag, out_diag in zip(trainable_layers,
                                            expected_in_diag_approx,
                                            expected_out_diag_approx):
            self.assertEqual(
                in_diag,
                lc.fisher_blocks[layer.weights]._diagonal_approx_for_input)
            self.assertEqual(
                out_diag,
                lc.fisher_blocks[layer.weights]._diagonal_approx_for_output)
Exemple #4
0
    def testModelAsCallable(self):
        model = tf.keras.Sequential([
            layers.Conv2D(13, 5),
            layers.BatchNormalization(name='bn', fused=False),
            layers.Conv2D(23, 3),
            layers.LayerNormalization(),
            layers.GlobalMaxPool2D(),
        ])
        inp = tf.random_normal((10, 28, 28, 3))
        inp = tf.keras.Input(tensor=inp)
        inp2 = tf.random_normal((10, 28, 28, 3))
        inp2 = tf.keras.Input(tensor=inp2)

        fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'}
        _ = model(inp)
        _ = model(
            inp2)  # with multiple calls, the latest should be registered.
        lc = utils.get_layer_collection(model,
                                        'mse',
                                        fisher_approx=fisher_approx)

        for i in (0, 2):
            conv_block = lc.fisher_blocks[model.layers[i].trainable_weights]
            conv_inp = model.layers[i].inbound_nodes[-1].input_tensors
            conv_out = model.layers[i].inbound_nodes[-1].output_tensors
            self.assertEqual(conv_inp, conv_block._inputs[0])
            self.assertEqual(conv_out, conv_block._outputs[0])
Exemple #5
0
    def testMultipleLoss(self, loss, loss_weights):
        inputs, (out1, out2) = _two_loss_model()
        model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])
        lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights)

        self.assertLen(lc.loss_coeffs.keys(), 2)
        self.assertLen(lc.loss_colocation_ops.keys(), 2)

        l1 = lc._loss_dict['sigmoid_cross_entropy_loss']
        l2 = lc._loss_dict['sparse_softmax_cross_entropy_loss']

        self.assertLen(l1, 1)
        self.assertLen(l2, 1)

        l1, l2 = l1[0], l2[0]

        self.assertIsInstance(l1,
                              loss_functions.MultiBernoulliNegativeLogProbLoss)
        self.assertIsInstance(
            l2, loss_functions.CategoricalLogitsNegativeLogProbLoss)
        self.assertEqual(lc.loss_coeffs[l1], 0.1)
        self.assertEqual(lc.loss_coeffs[l2], 0.9)
        self.assertEqual(lc.loss_colocation_ops[l1], out1)
        self.assertEqual(lc.loss_colocation_ops[l2], out2)

        self.assertEqual(lc.loss_coeffs[l1], 0.1)
        self.assertEqual(lc.loss_coeffs[l2], 0.9)
Exemple #6
0
 def testInstantiationWithLayerCollection(self):
   model = _simple_mlp()
   lc = utils.get_layer_collection(model, 'mse')
   opt = optimizers.Kfac(
       learning_rate=0.1, damping=0.2, layer_collection=lc)
   model.compile(optimizer=opt, loss='mse')
   opt.get_updates(model.total_loss, model.trainable_weights)
Exemple #7
0
 def testRegisterLayersWithLayerCollection(self):
   model, loss = _mnist_model(), 'categorical_crossentropy'
   lc = utils.get_layer_collection(model, loss)
   opt = optimizers.Kfac(learning_rate=0.01, damping=0.001)
   opt.register_layers(layer_collection=lc)
   model.compile(optimizer=opt, loss=loss)
   opt.get_updates(model.total_loss, model.trainable_weights)
Exemple #8
0
 def testFisherApproxLayerClass(self, fisher_approx, block_types):
     model = _cnn()
     lc = utils.get_layer_collection(model,
                                     'mse',
                                     fisher_approx=fisher_approx)
     trainable_layers = [model.layers[0], model.layers[2]]
     for layer, block_type in zip(trainable_layers, block_types):
         self.assertIsInstance(lc.fisher_blocks[layer.weights], block_type)
Exemple #9
0
    def testMultipleLossWeights(self, loss_weights):
        inputs, (out1, out2) = _two_loss_model()
        model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])
        loss = ['binary_crossentropy', 'categorical_crossentropy']
        lc = utils.get_layer_collection(model, loss, loss_weights=loss_weights)

        l1 = lc._loss_dict['sigmoid_cross_entropy_loss'][0]
        self.assertEqual(lc.loss_coeffs[l1], 1.0)
Exemple #10
0
 def testNumBatchNormUsesNoPhase(self):
     model = tf.keras.Sequential([
         layers.Conv2D(13, 5, input_shape=(28, 28, 3)),
         layers.BatchNormalization(fused=True),
         layers.GlobalMaxPool2D(),
     ])
     lc = utils.get_layer_collection(model, 'binary_crossentropy')
     for w in model.layers[1].trainable_weights:
         self.assertEqual(lc._vars_to_uses[w], 2)
Exemple #11
0
 def register_layers(self, model=None, loss=None, layer_collection=None):
   if not layer_collection:
     if not loss and hasattr(model, 'loss'):
       loss = model.loss
     if not (model and loss):
       raise ValueError('Please provide a model with a loss, a model and loss,'
                        ' or a LayerCollection')
     layer_collection = utils.get_layer_collection(
         model, loss, **self._layer_collection_kwargs)
   self._layer_collection = layer_collection
   self._kfac_kwargs['var_list'] = layer_collection.registered_variables
Exemple #12
0
    def testValidMSE(self, loss, model_builder):
        """Ensures variations of MSE and output variables work.

    Args:
      loss: A tf.keras.losses function (in serialized form or actual reference)
      model_builder: Function that returns a Keras model.
    """
        model = model_builder()
        lc = utils.get_layer_collection(model, loss)
        self.assertIsInstance(lc.losses[0],
                              loss_functions.NormalMeanNegativeLogProbLoss)
        self.assertEqual(lc.losses[0].params, model.layers[-1].output)
Exemple #13
0
    def testValidLogitLossFunctionsMLP(self, loss, kfac_loss):
        """Ensures correct tensorflow_kfac loss function and variable for a MLP.

    Args:
      loss: A losses function (in serialized form or actual reference)
      kfac_loss: tensorflow_kfac.python.ops loss function.
    """
        with tf.Graph().as_default():
            model = _mlp()
            lc = utils.get_layer_collection(model, loss)
            self.assertIsInstance(lc.losses[0], kfac_loss)
            self.assertEqual(lc.losses[0].params, model.layers[-1].output)
Exemple #14
0
    def testLayerRegistration(self, model_builder):
        model = model_builder()
        model.layers[0].trainable = False

        lc = utils.get_layer_collection(model, 'mse')
        registered = set(lc.registered_variables)

        variables = set()
        for layer in model.layers[1:]:
            if layer.trainable and layer.count_params():
                variables |= set(layer.weights)

        self.assertEqual(registered, variables)
Exemple #15
0
    def testNestedModels(self, fisher_approx):
        # Note this is not a valid trainable model, it was just created to test
        # order of the dict and list test the DFS order in utils as well.
        layer1 = layers.Dense(10, input_shape=(1, ), name='l1')
        layer2 = layers.Dense(10, activation='relu', name='l2')
        layer3 = layers.Dense(10, activation='relu', name='l3')

        inner_model0 = tf.keras.Sequential([layer1])

        inner_model1 = tf.keras.Sequential()
        inner_model1.add(inner_model0)
        inner_model1.add(layers.Activation('relu'))
        inner_model1.add(layer2)

        inner_inp = layers.Input(shape=(1, ))
        x = layer3(inner_inp)
        x = layers.Reshape(target_shape=(10, 1))(x)
        x = layers.GlobalMaxPool1D()(x)
        inner_model2 = tf.keras.Model(inputs=inner_inp, outputs=x)

        inp = layers.Input(shape=(1, ))
        branch1 = inner_model1(inp)
        branch2 = inner_model2(inp)
        out = layers.Add()([branch1, branch2])
        model = tf.keras.Model(inputs=inp, outputs=out)

        lc = utils.get_layer_collection(model=model,
                                        loss='mse',
                                        fisher_approx=fisher_approx)

        expected_in_diag_approx = [False, True, True]
        expected_out_diag_approx = [True, False, True]
        trainable_layers = [layer1, layer2, layer3]
        for layer, in_diag, out_diag in zip(trainable_layers,
                                            expected_in_diag_approx,
                                            expected_out_diag_approx):
            self.assertIsInstance(lc.fisher_blocks[layer.weights],
                                  fisher_blocks.FullyConnectedKFACBasicFB)
            self.assertEqual(
                in_diag,
                lc.fisher_blocks[layer.weights]._diagonal_approx_for_input)
            self.assertEqual(
                out_diag,
                lc.fisher_blocks[layer.weights]._diagonal_approx_for_output)
Exemple #16
0
 def testNormalizationLayers(self, has_shift):
     model = tf.keras.Sequential([
         layers.Conv2D(13, 5, input_shape=(28, 28, 3)),
         layers.BatchNormalization(center=has_shift, name='bn'),
         layers.Conv2D(23, 3),
         layers.LayerNormalization(center=has_shift),
         layers.GlobalMaxPool2D(),
     ])
     fisher_approx = {layers.LayerNormalization: 'full', 'bn': 'diagonal'}
     lc = utils.get_layer_collection(model,
                                     'mse',
                                     fisher_approx=fisher_approx)
     bn_weights = model.layers[1].trainable_weights
     ln_weights = model.layers[3].trainable_weights
     if not has_shift:
         bn_weights, ln_weights = bn_weights[0], ln_weights[0]
     bn_block = lc.fisher_blocks[bn_weights]
     ln_block = lc.fisher_blocks[ln_weights]
     self.assertIsInstance(bn_block, fisher_blocks.ScaleAndShiftDiagonalFB)
     self.assertIsInstance(ln_block, fisher_blocks.ScaleAndShiftFullFB)
     self.assertEqual(bn_block._has_shift, has_shift)
     self.assertEqual(ln_block._has_shift, has_shift)
Exemple #17
0
 def testFisherApproxErrors(self, fisher_approx):
     with self.assertRaisesRegex(ValueError, '.*fisher_approx.*'):
         utils.get_layer_collection(_cnn(),
                                    'mse',
                                    fisher_approx=fisher_approx)
Exemple #18
0
 def testInvalidLossFunctions(self, loss):
     with self.assertRaisesRegex(ValueError, '.*loss function:.*'):
         model = _mlp()
         utils.get_layer_collection(model, loss)
Exemple #19
0
 def testLossErrors(self, loss):
     with self.assertRaisesRegex(ValueError, '.*loss dict.*'):
         inputs, (out1, out2) = _two_loss_model()
         model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])
         utils.get_layer_collection(model, loss)
Exemple #20
0
 def testLossWeightErrors(self, loss_weights):
     with self.assertRaisesRegex(ValueError, '.*loss_weights.*'):
         inputs, (out1, out2) = _two_loss_model()
         model = tf.keras.Model(inputs=inputs, outputs=[out1, out2])
         loss = ['binary_crossentropy', 'categorical_crossentropy']
         utils.get_layer_collection(model, loss, loss_weights=loss_weights)
Exemple #21
0
 def testInvalidCNNLayers(self, layer):
     with self.assertRaises(ValueError):
         model = tf.keras.Sequential(
             [layers.Input(shape=(28, 28, 3)), layer])
         utils.get_layer_collection(model, 'mse')
Exemple #22
0
 def testSeed(self):
     lc = utils.get_layer_collection(model=_mlp(), loss='mse', seed=4321)
     self.assertEqual(lc._loss_dict['squared_error_loss'][0]._default_seed,
                      4321)