def build_wrapped_model(self, model, sparse=False):
    if sparse:
      wrapper = create_model_wrapper(model, sparse=True, num_sparse_params=6)
    else:
      wrapper = create_model_wrapper(model)

    params = Input(shape=(wrapper.num_params,))
    trainable_params = Input(shape=(wrapper.num_trainable_params,))
    x = Input(shape=K.int_shape(model.inputs[0]))
    y = wrapper([params, trainable_params, x])

    return Model(inputs=[params, trainable_params, x], outputs=[y])
  def testGetAllWeights(self):
    model = self.build_model()
    model.set_weights((np.ones(2), np.zeros(2), np.eye(2), np.zeros(2), np.ones(2)))
    wrapper = create_model_wrapper(model)

    expected_weights = np.array([1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1.])
    np.testing.assert_allclose(expected_weights, get_model_weights(model))
  def testGetParamGroupsWithTrainableParameters(self):
    model = self.build_model()
    for l in model.layers:
      l.trainable = l.name.startswith("dense")
    wrapper = create_model_wrapper(model)

    expected_groups = [(0, 6)]
    self.assertEqual(expected_groups, list(wrapper.param_groups()))
  def build_wrapped_model(self, model, batch_size=1):
    wrapper = create_model_wrapper(model, batch_size=batch_size)
    params = Input(shape=(wrapper.num_params,))
    trainable_params = Input(shape=(wrapper.num_trainable_params,))
    x = Input(shape=K.int_shape(model.inputs[0]))
    y = wrapper([params, trainable_params, x])

    return Model(inputs=[params, trainable_params, x], outputs=[y])
  def build_batchnorm_model_wrapper(self, batchnorm_training=False, batch_size=1):
    model = Sequential()
    model.add(BatchNormalization(input_shape=(None, 2)))
    model.compile(loss='mse', optimizer='SGD')

    wrapper = create_model_wrapper(model, batch_size=batch_size)
    params = Input(shape=(wrapper.num_params,))
    x = Input(shape=K.int_shape(model.inputs[0]))
    y = wrapper([params, params, x], training=batchnorm_training)

    wrapped_model = Model(inputs=[params, x], outputs=[y])
    wrapped_model.compile(loss='mse', optimizer='SGD')

    return model, wrapped_model
  def testGetParamGroups(self):
    model = self.build_model()
    wrapper = create_model_wrapper(model)

    expected_groups = [(0, 4), (4, 10), (10, 12)]
    self.assertEqual(expected_groups, list(wrapper.param_groups()))