def test_skip_if_error_should_skip_with_list(self): with self.assertRaises(unittest.SkipTest): with test_util.skip_if_error(self, ValueError, ["foo bar", "test message"]): raise ValueError("test message")
def test_skip_if_error_should_skip(self): with self.assertRaises(unittest.SkipTest): with test_util.skip_if_error(self, ValueError, "test message"): raise ValueError("test message")
def testMultiWorkerTutorial(self, mode, shard_policy): """Test multi-worker training flow demo'ed in go/multi-worker-with-keras. This test should be kept in sync with the code samples in go/multi-worker-with-keras. Args: mode: Runtime mode. shard_policy: None or any of tf.data.experimental.AutoShardPolicy for testing. """ if shard_policy is distribute_options.AutoShardPolicy.FILE: self.skipTest( 'TensorSliceDataset is not shardable with FILE policy.') def mnist_dataset(batch_size): with self.skip_fetch_failure_exception(): (x_train, y_train), _ = mnist.load_data() # The `x` arrays are in uint8 and have values in the range [0, 255]. # We need to convert them to float32 with values in the range [0, 1] x_train = x_train / np.float32(255) y_train = y_train.astype(np.int64) train_dataset = dataset_ops.DatasetV2.from_tensor_slices( (x_train, y_train)).shuffle(60000).repeat().batch(batch_size) return train_dataset def build_and_compile_cnn_model(): model = keras.Sequential([ keras.layers.Input(shape=(28, 28)), keras.layers.Reshape(target_shape=(28, 28, 1)), keras.layers.Conv2D(32, 3, activation='relu'), keras.layers.Flatten(), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10) ]) model.compile(loss=keras.losses.SparseCategoricalCrossentropy( from_logits=True), optimizer=gradient_descent.SGD(learning_rate=0.001), metrics=['accuracy']) return model per_worker_batch_size = 64 single_worker_dataset = mnist_dataset(per_worker_batch_size) single_worker_model = build_and_compile_cnn_model() single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70) num_workers = 4 def fn(model_path, checkpoint_dir): global_batch_size = per_worker_batch_size * num_workers strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( ) with strategy.scope(): multi_worker_model = build_and_compile_cnn_model() callbacks = [ keras.callbacks.ModelCheckpoint( filepath=os.path.join(self.get_temp_dir(), 'checkpoint')) ] multi_worker_dataset = mnist_dataset(global_batch_size) if shard_policy: options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = shard_policy multi_worker_dataset = multi_worker_dataset.with_options( options) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20, callbacks=callbacks) def _is_chief(task_type, task_id): return task_type is None or task_type == 'chief' or ( task_type == 'worker' and task_id == 0) def _get_temp_dir(dirpath, task_id): base_dirpath = 'workertemp_' + str(task_id) temp_dir = os.path.join(dirpath, base_dirpath) file_io.recursive_create_dir_v2(temp_dir) return temp_dir def write_filepath(filepath, task_type, task_id): dirpath = os.path.dirname(filepath) base = os.path.basename(filepath) if not _is_chief(task_type, task_id): dirpath = _get_temp_dir(dirpath, task_id) return os.path.join(dirpath, base) task_type, task_id = (strategy.cluster_resolver.task_type, strategy.cluster_resolver.task_id) write_model_path = write_filepath(model_path, task_type, task_id) multi_worker_model.save(write_model_path) if not _is_chief(task_type, task_id): file_io.delete_recursively_v2( os.path.dirname(write_model_path)) # Make sure chief finishes saving before non-chief's assertions. multi_process_runner.get_barrier().wait() if not file_io.file_exists_v2(model_path): raise RuntimeError() if file_io.file_exists_v2(write_model_path) != _is_chief( task_type, task_id): raise RuntimeError() loaded_model = keras.saving.save.load_model(model_path) loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) checkpoint = tracking_util.Checkpoint(model=multi_worker_model) write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id) checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=write_checkpoint_dir, max_to_keep=1) checkpoint_manager.save() if not _is_chief(task_type, task_id): file_io.delete_recursively_v2(write_checkpoint_dir) # Make sure chief finishes saving before non-chief's assertions. multi_process_runner.get_barrier().wait() if not file_io.file_exists_v2(checkpoint_dir): raise RuntimeError() if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief( task_type, task_id): raise RuntimeError() latest_checkpoint = checkpoint_management.latest_checkpoint( checkpoint_dir) checkpoint.restore(latest_checkpoint) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) logging.info('testMultiWorkerTutorial successfully ends') model_path = os.path.join(self.get_temp_dir(), 'model.tf') checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') with test_util.skip_if_error(self, errors_impl.UnavailableError): mpr_result = multi_process_runner.run( fn, multi_worker_test_base.create_cluster_spec( num_workers=num_workers), args=(model_path, checkpoint_dir), return_output=True) self.assertTrue( any([ 'testMultiWorkerTutorial successfully ends' in msg for msg in mpr_result.stdout ])) def extract_accuracy(worker_id, input_string): match = re.match( r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id), input_string) return None if match is None else float(match.group(1)) for worker_id in range(num_workers): accu_result = nest.map_structure( lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop mpr_result.stdout) self.assertTrue( any(accu_result), 'Every worker is supposed to have accuracy result.')
def test_skip_if_error_should_skip_without_error_message(self): with self.assertRaises(unittest.SkipTest): with test_util.skip_if_error(self, ValueError): raise ValueError()
def testMultiWorkerTutorial(self, mode, shard_policy): """Test multi-worker training flow demo'ed in go/multi-worker-with-keras. This test should be kept in sync with the code samples in go/multi-worker-with-keras. Args: mode: Runtime mode. shard_policy: None or any of tf.data.experimental.AutoShardPolicy for testing. """ if shard_policy is distribute_options.AutoShardPolicy.FILE: self.skipTest( 'TensorSliceDataset is not shardable with FILE policy.') def mnist_dataset(batch_size): with self.skip_fetch_failure_exception(): (x_train, y_train), _ = mnist.load_data() # The `x` arrays are in uint8 and have values in the range [0, 255]. # We need to convert them to float32 with values in the range [0, 1] x_train = x_train / np.float32(255) y_train = y_train.astype(np.int64) train_dataset = dataset_ops.DatasetV2.from_tensor_slices( (x_train, y_train)).shuffle(60000).repeat().batch(batch_size) return train_dataset def build_and_compile_cnn_model(): model = keras.Sequential([ keras.layers.Input(shape=(28, 28)), keras.layers.Reshape(target_shape=(28, 28, 1)), keras.layers.Conv2D(32, 3, activation='relu'), keras.layers.Flatten(), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10) ]) model.compile(loss=keras.losses.SparseCategoricalCrossentropy( from_logits=True), optimizer=gradient_descent.SGD(learning_rate=0.001), metrics=['accuracy']) return model per_worker_batch_size = 64 single_worker_dataset = mnist_dataset(per_worker_batch_size) single_worker_model = build_and_compile_cnn_model() single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70) num_workers = 4 def proc_func(): global_batch_size = per_worker_batch_size * num_workers strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( ) with strategy.scope(): multi_worker_model = build_and_compile_cnn_model() callbacks = [ keras.callbacks.ModelCheckpoint( filepath=os.path.join(self.get_temp_dir(), 'checkpoint')) ] multi_worker_dataset = mnist_dataset(global_batch_size) if shard_policy: options = dataset_ops.Options() options.experimental_distribute.auto_shard_policy = shard_policy multi_worker_dataset = multi_worker_dataset.with_options( options) multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20, callbacks=callbacks) with test_util.skip_if_error(self, errors_impl.UnavailableError): mpr_result = multi_process_runner.run( proc_func, multi_worker_test_base.create_cluster_spec( num_workers=num_workers), list_stdout=True) def extract_accuracy(worker_id, input_string): match = re.match( r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id), input_string) return None if match is None else float(match.group(1)) for worker_id in range(num_workers): accu_result = nest.map_structure( lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop mpr_result.stdout) self.assertTrue( any(accu_result), 'Every worker is supposed to have accuracy result.')