def test_predict_base_model(self): model = build_linear_functional_model( input_shape=(2,), weights=np.array([1.0, -1.0])) inputs = {FEATURE_NAME: tf.constant([[5.0, 3.0]])} graph_reg_model = graph_regularization.GraphRegularization(model) graph_reg_model.compile(optimizer=tf.keras.optimizers.SGD(0.01), loss='MSE') prediction = model.predict(x=inputs, steps=1, batch_size=1) self.assertAllEqual([[1 * 5.0 + (-1.0) * 3.0]], prediction)
def _create_and_compile_graph_reg_model(model_fn, weight, max_neighbors): """Creates and compiles a graph regularized model. Args: model_fn: A function that builds a linear regression model. weight: Initial value for the weights variable in the linear regressor. max_neighbors: The maximum number of neighbors for graph regularization. Returns: A pair containing the unregularized model and the graph regularized model as `tf.keras.Model` instances. """ model = model_fn((2,), weight) graph_reg_config = configs.make_graph_reg_config( max_neighbors=max_neighbors, multiplier=1) graph_reg_model = graph_regularization.GraphRegularization( model, graph_reg_config) graph_reg_model.compile( optimizer=tf.keras.optimizers.SGD(LEARNING_RATE), loss='MSE') return model, graph_reg_model