Ejemplo n.º 1
0
def test_adapt_short_sequence() -> None:
    sequence = IdentitySequence(3)
    with pytest.raises(check.CheckFailedError):
        with keras._build_enqueuer(
                sequence=sequence,
                workers=0,
                use_multiprocessing=False,
                max_queue_size=10,
                shard_rank=0,
                num_shards=4,
                repeat=False,
                shuffle=False,
                shuffle_seed=0,
                prior_batches_trained=0,
        ):
            pass
Ejemplo n.º 2
0
    def _launch_fit(self) -> None:
        training_data = self.training_data

        if isinstance(training_data, tf.keras.utils.Sequence):
            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=training_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=True,
                shuffle=self.context._fit_shuffle,
                shuffle_seed=self.context.get_trial_seed(),
                prior_batches_trained=self.env.initial_workload.
                total_batches_processed,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            training_data = enqueuer.data()

        if isinstance(training_data, tf.data.Dataset):
            training_data = training_data.repeat()
            if self.context._fit_shuffle:
                logging.warning(
                    "You set shuffle=True for a tf.data.Dataset, which will be ignored. "
                    "Please call .shuffle() on your dataset instead.")

        self.model.fit(
            training_data,
            class_weight=self.context._fit_class_weight,
            callbacks=self.callback_list,
            shuffle=False,
            steps_per_epoch=sys.maxsize,
            epochs=IMPOSSIBLY_LARGE_EPOCHS,
            validation_split=0,
            verbose=0,
            workers=0,
        )
Ejemplo n.º 3
0
    def _launch_evaluate(self) -> Any:
        validation_data = self.validation_data
        steps = None

        if isinstance(validation_data, tf.keras.utils.Sequence):
            # Calculate the length of our validation shard.
            steps = len(validation_data)
            if self.context.distributed.get_size() > 1:
                size = self.context.distributed.get_size()
                rank = self.context.distributed.get_rank()
                steps = steps // size + (1 if steps % size > rank else 0)

            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=validation_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=False,
                shuffle=False,
                shuffle_seed=0,
                prior_batches_trained=0,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            validation_data = enqueuer.data()

        if isinstance(validation_data, tf.data.Dataset):
            # Handle validation_steps, which in Keras only applies to tf.data.Datasets.
            steps = self.context._fit_validation_steps

        # Starting in TF 2.2 users may define custom test_step() that do
        # not use the model metrics.
        use_model_metrics = not (
            version.parse(tf.__version__) >= version.parse("2.2.0")
            and is_tf2_enabled() and tf.executing_eagerly())
        evaluate_kwargs = {} if use_model_metrics else {"return_dict": True}

        if self.env.test_mode:
            steps = 1

        metrics_values = self.model.evaluate(
            validation_data,
            callbacks=self.callback_list,
            steps=steps,
            verbose=0,
            workers=0,
            **evaluate_kwargs,
        )
        logging.debug(
            f"Worker finished model.evaluate() with metrics: {metrics_values}."
        )

        # Clean up the enqueuer if we started one.
        if isinstance(self.validation_data, tf.keras.utils.Sequence):
            enqueuer.stop()
            self.enqueuers.remove(enqueuer)

            # A special side-effect of converting the keras sequence to a generator and passing
            # steps explicitly is that keras will exit our generator after N steps and the
            # Sequence.on_epoch_end() that normally runs after the last yield won't run at all
            # because the fit loop will call next() exactly `steps` times.  So we try to match the
            # exact keras behavior by manually calling on_epoch_end() here.
            self.validation_data.on_epoch_end()

        # If the model was compiled with metrics=None, metrics_value will be a single value.
        if not isinstance(metrics_values, (tuple, list, dict)):
            metrics_values = (metrics_values, )

        if use_model_metrics:
            metrics = make_logs(self.model, {},
                                metrics_values,
                                ModeKeys.TEST,
                                prefix="val_")
        else:
            check.is_instance(metrics_values, dict)
            metrics = {f"val_{k}": v for k, v in metrics_values.items()}

        return metrics
Ejemplo n.º 4
0
    def _launch_evaluate(self) -> Any:
        validation_data = self.validation_data
        steps = None

        # Support the deprecated SequenceAdapter API.
        if isinstance(validation_data, keras.SequenceAdapter):
            # Ignore these settings and use the same settings as for the fit call.
            validation_data = validation_data.sequence

        if isinstance(validation_data, tf.keras.utils.Sequence):
            # Calculate the length of our validation shard.
            steps = len(validation_data)
            if self.context.distributed.get_size() > 1:
                size = self.context.distributed.get_size()
                rank = self.context.distributed.get_rank()
                steps = steps // size + (1 if steps % size > rank else 0)

            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=validation_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=False,
                shuffle=False,
                shuffle_seed=0,
                prior_batches_trained=0,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            validation_data = enqueuer.data()

        if isinstance(validation_data, tf.data.Dataset):
            # Handle validation_steps, which in Keras only applies to tf.data.Datasets.
            steps = self.context._fit_validation_steps

        # Starting in TF 2.2 users may define custom test_step() that do
        # not use the model metrics.
        use_model_metrics = version.parse(
            tf.__version__) < version.parse("2.2.0")
        evaluate_kwargs = {} if use_model_metrics else {"return_dict": True}

        metrics_values = self.model.evaluate(
            validation_data,
            callbacks=self.callback_list,
            steps=steps,
            verbose=0,
            workers=0,
            **evaluate_kwargs,
        )
        logging.debug(
            f"Worker finished model.evaluate() with metrics: {metrics_values}."
        )

        # If the model was compiled with metrics=None, metrics_value will be a single value.
        if not isinstance(metrics_values, (tuple, list, dict)):
            metrics_values = (metrics_values, )

        if use_model_metrics:
            metrics = make_logs(self.model, {},
                                metrics_values,
                                ModeKeys.TEST,
                                prefix="val_")
        else:
            check.is_instance(metrics_values, dict)
            metrics = {f"val_{k}": v for k, v in metrics_values.items()}

        return metrics