Пример #1
0
def get_model(model_arch_name=gin.REQUIRED,
              dataset_name=gin.REQUIRED,
              load_path=None,
              prepare_for_pruning=False):
    """Creates or loads the model and returns it.

  If the model does not match with the saved, version, usually no error or
    warning is made, so be careful, CHECK YOUR VARIABLES.

  Args:
    model_arch_name: str, definition from .model_defs.py file.
    dataset_name: str, either 'cifar10' or 'imagenet'.
    load_path: str, checkpoint name/path to be load.
    prepare_for_pruning: bool, if True the loaded model is copied in-to one with
      TaylorScorer layer and layers are wrapped with MaskedLayer.

  Returns:
    generic_convnet.GenericConvnet, initialized or loaded model.
  Raises:
    ValueError: when the args doesn't match the specs.
    IOError: when there is no checkpoint found at the path given.
  """
    if dataset_name not in INPUT_SHAPES:
        raise ValueError('Dataset_name: %s is not one of %s' %
                         (dataset_name, list(INPUT_SHAPES.keys())))
    if not hasattr(model_defs, model_arch_name):
        raise ValueError('Model name: %s...not in model_defs.py' %
                         model_arch_name)
    n_classes = data.N_CLASSES_BY_DATASET[dataset_name]
    model_arch = (getattr(model_defs, model_arch_name) + [['O', n_classes]])
    model = generic_convnet.GenericConvnet(model_arch=model_arch,
                                           name=model_arch_name)
    dummy_var = tf.zeros(INPUT_SHAPES[dataset_name])
    # Initializing model.
    model(dummy_var)
    if load_path is not None:
        checkpoint = tf.train.Checkpoint(model=model)
        checkpoint.restore(load_path)
        if prepare_for_pruning:
            old_model = model
            model = generic_convnet.GenericConvnet(model_arch=model_arch,
                                                   name=model_arch_name,
                                                   use_taylor_scorer=True,
                                                   use_masked_layers=True)
            model(dummy_var)
            for v1, v2 in zip(old_model.trainable_variables,
                              model.trainable_variables):
                v2.assign(v1)
    return model
Пример #2
0
 def testValidModelArchitectures(self):
     # 3rd argument for 'C' is the filter shape, which can be an int or list
     # of two integers.
     generic_convnet.GenericConvnet(model_arch=[['C', 16, [3, 5], {}]])
     generic_convnet.GenericConvnet(model_arch=[['C', 16, 5, {}]])
     # Check that **kwargs is working.
     m = generic_convnet.GenericConvnet(
         model_arch=[['C', 3, 5, {
             'padding': 'same'
         }]])
     x = tf.random.uniform((5, 32, 32, 3))
     y = m(x)
     self.assertAllEqual(x.shape, y.shape)
     generic_convnet.GenericConvnet(model_arch=[['MP', 2, [3, 5]]])
     generic_convnet.GenericConvnet(model_arch=[['MP', [2, 3], 5]])
Пример #3
0
 def testPropagateBias(self):
     for use_masked_layers in [True, False]:
         m = generic_convnet.GenericConvnet(
             name='test', use_masked_layers=use_masked_layers)
         dummy_input = tf.ones((32, 28, 28, 3), dtype=tf.float32)
         # Initialize model parameters.
         m(dummy_input)
         layer_conv_1 = m.conv_1.layer if use_masked_layers else m.conv_1
         layer_conv_2 = m.conv_2.layer if use_masked_layers else m.conv_2
         n_units = layer_conv_1.filters
         l_out = m.conv_1(dummy_input)
         zeros_with_a_single_one = [0] * (n_units - 1) + [1]
         mean_values = tf.cast(
             tf.broadcast_to(zeros_with_a_single_one, l_out.shape),
             tf.float32)
         # Default initialization for bias is all zeros.
         self.assertEqual(
             tf.math.count_nonzero(layer_conv_2.weights[1]).numpy(), 0)
         correct_propagated_tensor = m.conv_2(
             tf.keras.activations.relu(mean_values))
         # Since we have constant tensors `mean_values[:,:,:,i]` for each i, each
         # of the resulting channels should be equal to each other.
         self.assertAllEqual(correct_propagated_tensor[0, 0, 0, :],
                             correct_propagated_tensor[0, 0, 1, :])
         # Since all values in the last dimension same, we can take a single one,
         # this is equal to the mean anyway.
         correct_new_bias = correct_propagated_tensor[0, 0, 1, :]
         m.propagate_bias('conv_1', mean_values)
         self.assertAllClose(layer_conv_2.weights[1].numpy(),
                             correct_new_bias.numpy(),
                             atol=1e-04)
Пример #4
0
 def testClone(self):
     m = generic_convnet.GenericConvnet(name='test', use_masked_layers=True)
     # Initilizes the params.
     m(tf.random.uniform((4, 32, 32, 3)))
     m2 = m.clone()
     self.assertAllEqual(m2.conv_2.weights[0].numpy(),
                         m.conv_2.weights[0].numpy())
     self.assertAllEqual(m2.dense_1.weights[0].numpy(),
                         m.dense_1.weights[0].numpy())
     self.assertNotEqual(m2.conv_2, m.conv_2)
Пример #5
0
 def testDropoutInjection(self):
     model_arch = [['C', 16, [3, 5], {}], ['DO', 0.5], ['F'], ['D', 16]]
     m = generic_convnet.GenericConvnet(name='test',
                                        model_arch=model_arch,
                                        use_dropout=False)
     m(tf.random.uniform((2, 10, 10, 3)))
     correct_chain = [
         'conv_1', 'conv_1_a', 'dropout_1', 'flatten_1', 'dense_1',
         'dense_1_a'
     ]
     self.assertAllEqual(m.forward_chain, correct_chain)
     self.assertIsInstance(m.dropout_1, tf.keras.layers.Dropout)
Пример #6
0
 def testPropagateBiasErrors(self):
     m = generic_convnet.GenericConvnet(name='test', use_masked_layers=True)
     with self.assertRaises(AssertionError):
         m.propagate_bias('conv_1', tf.ones((32, 24, 24, 3),
                                            dtype=tf.int16))
     with self.assertRaises(ValueError):
         # Layer name misspelled.
         m.propagate_bias('conv1', tf.ones((32, 24, 24, 3),
                                           dtype=tf.float32))
     with self.assertRaises(ValueError):
         # There is no other layer to propagate.
         m.propagate_bias('output_1',
                          tf.ones((32, 24, 24, 3), dtype=tf.float32))
Пример #7
0
 def testUseMeanReplacer(self):
     model_arch = [['C', 16, [3, 5], {}], ['GA'], ['D', 16]]
     m = generic_convnet.GenericConvnet(name='test',
                                        model_arch=model_arch,
                                        use_mean_replacer=True)
     m(tf.random.uniform((2, 10, 10, 3)))
     correct_chain = [
         'conv_1', 'conv_1_a', 'conv_1_mr', 'gap_1', 'dense_1', 'dense_1_a',
         'dense_1_mr'
     ]
     self.assertAllEqual(m.forward_chain, correct_chain)
     self.assertIsInstance(m.conv_1_mr, layers.MeanReplacer)
     self.assertNotIsInstance(m.conv_1, layers.MeanReplacer)
     self.assertIsInstance(m.dense_1_mr, layers.MeanReplacer)
     self.assertNotIsInstance(m.dense_1, layers.MeanReplacer)
Пример #8
0
 def testUseTaylorScorer(self):
     model_arch = [['C', 16, [3, 5], {}], ['F'], ['D', 16]]
     m = generic_convnet.GenericConvnet(name='test',
                                        model_arch=model_arch,
                                        use_taylor_scorer=True)
     m(tf.random.uniform((2, 10, 10, 3)))
     correct_chain = [
         'conv_1', 'conv_1_a', 'conv_1_ts', 'flatten_1', 'dense_1',
         'dense_1_a', 'dense_1_ts'
     ]
     self.assertAllEqual(m.forward_chain, correct_chain)
     self.assertIsInstance(m.conv_1_ts, layers.TaylorScorer)
     self.assertNotIsInstance(m.conv_1, layers.TaylorScorer)
     self.assertIsInstance(m.dense_1_ts, layers.TaylorScorer)
     self.assertNotIsInstance(m.dense_1, layers.TaylorScorer)
Пример #9
0
 def testGetAllLayerKeys(self):
     m = generic_convnet.GenericConvnet(name='test',
                                        use_masked_layers=True,
                                        use_batchnorm=True)
     returned_set = set(m.get_layer_keys(layers.MaskedLayer))
     correct_set = set(['conv_1', 'conv_2', 'dense_1'])
     self.assertSetEqual(returned_set, correct_set)
     returned_set = set(m.get_layer_keys(
         tf.keras.layers.BatchNormalization))
     correct_set = set(['conv_1_bn', 'conv_2_bn', 'dense_1_bn'])
     self.assertSetEqual(returned_set, correct_set)
     returned_set = set(
         m.get_layer_keys(tf.keras.layers.BatchNormalization,
                          name_filter=lambda n: not n.startswith('dense')))
     correct_set = set(['conv_1_bn', 'conv_2_bn'])
     self.assertSetEqual(returned_set, correct_set)
Пример #10
0
 def testDefaultConstructor(self):
     m = generic_convnet.GenericConvnet(name='test')
     # m.forward_chain should be like
     # ['conv_1', 'conv_1_a', 'maxpool_1', 'conv_2', 'conv_2_a', 'maxpool_2',
     # 'flatten_1', 'dense_1', 'dense_1_a', 'output_1'])
     self.assertEqual(len(m.forward_chain), 10)
     self.assertEqual(m.name, 'test')
     self.assertIsInstance(m.conv_1, tf.keras.layers.Conv2D)
     self.assertEqual(m.conv_1_a, tf.keras.activations.relu)
     self.assertIsInstance(m.maxpool_1, tf.keras.layers.MaxPool2D)
     self.assertIsInstance(m.conv_2, tf.keras.layers.Conv2D)
     self.assertEqual(m.conv_2_a, tf.keras.activations.relu)
     self.assertIsInstance(m.maxpool_2, tf.keras.layers.MaxPool2D)
     self.assertIsInstance(m.flatten_1, tf.keras.layers.Flatten)
     self.assertIsInstance(m.dense_1, tf.keras.layers.Dense)
     self.assertEqual(m.dense_1_a, tf.keras.activations.relu)
     self.assertIsInstance(m.output_1, tf.keras.layers.Dense)
Пример #11
0
 def testReturnNodes(self):
     model_arch = [['C', 16, [3, 5], {}], ['F'], ['D', 16]]
     m = generic_convnet.GenericConvnet(name='test',
                                        model_arch=model_arch,
                                        use_taylor_scorer=True)
     # forward_chain = [
     #    'conv_1', 'conv_1_a', 'conv_1_ts', 'flatten_1', 'dense_1', 'dense_1_a',
     #    'dense_1_ts'
     # ]
     x = tf.random.uniform((2, 10, 10, 3))
     y = m(x)
     nodes = set(['conv_1_a', 'dense_1_a'])
     y2, res_dict = m(x, return_nodes=nodes)
     self.assertAllEqual(y, y2)
     ts_layer = getattr(m, 'conv_1_ts')
     self.assertAllClose(tf.reduce_mean(res_dict['conv_1_a'],
                                        axis=[0, 1, 2]),
                         ts_layer.get_saved_values('mean'),
                         atol=1e-4)
     ts_layer = getattr(m, 'dense_1_ts')
     self.assertAllClose(tf.reduce_mean(res_dict['dense_1_a'], axis=0),
                         ts_layer.get_saved_values('mean'),
                         atol=1e-4)
Пример #12
0
 def testMaskedLayer(self):
     m = generic_convnet.GenericConvnet(name='test', use_masked_layers=True)
     self.assertIsInstance(m.conv_1, layers.MaskedLayer)
     self.assertIsInstance(m.conv_2, layers.MaskedLayer)
     self.assertIsInstance(m.dense_1, layers.MaskedLayer)
     self.assertNotIsInstance(m.flatten_1, layers.MaskedLayer)
Пример #13
0
 def testInvalidModelArchitectures(self):
     with self.assertRaises(AssertionError):
         # Needs to be non-zero length iterable of collections.MutableSequence.
         generic_convnet.GenericConvnet(model_arch='test')
     with self.assertRaises(AssertionError):
         generic_convnet.GenericConvnet(
             model_arch=[['C', 16, [5, 2, 5], {}]])
     with self.assertRaises(AssertionError):
         generic_convnet.GenericConvnet(
             model_arch=[['C', 16, [5, '5'], {}]])
     with self.assertRaises(AssertionError):
         # Empty kwargs missing
         generic_convnet.GenericConvnet(model_arch=[['C', 16, 4]])
     with self.assertRaises(AssertionError):
         generic_convnet.GenericConvnet(model_arch=[['MP', 16, [3, 5], {}]])
     with self.assertRaises(AssertionError):
         # Dense only requires one int (out_units).
         generic_convnet.GenericConvnet(model_arch=[['D', 16, 4]])
     with self.assertRaises(AssertionError):
         # Second argument should be n_units. Activations are given with a flag.
         generic_convnet.GenericConvnet(model_arch=[['0', 'norelu']])
     with self.assertRaises(AssertionError):
         # Flatten should not have other elements in the list
         generic_convnet.GenericConvnet(model_arch=[['F', 16]])
     with self.assertRaises(AssertionError):
         # GlobalAveragePooling should not have other elements in the list
         generic_convnet.GenericConvnet(model_arch=[['GA', 16]])
     with self.assertRaises(AssertionError):
         # Second argument should be n_units. Activations are given with a flag.
         generic_convnet.GenericConvnet(model_arch=[['D', 'relu']])
     with self.assertRaises(AssertionError):
         # Second argument should be dropout rate between 0,1.
         generic_convnet.GenericConvnet(model_arch=[['DO', 0.0]])
     with self.assertRaises(AssertionError):
         # Second argument should be dropout rate between 0,1.
         generic_convnet.GenericConvnet(model_arch=[['DO', '0.5']])