Esempio n. 1
0
    def test_model_checkpoint_works_with_same_file_path(self, mode):
        def proc_model_checkpoint_works_with_same_file_path(
                test_obj, saving_filepath):
            model, _, train_ds, steps = _model_setup(test_obj, file_format='')
            num_epoch = 2

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(file_io.file_exists(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(filepath=saving_filepath)
                      ])

            test_obj.assertTrue(file_io.file_exists(saving_filepath))

        saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')

        multi_process_runner.run(
            proc_model_checkpoint_works_with_same_file_path,
            cluster_spec=test_base.create_cluster_spec(num_workers=2),
            args=(self, saving_filepath))
    def test_multi_process_runner(self):
        mpr_result = multi_process_runner.run(
            proc_func_that_adds_task_type_in_return_data,
            multi_worker_test_base.create_cluster_spec(num_workers=2,
                                                       num_ps=3,
                                                       has_eval=1),
            args=(self, 3))

        job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
        for data in mpr_result.return_value:
            job_count_dict[data] -= 1

        self.assertEqual(job_count_dict['worker'], 0)
        self.assertEqual(job_count_dict['ps'], 0)
        self.assertEqual(job_count_dict['evaluator'], 0)
Esempio n. 3
0
    def test_stdout_captured(self):
        def simple_print_func():
            print('This is something printed.')
            return 'This is returned data.'

        returned_data, std_stream_data = multi_process_runner.run(
            simple_print_func,
            multi_worker_test_base.create_cluster_spec(num_workers=2),
            capture_std_stream=True)
        num_string_std_stream = len(
            [d for d in std_stream_data if d == 'This is something printed.'])
        num_string_returned_data = len(
            [d for d in returned_data if d == 'This is returned data.'])
        self.assertEqual(num_string_std_stream, 2)
        self.assertEqual(num_string_returned_data, 2)
  def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):

    def proc_tensorboard_saves_on_chief_but_not_otherwise(test_obj):
      model, _, train_ds, steps = _model_setup(test_obj, file_format='')
      num_epoch = 2

      # Incorporate type/index information and thread id in saving_filepath to
      # ensure every worker has a unique path. Note that in normal use case the
      # saving_filepath will be the same for all workers, but we use different
      # ones here just to test out chief saves summaries but non-chief doesn't.
      saving_filepath = os.path.join(
          test_obj.get_temp_dir(), 'logfile_%s_%d' %
          (test_base.get_task_type(), test_base.get_task_index()))

      # The saving_filepath shouldn't exist at the beginning (as it's unique).
      test_obj.assertFalse(file_io.file_exists(saving_filepath))

      model.fit(
          x=train_ds,
          epochs=num_epoch,
          steps_per_epoch=steps,
          callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])

      # If it's chief, the summaries should be saved in the filepath; if not,
      # the directory should be empty (although created). Using
      # `file_io.list_directory()` since the directory may be created at this
      # point.
      test_obj.assertEqual(
          bool(file_io.list_directory(saving_filepath)), test_base.is_chief())

    # TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
    with multi_process_runner_util.try_run_and_except_connection_error(self):
      multi_process_runner.run(
          proc_tensorboard_saves_on_chief_but_not_otherwise,
          cluster_spec=test_base.create_cluster_spec(num_workers=2),
          args=(self,))
    def test_stdout_captured(self):
        def simple_print_func():
            print('This is something printed.', flush=True)
            return 'This is returned data.'

        mpr_result = multi_process_runner.run(
            simple_print_func,
            multi_worker_test_base.create_cluster_spec(num_workers=2),
            list_stdout=True)
        std_stream_results = mpr_result.stdout
        return_value = mpr_result.return_value
        self.assertIn('[worker-0]:    This is something printed.\n',
                      std_stream_results)
        self.assertIn('[worker-1]:    This is something printed.\n',
                      std_stream_results)
        self.assertIn('This is returned data.', return_value)
    def test_stdout_captured(self):
        def simple_print_func():
            print('This is something printed.')
            return 'This is returned data.'

        job_count_dict = {'worker': 2}
        returned_data, std_stream_data = multi_process_runner.run(
            simple_print_func,
            multi_process_runner.job_count_to_cluster_spec(job_count_dict),
            return_std_stream=True)
        num_string_std_stream = len(
            [d for d in std_stream_data if d == 'This is something printed.'])
        num_string_returned_data = len(
            [d for d in returned_data if d == 'This is returned data.'])
        self.assertEqual(num_string_std_stream, 2)
        self.assertEqual(num_string_returned_data, 2)
    def test_multi_process_runner(self):
        job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 2}
        proc_flags = {
            'test_flag': 3,
        }
        returned_data = multi_process_runner.run(
            proc_func_that_adds_task_type_in_return_data,
            multi_process_runner.job_count_to_cluster_spec(job_count_dict),
            proc_flags=proc_flags,
            args=(self, ))

        for data in returned_data:
            job_count_dict[data] -= 1

        self.assertEqual(job_count_dict['worker'], 0)
        self.assertEqual(job_count_dict['ps'], 0)
        self.assertEqual(job_count_dict['evaluator'], 0)
    def test_multi_process_runner(self):
        count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1}
        proc_flags = {
            'test_flag': 3,
        }
        returned_data = multi_process_runner.run(
            proc_func_that_adds_task_type_in_return_data,
            count_dict,
            proc_flags,
            args=(self, ))

        for data in returned_data:
            count_dict[data] -= 1

        self.assertEqual(count_dict['worker'], 0)
        self.assertEqual(count_dict['ps'], 0)
        self.assertEqual(count_dict['evaluator'], 0)
Esempio n. 9
0
  def decorator(self, has_chief, num_workers, runner, **kwargs):
    if _num_total_workers(has_chief, num_workers) == 1 or _running_in_worker:
      # We're in worker process or the test is for single worker. Either case we
      # execute the test method directly instead of spawning subprocesses.
      test_method(self, **kwargs)
      return

    # We're in the main process. We spawn subprocesses and run the *test* on
    # each of them. Note that we're not directly executing test_method passed to
    # _multi_worker_test, because we need setUp()/tearDown() to be called and
    # all the decorations on the test method. The conceptual call stack is:
    #   [main process]test.main()
    #     [main process]test_runner.run(test)
    #       [main process]wrapper by combinations.generate()
    #         [main process]_multi_worker_test.decorator()
    #           # A sub process goes through the same code path as the main
    #           # process.
    #           [sub process]_test_runner()
    #             [sub process]test_runner.run(test)
    #               [sub process]wrapper by combinations.generate()
    #                 [sub process]_multi_worker_test.decorator()
    #                   # _running_in_worker is True
    #                   [sub process]test_method()
    test_id = self.id()
    if runner:
      result = runner.run(_test_runner, args=(test_id,))
    else:
      cluster_spec = multi_worker_test_base.create_cluster_spec(
          has_chief=has_chief,
          num_workers=num_workers,
          num_ps=0,
          has_eval=False)
      result = multi_process_runner.run(
          _test_runner, cluster_spec, args=(test_id,)).return_value
    for was_successful in result:
      if not was_successful:
        raise AssertionError("some worker failed, see logs for details")
    def test_backupandrestore_checkpoint_works_with_interruption(self, mode):
        class InterruptingCallback(callbacks.Callback):
            def on_epoch_begin(self, epoch, logs=None):
                if epoch == 2:
                    raise RuntimeError('Interrupting!')

        class AssertCallback(callbacks.Callback):
            def on_epoch_begin(self, epoch, logs=None):
                # the interruption happened on epoch 2 as specified in
                # InterruptingCallback, so the initial epoch after restart will begin
                # at 2.
                assert epoch > 1

        def proc_model_checkpoint_works_with_same_file_path(
                test_obj, saving_filepath):
            if multi_process_runner.is_oss():
                test_obj.skipTest('TODO(b/170838633): Failing in OSS')
            model, _, train_ds, steps = _model_setup(test_obj, file_format='')
            num_epoch = 4

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
            bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')

            try:
                model.fit(
                    x=train_ds,
                    epochs=num_epoch,
                    steps_per_epoch=steps,
                    callbacks=[
                        callbacks.ModelCheckpoint(filepath=saving_filepath),
                        callbacks.BackupAndRestore(backup_dir=bar_dir),
                        InterruptingCallback()
                    ])
            except RuntimeError as e:
                if 'Interrupting!' not in str(e):
                    raise

            multi_process_runner.get_barrier().wait()
            backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint')
            test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
            test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(filepath=saving_filepath),
                          callbacks.BackupAndRestore(backup_dir=bar_dir),
                          AssertCallback()
                      ])
            multi_process_runner.get_barrier().wait()
            test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
            test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))

        saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')

        multi_process_runner.run(
            proc_model_checkpoint_works_with_same_file_path,
            cluster_spec=test_base.create_cluster_spec(num_workers=2),
            args=(self, saving_filepath))
 def test_barrier(self):
     multi_process_runner.run(
         proc_func_with_barrier,
         cluster_spec=multi_worker_test_base.create_cluster_spec(
             has_chief=True, num_workers=1),
     )
    def testMultiWorkerTutorial(self, mode, shard_policy):
        """Test multi-worker training flow demo'ed in go/multi-worker-with-keras.

    This test should be kept in sync with the code samples in
    go/multi-worker-with-keras.

    Args:
      mode: Runtime mode.
      shard_policy: None or any of tf.data.experimental.AutoShardPolicy for
        testing.
    """
        if shard_policy is distribute_options.AutoShardPolicy.FILE:
            self.skipTest(
                'TensorSliceDataset is not shardable with FILE policy.')

        def mnist_dataset(batch_size):
            with self.skip_fetch_failure_exception():
                (x_train, y_train), _ = mnist.load_data()
            # The `x` arrays are in uint8 and have values in the range [0, 255].
            # We need to convert them to float32 with values in the range [0, 1]
            x_train = x_train / np.float32(255)
            y_train = y_train.astype(np.int64)
            train_dataset = dataset_ops.DatasetV2.from_tensor_slices(
                (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
            return train_dataset

        def build_and_compile_cnn_model():
            model = keras.Sequential([
                keras.layers.Input(shape=(28, 28)),
                keras.layers.Reshape(target_shape=(28, 28, 1)),
                keras.layers.Conv2D(32, 3, activation='relu'),
                keras.layers.Flatten(),
                keras.layers.Dense(128, activation='relu'),
                keras.layers.Dense(10)
            ])
            model.compile(loss=keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
                          optimizer=gradient_descent.SGD(learning_rate=0.001),
                          metrics=['accuracy'])
            return model

        per_worker_batch_size = 64

        single_worker_dataset = mnist_dataset(per_worker_batch_size)
        single_worker_model = build_and_compile_cnn_model()
        single_worker_model.fit(single_worker_dataset,
                                epochs=3,
                                steps_per_epoch=70)

        num_workers = 4

        def fn(model_path, checkpoint_dir):
            global_batch_size = per_worker_batch_size * num_workers
            strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            )
            with strategy.scope():
                multi_worker_model = build_and_compile_cnn_model()

            callbacks = [
                keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
            ]

            multi_worker_dataset = mnist_dataset(global_batch_size)
            if shard_policy:
                options = dataset_ops.Options()
                options.experimental_distribute.auto_shard_policy = shard_policy
                multi_worker_dataset = multi_worker_dataset.with_options(
                    options)

            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20,
                                   callbacks=callbacks)

            def _is_chief(task_type, task_id):
                return task_type is None or task_type == 'chief' or (
                    task_type == 'worker' and task_id == 0)

            def _get_temp_dir(dirpath, task_id):
                base_dirpath = 'workertemp_' + str(task_id)
                temp_dir = os.path.join(dirpath, base_dirpath)
                file_io.recursive_create_dir_v2(temp_dir)
                return temp_dir

            def write_filepath(filepath, task_type, task_id):
                dirpath = os.path.dirname(filepath)
                base = os.path.basename(filepath)
                if not _is_chief(task_type, task_id):
                    dirpath = _get_temp_dir(dirpath, task_id)
                return os.path.join(dirpath, base)

            task_type, task_id = (strategy.cluster_resolver.task_type,
                                  strategy.cluster_resolver.task_id)
            write_model_path = write_filepath(model_path, task_type, task_id)

            multi_worker_model.save(write_model_path)
            if not _is_chief(task_type, task_id):
                file_io.delete_recursively_v2(
                    os.path.dirname(write_model_path))

            # Make sure chief finishes saving before non-chief's assertions.
            multi_process_runner.get_barrier().wait()

            if not file_io.file_exists_v2(model_path):
                raise RuntimeError()
            if file_io.file_exists_v2(write_model_path) != _is_chief(
                    task_type, task_id):
                raise RuntimeError()

            loaded_model = keras.saving.save.load_model(model_path)
            loaded_model.fit(multi_worker_dataset,
                             epochs=2,
                             steps_per_epoch=20)

            checkpoint = tracking_util.Checkpoint(model=multi_worker_model)
            write_checkpoint_dir = write_filepath(checkpoint_dir, task_type,
                                                  task_id)
            checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

            checkpoint_manager.save()
            if not _is_chief(task_type, task_id):
                file_io.delete_recursively_v2(write_checkpoint_dir)

            # Make sure chief finishes saving before non-chief's assertions.
            multi_process_runner.get_barrier().wait()

            if not file_io.file_exists_v2(checkpoint_dir):
                raise RuntimeError()
            if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
                    task_type, task_id):
                raise RuntimeError()

            latest_checkpoint = checkpoint_management.latest_checkpoint(
                checkpoint_dir)
            checkpoint.restore(latest_checkpoint)
            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20)

            logging.info('testMultiWorkerTutorial successfully ends')

        model_path = os.path.join(self.get_temp_dir(), 'model.tf')
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
        with test_util.skip_if_error(self, errors_impl.UnavailableError):
            mpr_result = multi_process_runner.run(
                fn,
                multi_worker_test_base.create_cluster_spec(
                    num_workers=num_workers),
                args=(model_path, checkpoint_dir),
                return_output=True)

        self.assertTrue(
            any([
                'testMultiWorkerTutorial successfully ends' in msg
                for msg in mpr_result.stdout
            ]))

        def extract_accuracy(worker_id, input_string):
            match = re.match(
                r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id),
                input_string)
            return None if match is None else float(match.group(1))

        for worker_id in range(num_workers):
            accu_result = nest.map_structure(
                lambda x: extract_accuracy(worker_id, x),  # pylint: disable=cell-var-from-loop
                mpr_result.stdout)
            self.assertTrue(
                any(accu_result),
                'Every worker is supposed to have accuracy result.')
Esempio n. 13
0
    def testMultiWorkerTutorial(self, mode, shard_policy):
        """Test multi-worker training flow demo'ed in go/multi-worker-with-keras.

    This test should be kept in sync with the code samples in
    go/multi-worker-with-keras.

    Args:
      mode: Runtime mode.
      shard_policy: None or any of tf.data.experimental.AutoShardPolicy for
        testing.
    """
        if shard_policy is distribute_options.AutoShardPolicy.FILE:
            self.skipTest(
                'TensorSliceDataset is not shardable with FILE policy.')

        def mnist_dataset(batch_size):
            with self.skip_fetch_failure_exception():
                (x_train, y_train), _ = mnist.load_data()
            # The `x` arrays are in uint8 and have values in the range [0, 255].
            # We need to convert them to float32 with values in the range [0, 1]
            x_train = x_train / np.float32(255)
            y_train = y_train.astype(np.int64)
            train_dataset = dataset_ops.DatasetV2.from_tensor_slices(
                (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
            return train_dataset

        def build_and_compile_cnn_model():
            model = keras.Sequential([
                keras.layers.Input(shape=(28, 28)),
                keras.layers.Reshape(target_shape=(28, 28, 1)),
                keras.layers.Conv2D(32, 3, activation='relu'),
                keras.layers.Flatten(),
                keras.layers.Dense(128, activation='relu'),
                keras.layers.Dense(10)
            ])
            model.compile(loss=keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
                          optimizer=gradient_descent.SGD(learning_rate=0.001),
                          metrics=['accuracy'])
            return model

        per_worker_batch_size = 64

        single_worker_dataset = mnist_dataset(per_worker_batch_size)
        single_worker_model = build_and_compile_cnn_model()
        single_worker_model.fit(single_worker_dataset,
                                epochs=3,
                                steps_per_epoch=70)

        num_workers = 4

        def proc_func():
            global_batch_size = per_worker_batch_size * num_workers
            strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            )
            with strategy.scope():
                multi_worker_model = build_and_compile_cnn_model()

            callbacks = [
                keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
            ]

            multi_worker_dataset = mnist_dataset(global_batch_size)
            if shard_policy:
                options = dataset_ops.Options()
                options.experimental_distribute.auto_shard_policy = shard_policy
                multi_worker_dataset = multi_worker_dataset.with_options(
                    options)

            multi_worker_model.fit(multi_worker_dataset,
                                   epochs=2,
                                   steps_per_epoch=20,
                                   callbacks=callbacks)

        with test_util.skip_if_error(self, errors_impl.UnavailableError):
            mpr_result = multi_process_runner.run(
                proc_func,
                multi_worker_test_base.create_cluster_spec(
                    num_workers=num_workers),
                list_stdout=True)

        def extract_accuracy(worker_id, input_string):
            match = re.match(
                r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id),
                input_string)
            return None if match is None else float(match.group(1))

        for worker_id in range(num_workers):
            accu_result = nest.map_structure(
                lambda x: extract_accuracy(worker_id, x),  # pylint: disable=cell-var-from-loop
                mpr_result.stdout)
            self.assertTrue(
                any(accu_result),
                'Every worker is supposed to have accuracy result.')