def experimental_tpu_fit_loop(model, dataset, epochs=100, verbose=1, callbacks=None, initial_epoch=0, steps_per_epoch=None, val_dataset=None, validation_steps=None, validation_freq=1): """Fit loop for training with TPU tf.distribute.Strategy. Args: model: Keras Model instance. dataset: Dataset that returns inputs and targets epochs: Number of times to iterate over the data verbose: Integer, Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. val_dataset: Dataset for validation data. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. validation_freq: Only relevant if validation data is provided. Integer or `collections.abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the end of the 1st, 2nd, and 10th epochs. Returns: Returns `None`. Raises: ValueError: in case of invalid arguments. """ mode = ModeKeys.TRAIN current_strategy = model._distribution_strategy iteration_value = min(steps_per_epoch, current_strategy.extended.steps_per_run) steps_per_run = K.variable(value=iteration_value, dtype='int32', name='steps_per_run') # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. iterator = dist_utils.get_iterator(dataset, current_strategy) scope = dist_utils.distributed_scope(strategy=current_strategy, learning_phase=1) scope.__enter__() out_labels = model.metrics_names or [] step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy, out_labels) # Add initial dummy values for loss and other metric tensors. initial_loop_values = {} initial_loop_values['loss'] = tf.constant(1e7) for m in model._get_training_eval_metrics(): tensor = m.result() initial_loop_values[m.name] = tf.zeros(tensor.shape, tensor.dtype) ctx = current_strategy.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=steps_per_run, initial_loop_values=initial_loop_values) train_op = ctx.run_op output_tensors = ctx.last_step_outputs do_validation = bool(validation_steps) if model._compile_distribution: dist_utils._copy_weights_to_distributed_model(model, mode) callbacks = cbks.configure_callbacks(callbacks, model, do_validation=do_validation, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=verbose, count_mode='steps', mode=mode) # Calculate the steps each time on the device. steps_to_run = ( [current_strategy.extended.steps_per_run] * (steps_per_epoch // current_strategy.extended.steps_per_run)) if steps_per_epoch % current_strategy.extended.steps_per_run: steps_to_run.append(steps_per_epoch % current_strategy.extended.steps_per_run) target_steps = len(steps_to_run) callbacks._call_begin_hook(mode) initial_epoch = model._maybe_load_initial_epoch_from_ckpt( initial_epoch, mode) for epoch in range(initial_epoch, epochs): dist_utils._reset_metrics(model) callbacks.on_epoch_begin(epoch) epoch_logs = {} step_index = 0 prev_step_count = None current_step = 0 while current_step < target_steps: step_count = steps_to_run[current_step] batch_logs = { 'batch': step_index, 'size': 1, 'num_steps': step_count } callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs) if prev_step_count is None or step_count != prev_step_count: K.get_session().run(steps_per_run.assign(step_count)) prev_step_count = step_count try: _, outputs = K.batch_get_value([train_op, output_tensors]) except tf.errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break batch_logs.update(outputs) callbacks._call_batch_hook(mode, 'end', step_index, batch_logs) step_index = step_index + step_count current_step += 1 if callbacks.model.stop_training: break if (do_validation and training_utils_v1.should_run_validation( validation_freq, epoch)): logging.info('Running validation at fit epoch: %s', epoch) if model._compile_distribution: # Since we create a new clone from the original model we need to copy # the weights back to the original model before we can run validation. dist_utils._copy_weights_to_original_model( model, ModeKeys.TRAIN) val_outs = experimental_tpu_test_loop( # pylint: disable=undefined-variable model, val_dataset, steps=validation_steps, verbose=verbose, callbacks=callbacks) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for label, val_out in zip(out_labels, val_outs): epoch_logs['val_' + label] = val_out callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break model._successful_loop_finish = True callbacks._call_end_hook(mode) if model._compile_distribution: # Copy the weights back from the replicated model to the original model. dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) scope.__exit__(None, None, None) return model.history
def model_iteration(model, inputs, targets=None, sample_weights=None, batch_size=None, epochs=1, verbose=1, callbacks=None, val_inputs=None, val_targets=None, val_sample_weights=None, shuffle=True, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, mode=ModeKeys.TRAIN, validation_in_fit=False, prepared_feed_values_from_dataset=False, steps_name="steps", **kwargs): """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. Args: model: Keras Model instance. inputs: Either a list or dictionary of arrays, or a dataset instance. targets: List/dictionary of input arrays. sample_weights: Optional list of sample weight arrays. batch_size: Integer batch size or None if unknown. epochs: Number of times to iterate over the data verbose: 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended when not running interactively (eg, in a production environment). callbacks: List of callbacks to be called during training val_inputs: Either a list or dictionary of arrays, or a dataset instance. val_targets: List/dictionary of target arrays. val_sample_weights: Optional list of sample weight arrays. shuffle: Whether to shuffle the data at the beginning of each epoch concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. validation_freq: Only relevant if validation data is provided. Integer or `collections.abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the end of the 1st, 2nd, and 10th epochs. mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. validation_in_fit: if true, then this method is invoked from within training iteration (for validation). In the case where `val_inputs` is a dataset, this flag indicates that its iterator and feed values are already created so should properly reuse resources. prepared_feed_values_from_dataset: if True, `inputs` is a list of feed tensors returned from `_prepare_feed_values` call on the validation dataset, so do not call it again on `inputs`. Should only be used for inline validation (i.e., only if `validation_in_fit` is also True). steps_name: The string name of the steps argument, either `steps`, `validation_steps`, or `steps_per_epoch`. Only used for error message formatting. **kwargs: Additional arguments for backwards compatibility. Returns: - In TRAIN mode: `History` object. - In TEST mode: Evaluation metrics. - In PREDICT mode: Outputs of the Model called on inputs. Raises: ValueError: in case of invalid arguments. """ # Backwards compatibility. if "steps" in kwargs: steps_per_epoch = kwargs.pop("steps") if kwargs: raise TypeError("Unknown arguments: %s" % (kwargs, )) # In case we were passed a dataset, we extract symbolic tensors from it. reset_dataset_after_each_epoch = False input_iterator = None is_dataset = isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)) # TODO(fchollet): consider moving `steps_per_epoch` inference to # _standardize_user_data and set reset_dataset_after_each_epoch as an # attribute on the dataset instance. if is_dataset: if steps_per_epoch is None: reset_dataset_after_each_epoch = True steps_per_epoch = training_utils_v1.infer_steps_for_dataset( model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name, ) input_iterator = _get_iterator(inputs, model._distribution_strategy) # Enter tf.distribute.Strategy scope. if model._distribution_strategy: scope = distributed_training_utils_v1.distributed_scope( strategy=model._distribution_strategy, learning_phase=(1 if mode == ModeKeys.TRAIN else 0), ) scope.__enter__() use_steps = is_dataset or steps_per_epoch is not None do_validation = val_inputs is not None # Prepare input data. inputs = input_iterator or inputs if validation_in_fit and prepared_feed_values_from_dataset: # When invoking validation in training loop, avoid creating iterator and # list of feed values for the same validation dataset multiple times # (which essentially would call `iterator.get_next()` that slows down # execution and leads to OOM errors eventually. ins = inputs else: ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode) # `ins` is a function when a distribute strategy is used in Eager mode. # In that case `is_dataset` is True. The code branches that have # requirements about the type of `ins` do not trigger in the distributed # case. if not is_dataset: num_samples_or_steps = _get_num_samples_or_steps( ins, batch_size, steps_per_epoch) else: num_samples_or_steps = steps_per_epoch # Update sample_weight_mode of the model if sample_weights is specified by # the user. We need to call this function after we have a handle on the # inputs (both numpy arrays and datasets) in order to determine if the user # has specified sample_weights. _update_sample_weight_mode(model, mode, ins) # Get step function and loop type. As part of building the execution # function we recompile the metrics based on the updated # sample_weight_mode value. f = _make_execution_function(model, mode) # Prepare validation data. Hold references to the iterator and the input # list to properly reinitialize and reuse in multiple validation passes. val_iterator = None if isinstance(val_inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)): if validation_steps is None: # Because we pass an iterator feed instead of a Dataset to the eval # model_iteration() call, it will not trigger the dataset-input path # that determines the number of steps required. To avoid this issue, # set validation_steps here if validation_steps is None. validation_steps = training_utils_v1.infer_steps_for_dataset( model, val_inputs, validation_steps, epochs=epochs, steps_name="validation_steps", ) val_iterator = _get_iterator(val_inputs, model._distribution_strategy) val_inputs = _prepare_feed_values(model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST) # Get num steps for printing. val_samples_or_steps = validation_steps else: # Get num samples for printing. val_samples_or_steps = (val_inputs and tf.nest.flatten(val_inputs)[0].shape[0] or None) if mode == ModeKeys.TRAIN and verbose: _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset) # Configure callbacks. count_mode = "steps" if use_steps else "samples" callbacks = cbks.configure_callbacks( callbacks, model, do_validation=do_validation, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_samples_or_steps, count_mode=count_mode, verbose=verbose, mode=mode, ) # Find beforehand arrays that need sparse-to-dense conversion. if issparse is not None and not use_steps: indices_for_conversion_to_dense = [] feed = _get_model_feed(model, mode) for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)): if issparse(input_data) and not backend.is_sparse(feed_tensor): indices_for_conversion_to_dense.append(i) # Select aggregation method. if mode == ModeKeys.PREDICT: aggregator = training_utils_v1.OutputsAggregator( use_steps, num_samples=None if steps_per_epoch else num_samples_or_steps, steps=steps_per_epoch, ) else: aggregator = training_utils_v1.MetricsAggregator( use_steps, num_samples=None if steps_per_epoch else num_samples_or_steps, steps=steps_per_epoch, ) if model._compile_distribution: distributed_training_utils_v1._copy_weights_to_distributed_model( model, mode) callbacks.model.stop_training = False callbacks._call_begin_hook(mode) initial_epoch = model._maybe_load_initial_epoch_from_ckpt( initial_epoch, mode) for epoch in range(initial_epoch, epochs): if callbacks.model.stop_training: break # Setup work for each epoch epoch_logs = {} if mode != ModeKeys.PREDICT: # Collecting and resetting metrics has non-zero cost and will # needlessly slow down model.predict. model.reset_metrics() if mode == ModeKeys.TRAIN: callbacks.on_epoch_begin(epoch, epoch_logs) if use_steps: # Step-wise loop. if steps_per_epoch is None: # Loop over dataset until `OutOfRangeError` is raised. target_steps = np.inf else: # Loop over dataset for the specified number of steps. target_steps = steps_per_epoch step = 0 while step < target_steps: batch_logs = {"batch": step, "size": 1} callbacks._call_batch_hook(mode, "begin", step, batch_logs) # Get outputs. try: # `ins` can be callable in tf.distribute.Strategy + eager # case. if not callable(ins) or ( model._distribution_strategy and not distributed_training_utils_v1. is_distributing_by_cloning( # noqa: E501 model)): actual_inputs = ins else: actual_inputs = ins() batch_outs = f(actual_inputs) except tf.errors.OutOfRangeError: if is_dataset: # The dataset passed by the user ran out of batches. # Now we know the cardinality of the dataset. If # steps_per_epoch was specified, then running out of # data is unexpected, so we stop training and inform the # user. if steps_per_epoch: callbacks.model.stop_training = True logging.warning( "Your dataset ran out of data; interrupting " "training. Make sure that your dataset can " "generate at least `%s * epochs` batches (in " "this case, %d batches). You may need to use " "the repeat() function when building your " "dataset." % (steps_name, steps_per_epoch * epochs)) elif step > 0: steps_per_epoch = step aggregator.steps = steps_per_epoch else: # We ran out of batches while the user passed an # iterator (legacy). callbacks.model.stop_training = True logging.warning( "Your dataset iterator ran out of data; " "interrupting training. Make sure that your " "iterator can generate at least `%s * epochs` " "batches (in this case, %d batches). You may need " "to use the repeat() function when building your " "dataset." % (steps_name, steps_per_epoch * epochs)) break if not isinstance(batch_outs, list): batch_outs = [batch_outs] if model._distribution_strategy: batch_outs = distributed_training_utils_v1._per_replica_aggregate_batch( # noqa: E501 model._distribution_strategy, batch_outs, model, mode) # Aggregate results. if step == 0: aggregator.create(batch_outs) aggregator.aggregate(batch_outs) # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, "end", step, batch_logs) step += 1 if callbacks.model.stop_training: break else: # Sample-wise loop. index_array = np.arange(num_samples_or_steps) if shuffle == "batch": index_array = training_utils_v1.batch_shuffle( index_array, batch_size) elif shuffle: np.random.shuffle(index_array) batches = make_batches(num_samples_or_steps, batch_size) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] # Slice into a batch. if len(batches) == 1: # If we only have one batch, do not slice. This takes care # of composite tensors in non-Dataset modes; we currently # don't support slicing them. # TODO(b/133517906): Add slicing support. ins_batch = ins else: try: if ins and isinstance(ins[-1], int): # Do not slice the training phase flag. ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: ins_batch = slice_arrays(ins, batch_ids) except TypeError: raise TypeError("TypeError while preparing batch. " "If using HDF5 input data, " 'pass shuffle="batch".') # Sparse to dense conversion. if issparse is not None: for i in indices_for_conversion_to_dense: ins_batch[i] = ins_batch[i].toarray() # Callbacks batch_begin. batch_logs = {"batch": batch_index, "size": len(batch_ids)} callbacks._call_batch_hook(mode, "begin", batch_index, batch_logs) # Get outputs. batch_outs = f(ins_batch) if not isinstance(batch_outs, list): batch_outs = [batch_outs] # Aggregate results. if batch_index == 0: aggregator.create(batch_outs) aggregator.aggregate(batch_outs, batch_start, batch_end) # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, "end", batch_index, batch_logs) if callbacks.model.stop_training: break aggregator.finalize() results = aggregator.results epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) if len(results) == 1: results = results[0] # Run the test loop every `validation_freq` epochs during training. if (do_validation and training_utils_v1.should_run_validation( validation_freq, epoch) and not callbacks.model.stop_training): if model._compile_distribution: # Since we create a new clone from the original model we need to # copy the weights back to the original model before we can run # validation. distributed_training_utils_v1._copy_weights_to_original_model( model, ModeKeys.TRAIN) val_results = model_iteration( model, val_inputs, targets=val_targets, sample_weights=val_sample_weights, batch_size=batch_size, steps_per_epoch=validation_steps, callbacks=callbacks, verbose=0, mode=ModeKeys.TEST, validation_in_fit=True, prepared_feed_values_from_dataset=(val_iterator is not None), steps_name="validation_steps", ) if not isinstance(val_results, list): val_results = [val_results] epoch_logs = cbks.make_logs(model, epoch_logs, val_results, mode, prefix="val_") if val_iterator and epoch < epochs - 1: _reinitialize_iterator(val_iterator, model._distribution_strategy) if mode == ModeKeys.TRAIN: # Epochs only apply to `fit`. callbacks.on_epoch_end(epoch, epoch_logs) # Reinitialize dataset iterator for the next epoch. if reset_dataset_after_each_epoch and epoch < epochs - 1: _reinitialize_iterator(input_iterator, model._distribution_strategy) model._successful_loop_finish = True callbacks._call_end_hook(mode) if model._distribution_strategy: if model._compile_distribution: # TODO(priyag, psv): Copy back metrics to the original model as # well? distributed_training_utils_v1._copy_weights_to_original_model( model, mode) scope.__exit__(None, None, None) if mode == ModeKeys.TRAIN: return model.history return results
def model_iteration(model, data, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=False, initial_epoch=0, mode=ModeKeys.TRAIN, batch_size=None, steps_name='steps', **kwargs): """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. Arguments: model: Keras Model instance. data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or `(x, y, sample_weights)`) or a generator or `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. epochs: Number of times to iterate over the data. verbose: 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended when not running interactively (eg, in a production environment). callbacks: List of callbacks to be called during training. validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or `(x, y, sample_weights)`) or a generator or `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. validation_steps: Total number of steps (batches of samples) before declaring validation finished. validation_freq: Only relevant if validation data is provided. Integer or `collections.abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the end of the 1st, 2nd, and 10th epochs. class_weight: Dictionary mapping class indices to a weight for the class. max_queue_size: Integer. Maximum size for the generator queue. If unspecified, `max_queue_size` will default to 10. workers: Integer. Maximum number of processes to spin up when using process-based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. use_multiprocessing: Boolean. If `True`, use process-based threading. If unspecified, `use_multiprocessing` will default to `False`. Note that because this implementation relies on multiprocessing, you should not pass non-picklable arguments to the generator as they can't be passed easily to children processes. shuffle: Boolean. Whether to shuffle the order of the batches at the beginning of each epoch. Only used with instances of `Sequence` (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not `None`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run). mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. batch_size: Integer batch size or None if unknown. Will only be used if `data` is in NumPy/Tensor format. steps_name: The string name of the steps argument, either `steps`, `validation_steps`, or `steps_per_epoch`. Only used for error message formatting. **kwargs: Additional arguments for backwards compatibility. `steps` is accepted as an alias for `steps_per_epoch`. Returns: - In TRAIN mode: `History` object. - In TEST mode: Evaluation metrics. - In PREDICT mode: Outputs of the Model called on inputs. Raises: ValueError: in case of invalid arguments. """ if 'steps' in kwargs: steps_per_epoch = kwargs['steps'] # Determine the number of steps per epoch and whether we should reset the # dataset at the end of each epoch. reset_dataset_after_each_epoch = False original_dataset = None is_dataset = isinstance(data, (tf.data.Dataset, tf.compat.v1.data.Dataset)) if is_dataset: original_dataset = data if steps_per_epoch is None: reset_dataset_after_each_epoch = True steps_per_epoch = training_utils_v1.infer_steps_for_dataset( model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name) # Convert to a format that supports `next(generator)`. generator, steps_per_epoch = convert_to_generator_like( data, steps_per_epoch=steps_per_epoch, batch_size=batch_size, epochs=epochs - initial_epoch, shuffle=shuffle) do_validation = validation_data is not None is_sequence = isinstance(generator, data_utils.Sequence) _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, steps_per_epoch, validation_data, validation_steps, mode, kwargs) batch_function = _make_execution_function(model, mode, class_weight=class_weight) # Create the queue for the generator. enqueuer = None if not is_dataset: generator, enqueuer = _make_enqueued_generator( generator, workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, shuffle=shuffle) num_samples_or_steps, use_steps = _get_num_samples_or_steps( data, steps_per_epoch) count_mode = 'steps' if use_steps else 'samples' callbacks = cbks.configure_callbacks(callbacks, model, do_validation=do_validation, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, samples=num_samples_or_steps, count_mode=count_mode, verbose=verbose, mode=mode) if mode == ModeKeys.PREDICT: aggregator = training_utils_v1.OutputsAggregator(True, steps=steps_per_epoch) else: aggregator = training_utils_v1.MetricsAggregator(True, steps=steps_per_epoch) should_set_learning_phase = tf.executing_eagerly() and model.run_eagerly if should_set_learning_phase: learning_phase_scope = backend.eager_learning_phase_scope( 1 if mode == ModeKeys.TRAIN else 0) learning_phase_scope.__enter__() callbacks.model.stop_training = False callbacks._call_begin_hook(mode) initial_epoch = model._maybe_load_initial_epoch_from_ckpt( initial_epoch, mode) for epoch in range(initial_epoch, epochs): if callbacks.model.stop_training: break # Setup work for each epoch. model.reset_metrics() epoch_logs = {} if mode == ModeKeys.TRAIN: callbacks.on_epoch_begin(epoch, epoch_logs) if steps_per_epoch is None: # Loop over dataset until `OutOfRangeError` is raised. target_steps = np.inf else: # Loop over dataset for the specified number of steps. target_steps = steps_per_epoch step = 0 while step < target_steps: batch_data = _get_next_batch(generator) if batch_data is None: if is_dataset: # The dataset passed by the user ran out of batches. # Now we know the cardinality of the dataset. # If steps_per_epoch was specified, then running out of data is # unexpected, so we stop training and inform the user. if steps_per_epoch: callbacks.model.stop_training = True logging.warning( 'Your dataset ran out of data; interrupting training. ' 'Make sure that your dataset can generate at least ' '`%s * epochs` batches (in this case, %d batches). ' 'You may need to use the repeat() function when ' 'building your dataset.' % (steps_name, steps_per_epoch * epochs)) elif step > 0: steps_per_epoch = step aggregator.steps = steps_per_epoch else: # We ran out of batches while the user passed an iterator (legacy). callbacks.model.stop_training = True logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your iterator ' 'can generate at least `%s * epochs` ' 'batches (in this case, %d batches). You may need to' 'use the repeat() function when building your ' 'dataset.' % (steps_name, steps_per_epoch * epochs)) break # `batch_size` used for validation data if validation # data is NumPy/EagerTensors. batch_size = int(tf.nest.flatten(batch_data)[0].shape[0]) # Callbacks batch begin. batch_logs = {'batch': step, 'size': batch_size} callbacks._call_batch_hook(mode, 'begin', step, batch_logs) is_deferred = not model._is_compiled batch_outs = batch_function(*batch_data) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if step == 0: aggregator.create(batch_outs) if is_deferred: # Set callbacks params. We do this here when model is compiled only # in the first iteration of this loop (deferred build scenario). cbks.set_callback_parameters( callbacks, model, do_validation=do_validation, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_samples_or_steps, verbose=verbose, mode=mode) # Aggregate results. aggregator.aggregate(batch_outs) # Callbacks batch end. batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) callbacks._call_batch_hook(mode, 'end', step, batch_logs) step += 1 if callbacks.model.stop_training: break aggregator.finalize() results = aggregator.results epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) if len(results) == 1: results = results[0] # Run the test loop every epoch during training. if (do_validation and training_utils_v1.should_run_validation( validation_freq, epoch) and not callbacks.model.stop_training): val_results = model_iteration( model, validation_data, steps_per_epoch=validation_steps, batch_size=batch_size, class_weight=class_weight, workers=workers, use_multiprocessing=use_multiprocessing, max_queue_size=max_queue_size, callbacks=callbacks, verbose=verbose, mode=ModeKeys.TEST, steps_name='validation_steps') if not isinstance(val_results, list): val_results = [val_results] epoch_logs = cbks.make_logs(model, epoch_logs, val_results, mode, prefix='val_') if mode == ModeKeys.TRAIN: # Epochs only apply to `fit`. callbacks.on_epoch_end(epoch, epoch_logs) # Recreate dataset iterator for the next epoch. if reset_dataset_after_each_epoch and epoch < epochs - 1: generator = tf.compat.v1.data.make_one_shot_iterator( original_dataset) model._successful_loop_finish = True callbacks._call_end_hook(mode) if enqueuer is not None: enqueuer.stop() if should_set_learning_phase: learning_phase_scope.__exit__(None, None, None) if mode == ModeKeys.TRAIN: return model.history return results