def testDynamicndBatchNormalizationConsistency(self):
    """Checks that BatchNormalization falls into a special case of DN."""
    batch_size = 2
    dm_config = test_util.default_dm_config(
        per_cluster_buffer_size=batch_size,
        distance_to_cluster_threshold=0.5,
        max_num_clusters=1,
        bootstrap_steps=0)
    inputs = np.array([[1, 2], [3, 4]])
    hidden_layer = tf.keras.layers.Dense(
        5, activation='relu', kernel_initializer='ones')(
            inputs)

    batch_norm = tf.keras.layers.BatchNormalization(
        axis=1, center=True, scale=True, momentum=0)(
            hidden_layer, training=True)
    dynamic_norm = dn.DynamicNormalization(
        dm_config,
        mode=dm_ops.LOOKUP_WITH_UPDATE,
        axis=1,
        epsilon=0.001,
        use_batch_normalization=True,
        service_address=self._kbs_address)(
            hidden_layer, training=True)

    self.assertAllClose(batch_norm.numpy(), dynamic_norm.numpy())
  def testTrainingLogistic(self):
    """Trains two logistic regression models with two normalizations."""
    dm_config = test_util.default_dm_config(
        per_cluster_buffer_size=500,
        distance_to_cluster_threshold=0.9,
        max_num_clusters=1,
        bootstrap_steps=10)
    x = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1],
                  [1, 0, 1, 0], [0, 1, 0, 1]])
    y = np.array([[1], [1], [1], [1], [0], [0]])  # 0/1 label
    mode = tf.constant(dm_ops.LOOKUP_WITH_UPDATE, dtype=tf.int32)

    def _create_model(enable_dynamic_normalization: bool):
      hidden_layer = tf.keras.layers.Dense(5, activation='relu')
      if enable_dynamic_normalization:
        normalized_layer = dn.DynamicNormalization(
            dm_config,
            mode=mode,
            axis=1,
            epsilon=0.001,
            use_batch_normalization=False,
            service_address=self._kbs_address)
      else:
        normalized_layer = tf.keras.layers.BatchNormalization(
            axis=1, center=True, scale=True, momentum=0)
      output_layer = tf.keras.layers.Dense(1, kernel_initializer='ones')
      model = tf.keras.Sequential(
          [hidden_layer, normalized_layer, output_layer])
      return model

    def _loss(model, x, y):
      output = model(x)
      pred = 1 / (1 + tf.exp(-output))
      loss = y * tf.math.log(pred) + (1 - y) * tf.math.log(1 - pred)
      loss = tf.reduce_mean(-tf.math.reduce_sum(loss, axis=1))
      return loss

    def _grad(model, x, y):
      with tf.GradientTape() as tape:
        loss_value = _loss(model, x, y)
      return loss_value, tape.gradient(loss_value, model.trainable_variables)

    bn_model = _create_model(False)  # Model with batch normalization.
    dn_model = _create_model(True)  # Model with dynamic normalization.
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    for i in range(100):
      bn_loss_value, bn_grads = _grad(bn_model, x, y)
      dn_loss_value, dn_grads = _grad(dn_model, x, y)

      print('Step: {}, BN loss: {}, DN loss: {}'.format(i,
                                                        bn_loss_value.numpy(),
                                                        dn_loss_value.numpy()))
      # Update the trainable variables w.r.t. the logistic loss
      optimizer.apply_gradients(zip(bn_grads, bn_model.trainable_variables))
      optimizer.apply_gradients(zip(dn_grads, dn_model.trainable_variables))

    # Checks that DynamicNormalization consistently outperforms
    # BatchNormalization in terms of finding lower loss.
    self.assertGreater(bn_loss_value.numpy(), dn_loss_value.numpy())
    def testGaussianMemoryLookupWithSingleCluster_3DInput(self):
        dm_config = test_util.default_dm_config(
            per_cluster_buffer_size=4,
            distance_to_cluster_threshold=0.5,
            bootstrap_steps=0,
            min_variance=1,
            max_num_clusters=1)
        inputs = [[[0, 0], [1, 0]], [[101, 0], [0, 101]]]
        mode = dm_ops.LOOKUP_WITH_UPDATE
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)

        # Mean: [(1 + 101) / 4, 101 / 4]
        self.assertAllClose(
            mean.numpy(),
            [[[25.5, 25.25], [25.5, 25.25]], [[25.5, 25.25], [25.5, 25.25]]])
        self.assertAllClose(variance.numpy(),
                            [[[1900.25, 1912.6875], [1900.25, 1912.6875]],
                             [[1900.25, 1912.6875], [1900.25, 1912.6875]]])
        self.assertAllClose(
            distance.numpy(),
            [[0.84460914, 0.85018337], [0.43462682, 0.4336368]])
        self.assertAllClose(cid.numpy(), [[0, 0], [0, 0]])

        mode = dm_ops.LOOKUP_WITH_GROW
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)
        # Single cluster, do not grow.
        self.assertAllClose(
            mean.numpy(),
            [[[25.5, 25.25], [25.5, 25.25]], [[25.5, 25.25], [25.5, 25.25]]])
        self.assertAllClose(variance.numpy(),
                            [[[1900.25, 1912.6875], [1900.25, 1912.6875]],
                             [[1900.25, 1912.6875], [1900.25, 1912.6875]]])
        self.assertAllClose(
            distance.numpy(),
            [[0.84460914, 0.85018337], [0.43462682, 0.4336368]])
        self.assertAllClose(cid.numpy(), [[0, 0], [0, 0]])
    def testGaussianMemoryLookupWithMutiClusterWithGrow(self):
        dm_config = test_util.default_dm_config(
            per_cluster_buffer_size=3,
            distance_to_cluster_threshold=0.7,
            bootstrap_steps=0,
            min_variance=1,
            max_num_clusters=2)
        inputs = [[0, 0], [1, 0], [101, 0]]
        mode = dm_ops.LOOKUP_WITH_UPDATE
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)

        # Mean: [(1 + 101) / 3, 0]
        # Variance: [(34^2 + (1-34)^2 + (101-34)^2)/3, 0]
        # Distance: exp(-((34 - 0)^2/2)/ (2 * 2244.6667))
        self.assertAllClose(mean.numpy(), [[34, 0], [34, 0], [34, 0]])
        self.assertAllClose(variance.numpy(),
                            [[2244.6667, 1], [2244.6667, 1], [2244.6667, 1]])
        self.assertAllClose(distance.numpy(),
                            [0.8791941, 0.88577926, 0.6065532])
        self.assertAllClose(cid.numpy(), [0, 0, 0])

        mode = dm_ops.LOOKUP_WITH_GROW
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)
        # A new cluster is formed with single data [101, 0].
        self.assertAllClose(mean.numpy(), [[34, 0], [34, 0], [101, 0]])
        self.assertAllClose(variance.numpy(),
                            [[2244.6667, 1], [2244.6667, 1], [1, 1]])
        self.assertAllClose(distance.numpy(), [0.8791941, 0.88577926, 1])
        self.assertAllClose(cid.numpy(), [0, 0, 1])
  def testKerasModel(self):
    """Tests the tf.keras interface."""
    dm_config = test_util.default_dm_config(
        per_cluster_buffer_size=1,
        distance_to_cluster_threshold=0.5,
        max_num_clusters=1,
        bootstrap_steps=0)
    model = tf.keras.Sequential([
        tf.keras.Input(shape=(2,), name='inputs'),
        tf.keras.layers.Dense(4),
        dn.DynamicNormalization(
            dm_config,
            mode=dm_ops.LOOKUP_WITH_UPDATE,
            axis=1,
            epsilon=0.1,
            use_batch_normalization=True,
            service_address=self._kbs_address),
        tf.keras.layers.Dense(1)
    ])
    model.compile(optimizer='sgd', loss='mse')

    x_train = np.array([[1, 0], [0, 1], [1, 1], [0, 0]])
    y_train = np.array([1, 1, 0, 0])
    model.fit(x_train, y_train, epochs=10)
    def testGaussianMemoryLookupWithSingleCluster(self):
        dm_config = test_util.default_dm_config(
            per_cluster_buffer_size=3,
            distance_to_cluster_threshold=0.5,
            bootstrap_steps=0,
            min_variance=1,
            max_num_clusters=1)
        inputs = [[0, 0], [1, 0], [101, 0]]
        mode = dm_ops.LOOKUP_WITH_UPDATE
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)

        # Mean: [(1 + 101) / 3, 0]
        # Variance: [(34^2 + (1-34)^2 + (101-34)^2)/3, 0]
        # Distance: exp(-((34 - 0)^2/2)/ (2 * 2244.6667))
        self.assertAllClose(mean.numpy(), [[34, 0], [34, 0], [34, 0]])
        self.assertAllClose(variance.numpy(),
                            [[2244.6667, 1], [2244.6667, 1], [2244.6667, 1]])
        self.assertAllClose(distance.numpy(),
                            [0.8791941, 0.88577926, 0.6065532])
        self.assertAllClose(cid.numpy(), [0, 0, 0])

        # Switch x and y values of input.
        inputs = [[0, 0], [0, 1], [0, 101]]
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)
        self.assertAllClose(mean.numpy(), [[0, 34], [0, 34], [0, 34]])
        self.assertAllClose(variance.numpy(),
                            [[1, 2244.6667], [1, 2244.6667], [1, 2244.6667]])
        self.assertAllClose(distance.numpy(),
                            [0.8791941, 0.88577926, 0.6065532])
        self.assertAllClose(cid.numpy(), [0, 0, 0])

        # Lookup without update mode.
        inputs = [[0, 0], [1, 0], [101, 0]]
        mode = dm_ops.LOOKUP_WITHOUT_UPDATE
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)
        # Returns the same mean and variance as above.
        self.assertAllClose(mean.numpy(), [[0, 34], [0, 34], [0, 34]])
        self.assertAllClose(variance.numpy(),
                            [[1, 2244.6667], [1, 2244.6667], [1, 2244.6667]])
        # [0, 0]: exp(-(((0 - 0)^2/1 + 34^2/2244.6667)/2)/2)
        # [1, 0]: exp(-(((1 - 0)^2/1 + 34^2/2244.6667)/2)/2)
        # [101, 0]: exp(-(((100 - 0)^2/1 + 34^2/2244.6667)/2)/2)
        self.assertAllClose(distance.numpy(), [0.8791941, 0.68471706, 0])
        self.assertAllClose(cid.numpy(), [0, 0, 0])

        # Lookup without grow mode, it's equivalent to update for single cluster.
        inputs = [[10, 0], [40, 0], [70, 0]]
        mode = dm_ops.LOOKUP_WITH_GROW
        mean, variance, distance, cid = dm_ops.dynamic_gaussian_memory_lookup(
            inputs,
            mode,
            dm_config,
            'dm_layer',
            service_address=self._kbs_address)
        self.assertAllClose(mean.numpy(), [[40, 0], [40, 0], [40, 0]])
        self.assertAllClose(variance.numpy(), [[600, 1], [600, 1], [600, 1]])
        # [10, 0]; exp(-(((10 - 40)^2/600)/2)/2)
        # [40, 0]; exp(-(((40 - 40)^2/600)/2)/2)
        # [70, 0]; exp(-(((70 - 40)^2/600)/2)/2)
        self.assertAllClose(distance.numpy(), [0.6872893, 1, 0.6872893])
        self.assertAllClose(cid.numpy(), [0, 0, 0])