def test_adapt_short_sequence() -> None: sequence = IdentitySequence(3) with pytest.raises(check.CheckFailedError): with keras._build_enqueuer( sequence=sequence, workers=0, use_multiprocessing=False, max_queue_size=10, shard_rank=0, num_shards=4, repeat=False, shuffle=False, shuffle_seed=0, prior_batches_trained=0, ): pass
def _launch_fit(self) -> None: training_data = self.training_data if isinstance(training_data, tf.keras.utils.Sequence): # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size. enqueuer = keras._build_enqueuer( sequence=training_data, workers=self.context._fit_workers, use_multiprocessing=self.context._fit_use_multiprocessing, max_queue_size=self.context._fit_max_queue_size, shard_rank=self.context.distributed.get_rank(), num_shards=self.context.distributed.get_size(), repeat=True, shuffle=self.context._fit_shuffle, shuffle_seed=self.context.get_trial_seed(), prior_batches_trained=self.env.initial_workload. total_batches_processed, ) enqueuer.start() self.enqueuers.append(enqueuer) training_data = enqueuer.data() if isinstance(training_data, tf.data.Dataset): training_data = training_data.repeat() if self.context._fit_shuffle: logging.warning( "You set shuffle=True for a tf.data.Dataset, which will be ignored. " "Please call .shuffle() on your dataset instead.") self.model.fit( training_data, class_weight=self.context._fit_class_weight, callbacks=self.callback_list, shuffle=False, steps_per_epoch=sys.maxsize, epochs=IMPOSSIBLY_LARGE_EPOCHS, validation_split=0, verbose=0, workers=0, )
def _launch_evaluate(self) -> Any: validation_data = self.validation_data steps = None if isinstance(validation_data, tf.keras.utils.Sequence): # Calculate the length of our validation shard. steps = len(validation_data) if self.context.distributed.get_size() > 1: size = self.context.distributed.get_size() rank = self.context.distributed.get_rank() steps = steps // size + (1 if steps % size > rank else 0) # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size. enqueuer = keras._build_enqueuer( sequence=validation_data, workers=self.context._fit_workers, use_multiprocessing=self.context._fit_use_multiprocessing, max_queue_size=self.context._fit_max_queue_size, shard_rank=self.context.distributed.get_rank(), num_shards=self.context.distributed.get_size(), repeat=False, shuffle=False, shuffle_seed=0, prior_batches_trained=0, ) enqueuer.start() self.enqueuers.append(enqueuer) validation_data = enqueuer.data() if isinstance(validation_data, tf.data.Dataset): # Handle validation_steps, which in Keras only applies to tf.data.Datasets. steps = self.context._fit_validation_steps # Starting in TF 2.2 users may define custom test_step() that do # not use the model metrics. use_model_metrics = not ( version.parse(tf.__version__) >= version.parse("2.2.0") and is_tf2_enabled() and tf.executing_eagerly()) evaluate_kwargs = {} if use_model_metrics else {"return_dict": True} if self.env.test_mode: steps = 1 metrics_values = self.model.evaluate( validation_data, callbacks=self.callback_list, steps=steps, verbose=0, workers=0, **evaluate_kwargs, ) logging.debug( f"Worker finished model.evaluate() with metrics: {metrics_values}." ) # Clean up the enqueuer if we started one. if isinstance(self.validation_data, tf.keras.utils.Sequence): enqueuer.stop() self.enqueuers.remove(enqueuer) # A special side-effect of converting the keras sequence to a generator and passing # steps explicitly is that keras will exit our generator after N steps and the # Sequence.on_epoch_end() that normally runs after the last yield won't run at all # because the fit loop will call next() exactly `steps` times. So we try to match the # exact keras behavior by manually calling on_epoch_end() here. self.validation_data.on_epoch_end() # If the model was compiled with metrics=None, metrics_value will be a single value. if not isinstance(metrics_values, (tuple, list, dict)): metrics_values = (metrics_values, ) if use_model_metrics: metrics = make_logs(self.model, {}, metrics_values, ModeKeys.TEST, prefix="val_") else: check.is_instance(metrics_values, dict) metrics = {f"val_{k}": v for k, v in metrics_values.items()} return metrics
def _launch_evaluate(self) -> Any: validation_data = self.validation_data steps = None # Support the deprecated SequenceAdapter API. if isinstance(validation_data, keras.SequenceAdapter): # Ignore these settings and use the same settings as for the fit call. validation_data = validation_data.sequence if isinstance(validation_data, tf.keras.utils.Sequence): # Calculate the length of our validation shard. steps = len(validation_data) if self.context.distributed.get_size() > 1: size = self.context.distributed.get_size() rank = self.context.distributed.get_rank() steps = steps // size + (1 if steps % size > rank else 0) # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size. enqueuer = keras._build_enqueuer( sequence=validation_data, workers=self.context._fit_workers, use_multiprocessing=self.context._fit_use_multiprocessing, max_queue_size=self.context._fit_max_queue_size, shard_rank=self.context.distributed.get_rank(), num_shards=self.context.distributed.get_size(), repeat=False, shuffle=False, shuffle_seed=0, prior_batches_trained=0, ) enqueuer.start() self.enqueuers.append(enqueuer) validation_data = enqueuer.data() if isinstance(validation_data, tf.data.Dataset): # Handle validation_steps, which in Keras only applies to tf.data.Datasets. steps = self.context._fit_validation_steps # Starting in TF 2.2 users may define custom test_step() that do # not use the model metrics. use_model_metrics = version.parse( tf.__version__) < version.parse("2.2.0") evaluate_kwargs = {} if use_model_metrics else {"return_dict": True} metrics_values = self.model.evaluate( validation_data, callbacks=self.callback_list, steps=steps, verbose=0, workers=0, **evaluate_kwargs, ) logging.debug( f"Worker finished model.evaluate() with metrics: {metrics_values}." ) # If the model was compiled with metrics=None, metrics_value will be a single value. if not isinstance(metrics_values, (tuple, list, dict)): metrics_values = (metrics_values, ) if use_model_metrics: metrics = make_logs(self.model, {}, metrics_values, ModeKeys.TEST, prefix="val_") else: check.is_instance(metrics_values, dict) metrics = {f"val_{k}": v for k, v in metrics_values.items()} return metrics