def build_wrapped_model(self, model, sparse=False): if sparse: wrapper = create_model_wrapper(model, sparse=True, num_sparse_params=6) else: wrapper = create_model_wrapper(model) params = Input(shape=(wrapper.num_params,)) trainable_params = Input(shape=(wrapper.num_trainable_params,)) x = Input(shape=K.int_shape(model.inputs[0])) y = wrapper([params, trainable_params, x]) return Model(inputs=[params, trainable_params, x], outputs=[y])
def testGetAllWeights(self): model = self.build_model() model.set_weights((np.ones(2), np.zeros(2), np.eye(2), np.zeros(2), np.ones(2))) wrapper = create_model_wrapper(model) expected_weights = np.array([1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1.]) np.testing.assert_allclose(expected_weights, get_model_weights(model))
def testGetParamGroupsWithTrainableParameters(self): model = self.build_model() for l in model.layers: l.trainable = l.name.startswith("dense") wrapper = create_model_wrapper(model) expected_groups = [(0, 6)] self.assertEqual(expected_groups, list(wrapper.param_groups()))
def build_wrapped_model(self, model, batch_size=1): wrapper = create_model_wrapper(model, batch_size=batch_size) params = Input(shape=(wrapper.num_params,)) trainable_params = Input(shape=(wrapper.num_trainable_params,)) x = Input(shape=K.int_shape(model.inputs[0])) y = wrapper([params, trainable_params, x]) return Model(inputs=[params, trainable_params, x], outputs=[y])
def build_batchnorm_model_wrapper(self, batchnorm_training=False, batch_size=1): model = Sequential() model.add(BatchNormalization(input_shape=(None, 2))) model.compile(loss='mse', optimizer='SGD') wrapper = create_model_wrapper(model, batch_size=batch_size) params = Input(shape=(wrapper.num_params,)) x = Input(shape=K.int_shape(model.inputs[0])) y = wrapper([params, params, x], training=batchnorm_training) wrapped_model = Model(inputs=[params, x], outputs=[y]) wrapped_model.compile(loss='mse', optimizer='SGD') return model, wrapped_model
def testGetParamGroups(self): model = self.build_model() wrapper = create_model_wrapper(model) expected_groups = [(0, 4), (4, 10), (10, 12)] self.assertEqual(expected_groups, list(wrapper.param_groups()))