def step_fn(ctx, inputs, targets): """Clones the model and calls make_predict_function.""" # TODO(anjalisridhar): Support predict input correctly as it will not # contain targets, only inputs. del targets # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_towers(model, current_strategy, make_callback_model=False, inputs=inputs) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_predict_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) for label, output in zip(model.output_names, combined_fn.outputs): ctx.set_last_step_output(label, output) return combined_fn.updates_op
def step_fn(ctx, *inputs): """Clones the model and calls make_predict_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, mode=_Mode.PREDICT) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_predict_function, args=(model._grouped_model_predict,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) for label, output in zip(model.output_names, combined_fn.outputs): ctx.set_last_step_output(label, output) return combined_fn.updates_op
def _step_fn(ctx, inputs): """A step fn that returns update ops.""" inputs, targets = inputs _build_model(strategy, model, mode, inputs, targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = strategy.extended.call_for_each_replica( _per_device_execution_function, args=(distributed_training_utils.get_distributed_model( model, mode), mode)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_' + str(mode) + '_function', **all_session_args) for label, output in zip(output_labels, combined_fn.outputs): if label == 'loss': reduce_op = ds_reduce_util.ReduceOp.SUM else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. clone_model_on_towers( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be # something else for different outputs. out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): ctx.set_last_step_output(label, output, aggregation=distribute_lib.get_loss_reduction()) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_predict_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_replicas(model, current_strategy, make_callback_model=False, inputs=inputs, mode=_Mode.PREDICT) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_predict_function, args=(model._grouped_model_predict, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) for label, output in zip(model.output_names, combined_fn.outputs): ctx.set_last_step_output(label, output) return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_eval_function.""" inputs, targets = inputs if model._compile_distribution: distributed_training_utils.clone_model_on_replicas( model, current_strategy, mode=mode, inputs=inputs, targets=targets) else: distributed_training_utils._build_distributed_network( model, current_strategy, mode, inputs, targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_eval_function, args=(model._distributed_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_predict_function.""" if model._compile_distribution: distributed_training_utils.clone_model_on_replicas( model, current_strategy, ModeKeys.PREDICT, inputs=inputs) else: distributed_training_utils._build_distributed_network( model, current_strategy, ModeKeys.PREDICT, inputs) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args ) = current_strategy.extended.call_for_each_replica( _per_device_predict_function, args=(model._distributed_model_predict, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) for label, output in zip(model.output_names, combined_fn.outputs): ctx.set_last_step_output(label, output) return combined_fn.updates_op
def _make_eager_execution_function(model, mode): """Makes function to run one step of distributed model eager execution.""" strategy = model._distribution_strategy if not model._grouped_model: clone_model_on_replicas( model, strategy, make_callback_model=(mode == 'train')) def _per_device_function(model): f = model._make_execution_function(mode) return (f.inputs, f.outputs) # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using # the global one. with K.get_graph().as_default(), strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs) = strategy.call_for_each_replica( _per_device_function, args=(model._grouped_model,)) # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of inptus/outputs # on all the devices over which the model is distributed. (all_inputs, all_outputs, _, _) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, with_loss_tensor=(mode != 'predict')) return K.function( all_inputs, all_outputs, name='eager_distributed_{}_function'.format(mode))
def step_fn(ctx, inputs, targets): """Clones the model and calls make_predict_function.""" # TODO(anjalisridhar): Support predict input correctly as it will not # contain targets, only inputs. del targets # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_towers( model, current_strategy, make_callback_model=False, inputs=inputs) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_predict_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) for label, output in zip(model.output_names, combined_fn.outputs): ctx.set_last_step_output(label, output) return combined_fn.updates_op
def _get_eager_execution_function(model, mode): """Get function to run one step of distributed model eager execution.""" strategy = model._distribution_strategy if not model._grouped_model: clone_model_on_replicas(model, strategy, make_callback_model=(mode == 'train')) def _per_device_function(model): f = model._get_execution_function(mode) return (f.inputs, f.outputs) # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using # the global one. with K.get_graph().as_default(), strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs) = strategy.call_for_each_replica( _per_device_function, args=(model._grouped_model, )) # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of inptus/outputs # on all the devices over which the model is distributed. (all_inputs, all_outputs, _, _) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, with_loss_tensor=(mode != 'predict')) return K.function(all_inputs, all_outputs, name='eager_distributed_{}_function'.format(mode))
def step_fn(ctx, inputs): """Clones the model and calls make_predict_function.""" if model._compile_distribution: distributed_training_utils.clone_model_on_replicas( model, current_strategy, mode, inputs=inputs) else: distributed_training_utils._build_distributed_network( model, current_strategy, mode, inputs) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_predict_function, args=(distributed_training_utils.get_distributed_model( model, ModeKeys.PREDICT),)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) for label, output in zip(model.output_names, combined_fn.outputs): ctx.set_last_step_output(label, output) return combined_fn.updates_op
def _make_execution_function(model, mode): """Makes function to run one step of distributed model execution.""" if context.executing_eagerly(): return _make_eager_execution_function(model, mode) strategy = model._distribution_strategy if not model._distributed_model: if model._compile_distribution: clone_model_on_replicas(model, strategy, make_callback_model=(mode == 'train')) else: _build_distributed_network(model, strategy) def _per_device_function(model): f = model._make_execution_function(mode) return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) with strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = strategy.extended.call_for_each_replica( _per_device_function, args=(model._distributed_model, )) if mode == 'train': # Initialize the variables in the replicated model. This is necessary for # multi-worker training because on some workers, initialization is not # needed. This method does initialization or waiting for initialization # according to the context object of distribute coordinator. distributed_training_utils.init_restore_or_wait_for_variables() # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=(mode != 'predict')) return K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_{}_function'.format(mode), **all_session_args)
def _make_execution_function(model, mode): """Makes function to run one step of distributed model execution.""" if context.executing_eagerly(): return _make_eager_execution_function(model, mode) strategy = model._distribution_strategy if not model._grouped_model: clone_model_on_replicas( model, strategy, make_callback_model=(mode == 'train')) def _per_device_function(model): f = model._make_execution_function(mode) return (f.inputs, f.outputs, f.updates_op, f.session_kwargs) with strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = strategy.extended.call_for_each_replica( _per_device_function, args=(model._grouped_model,)) if mode == 'train': # Initialize the variables in the replicated model. This is necessary for # multi-worker training because on some workers, initialization is not # needed. This method does initialization or waiting for initialization # according to the context object of distribute coordinator. distributed_training_utils.init_restore_or_wait_for_variables() # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=(mode != 'predict')) return K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_{}_function'.format(mode), **all_session_args)
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" inputs, targets = inputs if model._compile_distribution: distributed_training_utils.clone_model_on_replicas( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=distributed_training_utils.ModeKeys.TRAIN) else: distributed_training_utils._build_distributed_network( model, current_strategy, inputs, targets, mode=distributed_training_utils.ModeKeys.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args ) = current_strategy.extended.call_for_each_replica( _per_device_fit_function, args=(model._distributed_model_train, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """A step fn that returns update ops.""" if mode == ModeKeys.PREDICT: targets = None else: inputs, targets = inputs if model._compile_distribution: distributed_training_utils.clone_model_on_replicas( model, strategy, mode, inputs=inputs, targets=targets) else: distributed_training_utils._build_distributed_network( model, strategy, mode, inputs, targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = strategy.extended.call_for_each_replica( _per_device_execution_function, args=(distributed_training_utils.get_distributed_model(model, mode),)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_' + str(mode) + '_function', **all_session_args) for label, output in zip(output_labels, combined_fn.outputs): if mode == ModeKeys.PREDICT: ctx.set_last_step_output(label, output) else: if label == 'loss': reduce_op = ds_reduce_util.ReduceOp.SUM else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. inputs, targets = inputs clone_model_on_replicas( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=_Mode.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model_train,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_fit_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_replicas( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=_Mode.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, model._grouped_model_train) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_towers( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_eval_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. inputs, targets = inputs clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=_Mode.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_eval_function, args=(model._grouped_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" inputs, targets = inputs if model._compile_distribution: distributed_training_utils.clone_model_on_replicas( model, current_strategy, ModeKeys.TRAIN, inputs=inputs, targets=targets) else: distributed_training_utils._build_distributed_network( model, current_strategy, ModeKeys.TRAIN, inputs, targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_fit_function, args=(model._distributed_model_train,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_eval_function.""" inputs, targets = inputs if model._compile_distribution: distributed_training_utils. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=distributed_training_utils.ModeKeys.TEST) else: distributed_training_utils._build_distributed_network( model, current_strategy, inputs, targets, mode=distributed_training_utils.ModeKeys.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_eval_function, args=(model._distributed_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_eval_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=_Mode.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_eval_function, args=(model._grouped_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. clone_model_on_towers(model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) combined_fn = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be # something else for different outputs. out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): ctx.set_last_step_output( label, output, aggregation=distribute_lib.get_loss_reduction()) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def fit_loop(model, inputs, targets, epochs=100, verbose=1, callbacks=None, val_inputs=None, val_targets=None, callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """fit function when using DistributionStrategy for training. Arguments: model: Keras Model instance. inputs: List of input arrays. targets: List of target arrays. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_inputs: List of input arrays. val_targets: List of target arrays. callback_metrics: List of strings, the display names of the metrics passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy def _per_device_train_function(model): model._make_train_function() return (model.train_function.inputs, model.train_function.outputs, model.train_function.updates_op, model.train_function.session_kwargs) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_train_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) # Unwrap all the per device values returned from `call_for_each_tower`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_train_function = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range(len(model.outputs) * current_strategy.num_towers) ] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True if steps_per_epoch is None: raise ValueError('Can only use `validation_steps` ' 'when doing step-wise ' 'training, i.e. `steps_per_epoch` ' 'must be set.') out_labels = model.metrics_names if do_validation: callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] else: callback_metrics = copy.copy(out_labels) model.history = cbks.History() all_callbacks = [ cbks.BaseLogger(stateful_metrics=model.stateful_metric_names) ] if verbose: # We assume that `steps_per_epoch` is always set since we have to use # Datasets. count_mode = 'steps' all_callbacks.append( cbks.ProgbarLogger(count_mode, stateful_metrics=model.stateful_metric_names)) all_callbacks += (callbacks or []) + [model.history] callbacks = cbks.CallbackList(all_callbacks) out_labels = out_labels or [] # We set the callback model to an instance of the `DistributedModel` that we # create in the `compile` call. The `DistributedModel` is initialized with # the first replicated model. We need to set the callback model to a # DistributedModel to allow us to override saving and loading weights when # we checkpoint the model during training. callback_model = model._replicated_model callbacks.set_model(callback_model) callbacks.set_params({ 'epochs': epochs, 'steps': steps_per_epoch, 'samples': None, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], }) callbacks.on_train_begin() callback_model.stop_training = False out_labels = out_labels or [] # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) if steps_per_epoch is not None: epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {} batch_logs['batch'] = step_index batch_logs['size'] = 1 callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_train_function(ins) except errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] # TODO(anjalisridhar): Temporary workaround for aggregating metrics # across towers. Replace with the new metrics module eventually. merged_output = [] # The first output is the total loss. merged_output.append(outs[0]) current_index = 1 num_devices = len(current_strategy._devices) # Each label in `out_labels` corresponds to one set of metrics. The # number of metric values corresponds to the number of devices. We # currently take the mean of the values. for _ in out_labels[1:]: m = np.mean(outs[current_index:current_index + num_devices]) merged_output.append(m) current_index += num_devices for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callback_model.stop_training: break if do_validation: val_outs = test_loop(model, val_inputs, val_targets, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. with current_strategy.scope(): updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history
def test_loop(model, inputs, targets, verbose=0, steps=None): """evaluate method to validate a model that uses DistributionStrategy. Arguments: model: Keras Model instance. inputs: List of input arrays. targets: List of target arrays. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring predictions finished. Ignored with the default value of `None`. Returns: Scalar loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ current_strategy = model._distribution_strategy def _per_device_test_function(model): model._make_test_function() return (model.test_function.inputs, model.test_function.outputs, model.test_function.updates_op, model.test_function.session_kwargs) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_test_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) distributed_test_function = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [None for _ in range(len(model.outputs) * current_strategy.num_towers)] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [0] else: ins = dataset_inputs + dataset_targets outs = [] if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) if steps is not None: for step in range(steps): batch_outs = distributed_test_function(ins) batch_outs = _aggregate_metrics_across_towers( len(current_strategy._devices), model.metrics_names, batch_outs) if isinstance(batch_outs, list): if step == 0: for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): outs[i] += batch_out else: if step == 0: outs.append(0.) outs[0] += batch_outs if verbose == 1: progbar.update(step + 1) for i in range(len(outs)): outs[i] /= steps if len(outs) == 1: return outs[0] return outs
def predict_loop(model, inputs, verbose=0, steps=None): """Abstract method to loop over some data in batches. Arguments: model: Keras Model instance. inputs: list of tensors to be fed to `f`. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. Ignored with the default value of `None`. Returns: Array of predictions (if the model has a single output) or list of arrays of predictions (if the model has multiple outputs). """ current_strategy = model._distribution_strategy def _per_device_predict_function(model): model._make_predict_function() return (model.predict_function.inputs, model.predict_function.outputs, model.predict_function.updates_op, model.predict_function.session_kwargs) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_predict_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) distributed_predict_function = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + [0] else: ins = dataset_inputs if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) if steps is not None: # Since we do not know how many samples we will see, we cannot pre-allocate # the returned Numpy arrays. Instead, we store one array per batch seen # and concatenate them upon returning. unconcatenated_outs = [] for step in range(steps): batch_outs = distributed_predict_function(ins) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if step == 0: for _ in batch_outs: unconcatenated_outs.append([]) for i, batch_out in enumerate(batch_outs): unconcatenated_outs[i].append(batch_out) if verbose == 1: progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) return [ np.concatenate(unconcatenated_outs[i], axis=0) for i in range(len(unconcatenated_outs)) ]
def predict_loop(model, iterator, verbose=0, steps=None): """Predict loop for predicting with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. verbose: Integer, Verbosity mode 0 or 1. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. Ignored with the default value of `None`. Returns: Array of predictions (if the model has a single output) or list of arrays of predictions (if the model has multiple outputs). """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_predict_loop(model, iterator, verbose, steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy) def _per_device_predict_function(model): model._make_predict_function() return (model.predict_function.inputs, model.predict_function.outputs, model.predict_function.updates_op, model.predict_function.session_kwargs) inputs, _, _ = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_predict_function, args=(model._grouped_model, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) distributed_predict_function = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) num_replicas = current_strategy.num_replicas_in_sync # Since we do not know how many samples we will see, we cannot # pre-allocate the returned Numpy arrays. Instead, we store one array per # batch seen and concatenate them upon returning. unconcatenated_outs = [] assert steps is not None for step in range(steps): batch_outs = distributed_predict_function(dataset_inputs) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if step == 0: # batch_outs gives you the number of model outputs. In the distributed # case this will be number of model_outputs * num_replicas. for _ in range(len(model.outputs)): unconcatenated_outs.append([]) for i in range(len(model.outputs)): nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] outs = nest.flatten(nested_outs) unconcatenated_outs[i].extend(outs) if verbose >= 1: progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) return [ np.concatenate(unconcatenated_outs[i], axis=0) for i in range(len(unconcatenated_outs)) ]
def fit_loop(model, iterator, epochs=100, verbose=1, callbacks=None, val_iterator=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """Fit loop for training with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. epochs: Number of times to iterate over the data verbose: Integer, Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_fit_loop(model, iterator, epochs, verbose, callbacks, initial_epoch, steps_per_epoch, val_iterator, validation_steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy, make_callback_model=True) def _per_device_fit_function(model): model._make_fit_function() return (model._fit_function.inputs, model._fit_function.outputs, model._fit_function.updates_op, model._fit_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model, )) # Initialize the variables in the replicated model. This is necessary for # multi-worker training because on some workers, initialization is not # needed. This method does initialization or waiting for initialization # according to the context object of distribute coordinator. distributed_training_utils.init_restore_or_wait_for_variables() # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_fit_function = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range( len(model.outputs) * current_strategy.num_replicas_in_sync) ] if not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) callbacks = cbks.configure_callbacks(callbacks, model, do_validation=do_validation, val_inputs=None, val_targets=None, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() assert steps_per_epoch is not None for epoch in range(initial_epoch, epochs): # Reset stateful metrics for m in model.stateful_metric_functions: m.reset_states() callbacks.on_epoch_begin(epoch) epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_fit_function(ins) except errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callbacks.model.stop_training: break if do_validation: val_outs = test_loop(model, val_iterator, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history
def test_loop(model, iterator, verbose=0, steps=None): """Test loop for evaluating with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. verbose: Integer, Verbosity mode 0 or 1. steps: Total number of steps (batches of samples) before declaring predictions finished. Ignored with the default value of `None`. Returns: Scalar loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the outputs. """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_test_loop(model, iterator, verbose, steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy) def _per_device_eval_function(model): model._make_eval_function() return (model._eval_function.inputs, model._eval_function.outputs, model._eval_function.updates_op, model._eval_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_eval_function, args=(model._grouped_model, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) distributed_test_function = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range( len(model.outputs) * current_strategy.num_replicas_in_sync) ] ins = dataset_inputs + dataset_targets + sample_weights for m in model.stateful_metric_functions: m.reset_states() outs = [] if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) assert steps is not None for step in range(steps): batch_outs = distributed_test_function(ins) if isinstance(batch_outs, list): if step == 0: outs = [0.] * len(batch_outs) outs[0] += batch_outs[0] # index 0 = 'loss' outs[1:] = batch_outs[1:] else: if step == 0: outs.append(0.) outs[0] += batch_outs # index 0 = 'loss' if verbose >= 1: progbar.update(step + 1) outs[0] /= steps # index 0 = 'loss' if len(outs) == 1: return outs[0] return outs
def fit_loop( model, inputs, targets, epochs=100, verbose=1, callbacks=None, val_inputs=None, val_targets=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """fit function when using DistributionStrategy for training. Arguments: model: Keras Model instance. inputs: List of input arrays. targets: List of target arrays. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_inputs: List of input arrays. val_targets: List of target arrays. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy def _per_device_train_function(model): model._make_train_function() return (model.train_function.inputs, model.train_function.outputs, model.train_function.updates_op, model.train_function.session_kwargs) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_train_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) # Unwrap all the per device values returned from `call_for_each_tower`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_train_function = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [None for _ in range(len(model.outputs) * current_strategy.num_towers)] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True if steps_per_epoch is None: raise ValueError('Can only use `validation_steps` ' 'when doing step-wise ' 'training, i.e. `steps_per_epoch` ' 'must be set.') # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) callbacks = cbks.configure_callbacks( callbacks, model, do_validation=do_validation, val_inputs=None, val_targets=None, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) if steps_per_epoch is not None: epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_train_function(ins) except errors.OutOfRangeError: logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] outs = _aggregate_metrics_across_towers( len(current_strategy._devices), out_labels, outs) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callbacks.model.stop_training: break if do_validation: val_outs = test_loop( model, val_inputs, val_targets, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. with current_strategy.scope(): updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history
def fit_loop(model, iterator, epochs=100, verbose=1, callbacks=None, val_iterator=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """fit function when using DistributionStrategy for training. Arguments: model: Keras Model instance. iterator: Iterator for input data. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy def _per_device_train_function(model): model._make_train_function() return (model.train_function.inputs, model.train_function.outputs, model.train_function.updates_op, model.train_function.session_kwargs) inputs, targets = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_train_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) # Unwrap all the per device values returned from `call_for_each_tower`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_train_function = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range(len(model.outputs) * current_strategy.num_towers) ] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True if steps_per_epoch is None: raise ValueError('Can only use `validation_steps` ' 'when doing step-wise ' 'training, i.e. `steps_per_epoch` ' 'must be set.') # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) callbacks = cbks.configure_callbacks(callbacks, model, do_validation=do_validation, val_inputs=None, val_targets=None, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) if steps_per_epoch is not None: epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_train_function(ins) except errors.OutOfRangeError: logging.warning( 'Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] outs = _aggregate_metrics_across_towers( current_strategy.num_towers, out_labels, outs) for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callbacks.model.stop_training: break if do_validation: val_outs = test_loop(model, val_iterator, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. with current_strategy.scope(): updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history
def predict_loop(model, iterator, verbose=0, steps=None): """Predict loop for predicting with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. verbose: Integer, Verbosity mode 0 or 1. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. Ignored with the default value of `None`. Returns: Array of predictions (if the model has a single output) or list of arrays of predictions (if the model has multiple outputs). """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_predict_loop(model, iterator, verbose, steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy) def _per_device_predict_function(model): model._make_predict_function() return (model.predict_function.inputs, model.predict_function.outputs, model.predict_function.updates_op, model.predict_function.session_kwargs) inputs, _, _ = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_predict_function, args=(model._grouped_model,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) distributed_predict_function = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) if not isinstance(K.learning_phase(), int): ins = dataset_inputs + [0] else: ins = dataset_inputs if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) num_replicas = current_strategy.num_replicas_in_sync # Since we do not know how many samples we will see, we cannot # pre-allocate the returned Numpy arrays. Instead, we store one array per # batch seen and concatenate them upon returning. unconcatenated_outs = [] assert steps is not None for step in range(steps): batch_outs = distributed_predict_function(ins) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if step == 0: # batch_outs gives you the number of model outputs. In the distributed # case this will be number of model_outputs * num_replicas. for _ in range(len(model.outputs)): unconcatenated_outs.append([]) for i in range(len(model.outputs)): nested_outs = batch_outs[i * num_replicas: i * num_replicas + num_replicas] outs = nest.flatten(nested_outs) unconcatenated_outs[i].extend(outs) if verbose >= 1: progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) return [ np.concatenate(unconcatenated_outs[i], axis=0) for i in range(len(unconcatenated_outs)) ]
def fit_loop( model, iterator, epochs=100, verbose=1, callbacks=None, val_iterator=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): """Fit loop for training with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. epochs: Number of times to iterate over the data verbose: Integer, Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with the default value of `None`. Returns: `History` object. Raises: ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_fit_loop( model, iterator, epochs, verbose, callbacks, initial_epoch, steps_per_epoch, val_iterator, validation_steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy, make_callback_model=True) def _per_device_fit_function(model): model._make_fit_function() return (model._fit_function.inputs, model._fit_function.outputs, model._fit_function.updates_op, model._fit_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_fit_function`. (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model,)) # Unwrap all the per device values returned from `call_for_each_replica`. # Unwrapping per device values gives you a list of values that can be # used to construct a new train function that is composed of update ops on # all the devices over which the model is distributed. (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) # Dataset inputs and targets are also per devices values that need to be # unwrapped. dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) # Create a train function that is composed of all the parameters above. distributed_fit_function = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [None for _ in range( len(model.outputs) * current_strategy.num_replicas_in_sync)] if not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [1] else: ins = dataset_inputs + dataset_targets do_validation = False if validation_steps: do_validation = True # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) callbacks = cbks.configure_callbacks( callbacks, model, do_validation=do_validation, val_inputs=None, val_targets=None, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() assert steps_per_epoch is not None for epoch in range(initial_epoch, epochs): # Reset stateful metrics for m in model.stateful_metric_functions: m.reset_states() callbacks.on_epoch_begin(epoch) epoch_logs = {} for step_index in range(steps_per_epoch): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = distributed_fit_function(ins) except errors.OutOfRangeError: logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' 'batches (in this case, %d batches).' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] for l, o in zip(out_labels, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) if callbacks.model.stop_training: break if do_validation: val_outs = test_loop( model, val_iterator, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: break callbacks.on_train_end() # Copy the weights back from the replicated model to the original model. updated_weights = current_strategy.unwrap( model._grouped_model)[0].get_weights() model.set_weights(updated_weights) return model.history
def test_loop(model, iterator, verbose=0, steps=None): """Test loop for evaluating with DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. verbose: Integer, Verbosity mode 0 or 1. steps: Total number of steps (batches of samples) before declaring predictions finished. Ignored with the default value of `None`. Returns: Scalar loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the outputs. """ current_strategy = model._distribution_strategy # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. if current_strategy.__class__.__name__ == 'TPUStrategy': return _experimental_test_loop(model, iterator, verbose, steps) if not model._grouped_model: clone_model_on_replicas(model, current_strategy) def _per_device_eval_function(model): model._make_eval_function() return (model._eval_function.inputs, model._eval_function.outputs, model._eval_function.updates_op, model._eval_function.session_kwargs) inputs, targets, sample_weights = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_eval_function, args=(model._grouped_model,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) distributed_test_function = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [None for _ in range( len(model.outputs) * current_strategy.num_replicas_in_sync)] if not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [0] else: ins = dataset_inputs + dataset_targets for m in model.stateful_metric_functions: m.reset_states() outs = [] if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights( current_strategy, distributed_model, orig_model_weights) assert steps is not None for step in range(steps): batch_outs = distributed_test_function(ins) if isinstance(batch_outs, list): if step == 0: outs = [0.] * len(batch_outs) outs[0] += batch_outs[0] # index 0 = 'loss' outs[1:] = batch_outs[1:] else: if step == 0: outs.append(0.) outs[0] += batch_outs # index 0 = 'loss' if verbose >= 1: progbar.update(step + 1) outs[0] /= steps # index 0 = 'loss' if len(outs) == 1: return outs[0] return outs
def test_loop(model, iterator, verbose=0, steps=None): """evaluate method to validate a model that uses DistributionStrategy. Arguments: model: Keras Model instance. iterator: Iterator for input data. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring predictions finished. Ignored with the default value of `None`. Returns: Scalar loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ current_strategy = model._distribution_strategy clone_model_on_towers(model, current_strategy) def _per_device_test_function(model): model._make_test_function() return (model.test_function.inputs, model.test_function.outputs, model.test_function.updates_op, model.test_function.session_kwargs) inputs, targets = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_test_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) dataset_targets = distributed_training_utils.flatten_perdevice_values( current_strategy, targets) distributed_test_function = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) # We need to set sample_weights to None since there are sample weight # placeholders that are created with default values. sample_weights = [ None for _ in range(len(model.outputs) * current_strategy.num_towers) ] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + dataset_targets + sample_weights + [0] else: ins = dataset_inputs + dataset_targets outs = [] if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) if steps is not None: for step in range(steps): batch_outs = distributed_test_function(ins) batch_outs = _aggregate_metrics_across_towers( current_strategy.num_towers, model.metrics_names, batch_outs) if isinstance(batch_outs, list): if step == 0: for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): outs[i] += batch_out else: if step == 0: outs.append(0.) outs[0] += batch_outs if verbose == 1: progbar.update(step + 1) for i in range(len(outs)): outs[i] /= steps if len(outs) == 1: return outs[0] return outs
def predict_loop(model, iterator, verbose=0, steps=None): """Abstract method to loop over some data in batches. Arguments: model: Keras Model instance. iterator: Iterator for input data. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. Ignored with the default value of `None`. Returns: Array of predictions (if the model has a single output) or list of arrays of predictions (if the model has multiple outputs). """ current_strategy = model._distribution_strategy clone_model_on_towers(model, current_strategy) def _per_device_predict_function(model): model._make_predict_function() return (model.predict_function.inputs, model.predict_function.outputs, model.predict_function.updates_op, model.predict_function.session_kwargs) inputs, _ = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_predict_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) dataset_inputs = distributed_training_utils.flatten_perdevice_values( current_strategy, inputs) distributed_predict_function = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_predict_function', **all_session_args) if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = dataset_inputs + [0] else: ins = dataset_inputs if verbose == 1: progbar = Progbar(target=steps) # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() with current_strategy.scope(): distributed_model = current_strategy.unwrap(model._grouped_model)[0] distributed_training_utils.set_weights(current_strategy, distributed_model, orig_model_weights) if steps is not None: # Since we do not know how many samples we will see, we cannot pre-allocate # the returned Numpy arrays. Instead, we store one array per batch seen # and concatenate them upon returning. unconcatenated_outs = [] for step in range(steps): batch_outs = distributed_predict_function(ins) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if step == 0: for _ in batch_outs: unconcatenated_outs.append([]) for i, batch_out in enumerate(batch_outs): unconcatenated_outs[i].append(batch_out) if verbose == 1: progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) return [ np.concatenate(unconcatenated_outs[i], axis=0) for i in range(len(unconcatenated_outs)) ]