def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, renorm, update_ops_in_cross_replica_mode): """Verifies that moving mean updates are reduced across replicas.""" with distribution.scope(): num_replicas = distribution.num_replicas_in_sync model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_replicas, momentum=momentum, renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode ) def step_fn(ctx, *inputs): del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) iterator = self._get_iterator( distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset(step_fn, iterator, iterations=1).run_op self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) expected_moving_means = [0.] * 8 def averaged_batch_mean(i): # Each batch has shape [16, 8] where the ith element in jth list is # (8 * j + i + replica_id * 100). So the batch mean in each replica is # (60 + i + replica_id * 100). So here comes its batch mean over all # replicas: return 60. + i + (num_replicas - 1.) / 2. * 100. for _ in range(10): run_step() moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all replicas. for i, expected_moving_mean in enumerate( expected_moving_means): expected_moving_means[i] -= ( (expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) self.evaluate(distribution.finalize())
def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, renorm, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, renorm=renorm, update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) # Make sure prefetching is disabled since that makes the # specific input on each device to be non deterministic, and # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): self.assertFalse(distribution._prefetch_on_device) def step_fn(ctx, *inputs): del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=batchnorm.built)) if update_ops_in_cross_tower_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( step_fn, iterator, iterations=1).run_op self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) expected_moving_means = [0.] * 8 def averaged_batch_mean(i): # Each batch has shape [16, 8] where the ith element in jth list is # (8 * j + i + tower_id * 100). So the batch mean in each tower is # (60 + i + tower_id * 100). So here comes its batch mean over all # towers: return 60. + i + (num_towers - 1.) / 2. * 100. for _ in range(10): run_step() moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all towers. for i, expected_moving_mean in enumerate(expected_moving_means): expected_moving_means[i] -= (( expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) self.evaluate(distribution.finalize())
def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, renorm, is_tpu, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, renorm=renorm, update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) # Make sure prefetching is disabled since that makes the # specific input on each device to be non deterministic, and # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): self.assertFalse(distribution._prefetch_on_device) iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): fetches = distribution.unwrap( distribution.call_for_each_tower( model_fn, iterator.get_next(), run_concurrently=batchnorm.built)) if update_ops_in_cross_tower_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) if not context.executing_eagerly(): with self.test_session() as sess: if is_tpu: sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) expected_moving_means = [0.] * 8 def averaged_batch_mean(i): # Each batch has shape [16, 8] where the ith element in jth list is # (8 * j + i + tower_id * 100). So the batch mean in each tower is # (60 + i + tower_id * 100). So here comes its batch mean over all # towers: return 60. + i + (num_towers - 1.) / 2. * 100. for _ in range(10): run_step() moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all towers. for i, expected_moving_mean in enumerate(expected_moving_means): expected_moving_means[i] -= (( expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) if is_tpu: with self.test_session() as sess: sess.run(tpu.shutdown_system())
def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, renorm, update_ops_in_cross_replica_mode): """Verifies that moving mean updates are reduced across replicas.""" with distribution.scope(): num_replicas = distribution.num_replicas_in_sync model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_replicas, momentum=momentum, renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) def step_fn(ctx, *inputs): del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( step_fn, iterator, iterations=1).run_op self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) expected_moving_means = [0.] * 8 def averaged_batch_mean(i): # Each batch has shape [16, 8] where the ith element in jth list is # (8 * j + i + replica_id * 100). So the batch mean in each replica is # (60 + i + replica_id * 100). So here comes its batch mean over all # replicas: return 60. + i + (num_replicas - 1.) / 2. * 100. for _ in range(10): run_step() moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all replicas. for i, expected_moving_mean in enumerate(expected_moving_means): expected_moving_means[i] -= (( expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) self.evaluate(distribution.finalize())