Exemplo n.º 1
0
    def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn,
                                      momentum, renorm,
                                      update_ops_in_cross_replica_mode):
        """Verifies that moving mean updates are reduced across replicas."""
        with distribution.scope():
            num_replicas = distribution.num_replicas_in_sync
            model_fn, dataset_fn, batchnorm = batchnorm_example(
                optimizer_fn,
                batch_per_epoch=num_replicas,
                momentum=momentum,
                renorm=renorm,
                update_ops_in_replica_mode=not update_ops_in_cross_replica_mode
            )

            def step_fn(ctx, *inputs):
                del ctx  # Unused
                fetches = distribution.unwrap(
                    distribution.call_for_each_replica(model_fn, args=inputs))
                if update_ops_in_cross_replica_mode:
                    fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
                return control_flow_ops.group(fetches)

            iterator = self._get_iterator(
                distribution.distribute_dataset(dataset_fn))

            def run_step():
                return distribution.run_steps_on_dataset(step_fn,
                                                         iterator,
                                                         iterations=1).run_op

            self.evaluate(distribution.initialize())
            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            expected_moving_means = [0.] * 8

            def averaged_batch_mean(i):
                # Each batch has shape [16, 8] where the ith element in jth list is
                # (8 * j + i + replica_id * 100). So the batch mean in each replica is
                # (60 + i + replica_id * 100). So here comes its batch mean over all
                # replicas:
                return 60. + i + (num_replicas - 1.) / 2. * 100.

            for _ in range(10):
                run_step()
                moving_means = self.evaluate(batchnorm.moving_mean)

                # We make sure that the moving_mean is updated as if the sample mean is
                # calculated over all replicas.
                for i, expected_moving_mean in enumerate(
                        expected_moving_means):
                    expected_moving_means[i] -= (
                        (expected_moving_mean - averaged_batch_mean(i)) *
                        (1.0 - momentum))
                    self.assertNear(expected_moving_means[i], moving_means[i],
                                    0.0001)

            self.evaluate(distribution.finalize())
Exemplo n.º 2
0
  def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
                                    renorm, update_ops_in_cross_tower_mode):
    """Verifies that moving mean updates are reduced across towers."""
    with distribution.scope():
      num_towers = len(distribution.worker_devices)
      model_fn, dataset_fn, batchnorm = batchnorm_example(
          optimizer_fn,
          batch_per_epoch=num_towers,
          momentum=momentum,
          renorm=renorm,
          update_ops_in_tower_mode=not update_ops_in_cross_tower_mode)

      # Make sure prefetching is disabled since that makes the
      # specific input on each device to be non deterministic, and
      # this test relies on specific input being on each device.
      if isinstance(distribution, mirrored_strategy.MirroredStrategy):
        self.assertFalse(distribution._prefetch_on_device)

      def step_fn(ctx, *inputs):
        del ctx  # Unused
        fetches = distribution.unwrap(
            distribution.call_for_each_tower(
                model_fn, *inputs, run_concurrently=batchnorm.built))
        if update_ops_in_cross_tower_mode:
          fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
        return control_flow_ops.group(fetches)

      iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))

      def run_step():
        return distribution.run_steps_on_dataset(
            step_fn, iterator, iterations=1).run_op

      self.evaluate(distribution.initialize())
      if not context.executing_eagerly():
        with self.cached_session() as sess:
          run_step = sess.make_callable(run_step())
      self.evaluate(variables_lib.global_variables_initializer())

      expected_moving_means = [0.] * 8

      def averaged_batch_mean(i):
        # Each batch has shape [16, 8] where the ith element in jth list is
        # (8 * j + i + tower_id * 100). So the batch mean in each tower is
        # (60 + i + tower_id * 100). So here comes its batch mean over all
        # towers:
        return 60. + i + (num_towers - 1.) / 2. * 100.

      for _ in range(10):
        run_step()
        moving_means = self.evaluate(batchnorm.moving_mean)

        # We make sure that the moving_mean is updated as if the sample mean is
        # calculated over all towers.
        for i, expected_moving_mean in enumerate(expected_moving_means):
          expected_moving_means[i] -= ((
              expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
          self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)

      self.evaluate(distribution.finalize())
Exemplo n.º 3
0
  def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
                                    renorm, is_tpu,
                                    update_ops_in_cross_tower_mode):
    """Verifies that moving mean updates are reduced across towers."""
    with distribution.scope():
      num_towers = len(distribution.worker_devices)
      model_fn, dataset_fn, batchnorm = batchnorm_example(
          optimizer_fn,
          batch_per_epoch=num_towers,
          momentum=momentum,
          renorm=renorm,
          update_ops_in_tower_mode=not update_ops_in_cross_tower_mode)

      # Make sure prefetching is disabled since that makes the
      # specific input on each device to be non deterministic, and
      # this test relies on specific input being on each device.
      if isinstance(distribution, mirrored_strategy.MirroredStrategy):
        self.assertFalse(distribution._prefetch_on_device)
      iterator = distribution.distribute_dataset(
          dataset_fn).make_one_shot_iterator()

      def run_step():
        fetches = distribution.unwrap(
            distribution.call_for_each_tower(
                model_fn, iterator.get_next(),
                run_concurrently=batchnorm.built))
        if update_ops_in_cross_tower_mode:
          fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
        return control_flow_ops.group(fetches)

      if not context.executing_eagerly():
        with self.test_session() as sess:
          if is_tpu:
            sess.run(tpu.initialize_system())
          run_step = sess.make_callable(run_step())
        self.evaluate(variables_lib.global_variables_initializer())

      expected_moving_means = [0.] * 8

      def averaged_batch_mean(i):
        # Each batch has shape [16, 8] where the ith element in jth list is
        # (8 * j + i + tower_id * 100). So the batch mean in each tower is
        # (60 + i + tower_id * 100). So here comes its batch mean over all
        # towers:
        return 60. + i + (num_towers - 1.) / 2. * 100.

      for _ in range(10):
        run_step()
        moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean))

        # We make sure that the moving_mean is updated as if the sample mean is
        # calculated over all towers.
        for i, expected_moving_mean in enumerate(expected_moving_means):
          expected_moving_means[i] -= ((
              expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
          self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)

      if is_tpu:
        with self.test_session() as sess:
          sess.run(tpu.shutdown_system())
  def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
                                    renorm, update_ops_in_cross_replica_mode):
    """Verifies that moving mean updates are reduced across replicas."""
    with distribution.scope():
      num_replicas = distribution.num_replicas_in_sync
      model_fn, dataset_fn, batchnorm = batchnorm_example(
          optimizer_fn,
          batch_per_epoch=num_replicas,
          momentum=momentum,
          renorm=renorm,
          update_ops_in_replica_mode=not update_ops_in_cross_replica_mode)

      def step_fn(ctx, *inputs):
        del ctx  # Unused
        fetches = distribution.unwrap(
            distribution.call_for_each_replica(model_fn, args=inputs))
        if update_ops_in_cross_replica_mode:
          fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
        return control_flow_ops.group(fetches)

      iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))

      def run_step():
        return distribution.run_steps_on_dataset(
            step_fn, iterator, iterations=1).run_op

      self.evaluate(distribution.initialize())
      if not context.executing_eagerly():
        with self.cached_session() as sess:
          run_step = sess.make_callable(run_step())
      self.evaluate(variables_lib.global_variables_initializer())

      expected_moving_means = [0.] * 8

      def averaged_batch_mean(i):
        # Each batch has shape [16, 8] where the ith element in jth list is
        # (8 * j + i + replica_id * 100). So the batch mean in each replica is
        # (60 + i + replica_id * 100). So here comes its batch mean over all
        # replicas:
        return 60. + i + (num_replicas - 1.) / 2. * 100.

      for _ in range(10):
        run_step()
        moving_means = self.evaluate(batchnorm.moving_mean)

        # We make sure that the moving_mean is updated as if the sample mean is
        # calculated over all replicas.
        for i, expected_moving_mean in enumerate(expected_moving_means):
          expected_moving_means[i] -= ((
              expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
          self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)

      self.evaluate(distribution.finalize())