def _make_execution_function(model, mode): """Makes function to run one step of distributed model execution.""" if context.executing_eagerly(): return _make_eager_execution_function(model, mode) strategy = model._distribution_strategy if not model._distributed_model: if model._compile_distribution: clone_model_on_replicas(model, strategy, make_callback_model=(mode == 'train')) else: _build_distributed_network(model, strategy) def _per_device_function(model): f = model._make_execution_function(mode) return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) with strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = strategy.extended.call_for_each_replica( _per_device_function, args=(model._distributed_model, )) if mode == 'train': # Initialize the variables in the replicated model. This is necessary for # multi-worker training because on some workers, initialization is not # needed. This method does initialization or waiting for initialization # according to the context object of distribute coordinator. distributed_training_utils.init_restore_or_wait_for_variables() # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=(mode != 'predict')) return K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_{}_function'.format(mode), **all_session_args)
def _make_execution_function(model, mode): """Makes function to run one step of distributed model execution.""" if context.executing_eagerly(): return _make_eager_execution_function(model, mode) strategy = model._distribution_strategy if not model._grouped_model: clone_model_on_replicas( model, strategy, make_callback_model=(mode == 'train')) def _per_device_function(model): f = model._make_execution_function(mode) return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) with strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = strategy.extended.call_for_each_replica( _per_device_function, args=(model._grouped_model,)) if mode == 'train': # Initialize the variables in the replicated model. This is necessary for # multi-worker training because on some workers, initialization is not # needed. This method does initialization or waiting for initialization # according to the context object of distribute coordinator. distributed_training_utils.init_restore_or_wait_for_variables() # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=(mode != 'predict')) return K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_{}_function'.format(mode), **all_session_args)
def fit_loop(model, iterator, epochs=100, verbose=1, callbacks=None, val_iterator=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """Fit loop for training with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. epochs: Number of times to iterate over the data verbose: Integer, Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_fit_loop(model, iterator, epochs, verbose, callbacks, initial_epoch, steps_per_epoch, val_iterator, validation_steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy, make_callback_model=True) def _per_device_fit_function(model): model._make_fit_function() return (model._fit_function.inputs, model._fit_function.outputs, model._fit_function.updates_op, model._fit_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model, )) # Initialize the variables in the replicated model. This is necessary for # multi-worker training because on some workers, initialization is not # needed. This method does initialization or waiting for initialization # according to the context object of distribute coordinator. distributed_training_utils.init_restore_or_wait_for_variables() # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_fit_function = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range( len(model.outputs) * current_strategy.num_replicas_in_sync) ] if not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) callbacks = cbks.configure_callbacks(callbacks, model, do_validation=do_validation, val_inputs=None, val_targets=None, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() assert steps_per_epoch is not None for epoch in range(initial_epoch, epochs): # Reset stateful metrics for m in model.stateful_metric_functions: m.reset_states() callbacks.on_epoch_begin(epoch) epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_fit_function(ins) except errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callbacks.model.stop_training: break if do_validation: val_outs = test_loop(model, val_iterator, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history