Esempio n. 1
0
def _make_functional_regularized_model(distance_config):
  """Makes a model with `PairwiseDistance` and the functional API."""

  def _make_unregularized_model(inputs, num_classes):
    """Makes standard 1 layer MLP with logistic regression."""
    x = tf.keras.layers.Dense(16, activation='relu')(inputs)
    return tf.keras.Model(inputs, outputs=tf.keras.layers.Dense(num_classes)(x))

  # Each example has 4 features and 2 neighbors, each with an edge weight.
  inputs = (tf.keras.Input(shape=(4,), dtype=tf.float32, name='features'),
            tf.keras.Input(shape=(2, 4), dtype=tf.float32, name='neighbors'),
            tf.keras.Input(
                shape=(2, 1), dtype=tf.float32, name='neighbor_weights'))
  features, neighbors, neighbor_weights = inputs
  unregularized_model = _make_unregularized_model(features, 3)
  logits = unregularized_model(features)
  model = tf.keras.Model(inputs=inputs, outputs=logits)
  # Add regularization.
  regularizer = layers.PairwiseDistance(distance_config)
  graph_loss = regularizer(
      sources=logits,
      targets=unregularized_model(neighbors),
      weights=neighbor_weights)
  model.add_loss(graph_loss)
  model.add_metric(graph_loss, aggregation='mean', name='graph_loss')
  return model
Esempio n. 2
0
 def __init__(self, distance_config, **kwargs):
   super(_PairwiseRegularizedModel, self).__init__(**kwargs)
   self._regularizer = layers.PairwiseDistance(
       distance_config, name='graph_loss')
   self._unregularized_model = tf.keras.Sequential([
       tf.keras.layers.Dense(16, activation='relu'),
       tf.keras.layers.Dense(3),
   ])
    def __init__(self, base_model, graph_reg_config=None):
        """Class initializer.

    Args:
      base_model: Unregularized model to which the loss term resulting from
        graph regularization will be added.
      graph_reg_config: Instance of `GraphRegConfig` that contains configuration
        for graph regularization.
    """

        super(GraphRegularization, self).__init__(name='GraphRegularization')
        self.base_model = base_model
        self.graph_reg_config = (nsl_configs.GraphRegConfig()
                                 if graph_reg_config is None else
                                 graph_reg_config)
        self.nbr_features_layer = nsl_layers.NeighborFeatures(
            self.graph_reg_config.neighbor_config)
        self.regularizer = nsl_layers.PairwiseDistance(
            self.graph_reg_config.distance_config, name='graph_loss')