def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn,
                                      momentum, renorm,
                                      update_ops_in_cross_replica_mode):
        """Verifies that moving mean updates are reduced across replicas."""
        with distribution.scope():
            num_replicas = distribution.num_replicas_in_sync
            model_fn, dataset_fn, batchnorm = batchnorm_example(
                optimizer_fn,
                batch_per_epoch=num_replicas,
                momentum=momentum,
                renorm=renorm,
                update_ops_in_replica_mode=not update_ops_in_cross_replica_mode
            )

            def step_fn(ctx, inputs):
                del ctx  # Unused
                fetches = distribution.experimental_local_results(
                    distribution.extended.call_for_each_replica(
                        model_fn, args=(inputs, )))
                if update_ops_in_cross_replica_mode:
                    fetches += tuple(
                        ops.get_collection(ops.GraphKeys.UPDATE_OPS))
                return control_flow_ops.group(fetches)

            iterator = self._get_iterator(distribution, dataset_fn)

            def run_step():
                return distribution.extended.experimental_run_steps_on_iterator(
                    step_fn, iterator, iterations=1).run_op

            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            expected_moving_means = [0.] * 8

            def averaged_batch_mean(i):
                # Each batch has shape [16, 8] where the ith element in jth list is
                # (8 * j + i + replica_id * 100). So the batch mean in each replica is
                # (60 + i + replica_id * 100). So here comes its batch mean over all
                # replicas:
                return 60. + i + (num_replicas - 1.) / 2. * 100.

            for _ in range(10):
                run_step()
                moving_means = self.evaluate(batchnorm.moving_mean)

                # We make sure that the moving_mean is updated as if the sample mean is
                # calculated over all replicas.
                for i, expected_moving_mean in enumerate(
                        expected_moving_means):
                    expected_moving_means[i] -= (
                        (expected_moving_mean - averaged_batch_mean(i)) *
                        (1.0 - momentum))
                    self.assertNear(expected_moving_means[i], moving_means[i],
                                    0.0001)
Exemplo n.º 2
0
  def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
                                    renorm, update_ops_in_cross_replica_mode):
    """Verifies that moving mean updates are reduced across replicas."""
    with distribution.scope():
      num_replicas = distribution.num_replicas_in_sync
      model_fn, dataset_fn, batchnorm = batchnorm_example(
          optimizer_fn,
          batch_per_epoch=num_replicas,
          momentum=momentum,
          renorm=renorm,
          update_ops_in_replica_mode=not update_ops_in_cross_replica_mode)

      def step_fn(ctx, inputs):
        del ctx  # Unused
        fetches = distribution.experimental_local_results(
            distribution.extended.call_for_each_replica(
                model_fn, args=(inputs,)))
        if update_ops_in_cross_replica_mode:
          fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
        return control_flow_ops.group(fetches)

      iterator = self._get_iterator(distribution, dataset_fn)

      def run_step():
        return distribution.extended.experimental_run_steps_on_iterator(
            step_fn, iterator, iterations=1).run_op

      if not context.executing_eagerly():
        with self.cached_session() as sess:
          run_step = sess.make_callable(run_step())
      self.evaluate(variables_lib.global_variables_initializer())

      expected_moving_means = [0.] * 8

      def averaged_batch_mean(i):
        # Each batch has shape [16, 8] where the ith element in jth list is
        # (8 * j + i + replica_id * 100). So the batch mean in each replica is
        # (60 + i + replica_id * 100). So here comes its batch mean over all
        # replicas:
        return 60. + i + (num_replicas - 1.) / 2. * 100.

      for _ in range(10):
        run_step()
        moving_means = self.evaluate(batchnorm.moving_mean)

        # We make sure that the moving_mean is updated as if the sample mean is
        # calculated over all replicas.
        for i, expected_moving_mean in enumerate(expected_moving_means):
          expected_moving_means[i] -= ((
              expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
          self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)