def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. clone_model_on_towers( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be # something else for different outputs. out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): ctx.set_last_step_output(label, output, aggregation=distribute_lib.get_loss_reduction()) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def _scale_loss(loss_value): if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def _scale_loss(loss_value): if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. inputs, targets = inputs clone_model_on_replicas( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=_Mode.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model_train,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. inputs, targets = inputs clone_model_on_replicas(model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=_Mode.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model_train, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_towers( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_towers( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.Function( all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_eval_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=_Mode.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_eval_function, args=(model._grouped_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_eval_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=_Mode.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_replica( _per_device_eval_function, model._grouped_model_test) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': aggregation = distribute_lib.get_loss_reduction() else: # We aggregate all other metrics using mean for now. This is temporary # workaround until new metrics are in place. aggregation = variable_scope.VariableAggregation.MEAN ctx.set_last_step_output(label, output, aggregation) return combined_fn.updates_op
def step_fn(ctx, inputs, targets): """Clones the model and calls make_train_function.""" # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. clone_model_on_towers(model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( _per_device_train_function, model._grouped_model) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=True) combined_fn = K.Function(all_inputs, all_outputs, updates=all_updates, name='distributed_train_function', **all_session_args) # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be # something else for different outputs. out_labels = model.metrics_names or [] for label, output in zip(out_labels, combined_fn.outputs): ctx.set_last_step_output( label, output, aggregation=distribute_lib.get_loss_reduction()) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=False): """Unwrap and return the list of values contained in the PerDevice parameters. This function calls `flatten_perdevice_values` to parse each of the input parameters into a list of values on the different devices. If we set `with_loss_tensor` to be True, we also call `reduce` on the list of losses on the different devices to give us one loss tensor. Args: distribution_strategy: DistributionStrategy used to distribute training and validation. grouped_inputs: PerDevice inputs returned from the train or test function that we ran on each device. grouped_outputs: PerDevice outputs returned from the train or test function that we ran on each device. grouped_updates: PerDevice updates returned from the train or test function that we ran on each device. grouped_session_args: PerDevice session args returned from the train or test function that we ran on each device. with_loss_tensor: Boolean that indicates if we need to add the reduced loss tensor as one of the outputs. Returns: Values of each of the PerDevice parameters. """ # Unwrap per device values returned from each model's train function. # This will be used to construct the main train function. all_inputs = flatten_perdevice_values(distribution_strategy, grouped_inputs) if with_loss_tensor: # reduce loss tensor before adding it to the list of fetches loss = distribution_strategy.unwrap( distribution_strategy.reduce(distribute_lib.get_loss_reduction(), grouped_outputs[0], destinations='/device:CPU:0'))[0] all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs[1:]) all_outputs = [loss] + all_outputs else: all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs) all_updates = flatten_perdevice_values(distribution_strategy, grouped_updates) all_session_args = {} grouped_feed_dict = grouped_session_args.get('feed_dict') if grouped_feed_dict: all_session_args['feed_dict'] = flatten_perdevice_values( distribution_strategy, grouped_feed_dict) grouped_fetches = grouped_session_args.get('fetches') if grouped_fetches: all_session_args['fetches'] = flatten_perdevice_values( distribution_strategy, grouped_fetches) return all_inputs, all_outputs, all_updates, all_session_args
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args, with_loss_tensor=False): """Unwrap and return the list of values contained in the PerDevice parameters. This function calls `flatten_perdevice_values` to parse each of the input parameters into a list of values on the different devices. If we set `with_loss_tensor` to be True, we also call `reduce` on the list of losses on the different devices to give us one loss tensor. Args: distribution_strategy: DistributionStrategy used to distribute training and validation. grouped_inputs: PerDevice inputs returned from the train or test function that we ran on each device. grouped_outputs: PerDevice outputs returned from the train or test function that we ran on each device. grouped_updates: PerDevice updates returned from the train or test function that we ran on each device. grouped_session_args: PerDevice session args returned from the train or test function that we ran on each device. with_loss_tensor: Boolean that indicates if we need to add the reduced loss tensor as one of the outputs. Returns: Values of each of the PerDevice parameters. """ # Unwrap per device values returned from each model's train function. # This will be used to construct the main train function. all_inputs = flatten_perdevice_values(distribution_strategy, grouped_inputs) if with_loss_tensor: # reduce loss tensor before adding it to the list of fetches loss = distribution_strategy.unwrap( distribution_strategy.reduce(distribute_lib.get_loss_reduction(), grouped_outputs[0], destinations='/device:CPU:0'))[0] all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs[1:]) all_outputs = [loss] + all_outputs else: all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs) all_updates = flatten_perdevice_values(distribution_strategy, grouped_updates) all_session_args = {} grouped_feed_dict = grouped_session_args.get('feed_dict') if grouped_feed_dict: all_session_args['feed_dict'] = flatten_perdevice_values( distribution_strategy, grouped_feed_dict) grouped_fetches = grouped_session_args.get('fetches') if grouped_fetches: all_session_args['fetches'] = flatten_perdevice_values( distribution_strategy, grouped_fetches) return all_inputs, all_outputs, all_updates, all_session_args
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a `Tensor`, an `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: A list of (gradient, variable) pairs. Variable is always present, but gradient can be `None`. Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. RuntimeError: If called with eager execution enabled and `loss` is not callable. @compatibility(eager) When eager execution is enabled, `gate_gradients`, `aggregation_method`, and `colocate_gradients_with_ops` are ignored. @end_compatibility """ if callable(loss): with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) self._assert_valid_dtypes([loss]) if grad_loss is not None: self._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] grads = gradients.gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) self._assert_valid_dtypes( [v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource]) return grads_and_vars
def compute_gradients(optimizer, loss, var_list=None, gate_gradients=Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): if callable(loss): from tensorflow.python.eager import backprop with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() # TODO(jhseu): Figure out why GradientTape's gradients don't require loss # to be executed. with ops.control_dependencies([loss_value]): grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribution_strategy_context.get_distribution_strategy( ).num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [ Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH ]: raise ValueError( "gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) optimizer._assert_valid_dtypes([loss]) if grad_loss is not None: optimizer._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = (variables.trainable_variables() + ops.get_collection( ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection( ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access from tensorflow.python.training.optimizer import _get_processor processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] # original gradients computation # grads = tf.gradients( # loss, var_refs, grad_ys=grad_loss, # gate_gradients=(gate_gradients == Optimizer.GATE_OP), # aggregation_method=aggregation_method, # colocate_gradients_with_ops=colocate_gradients_with_ops) # using gradient check-pointing from memory_saving_gradients import gradients # setting outputs of different networks tensors_to_checkpoint = self.get_tensors_to_checkpoint() # just specifying memory as parameter fails grads = gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, checkpoints='speed') if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) optimizer._assert_valid_dtypes([ v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource ]) return grads_and_vars
def _train_model_distributed(self, strategy, input_fn, hooks, saving_listeners, save_best_ckpt): """Initiate training with `input_fn`, using `DistributionStrategies`. Args: input_fn: A function that provides input data for training as minibatches. hooks: List of `tf.train.SessionRunHook` subclass instances. Used for callbacks inside the training loop. saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used for callbacks that run immediately before or after checkpoint savings. Returns: Loss from training """ strategy.configure(self._session_config) worker_hooks = [] with ops.Graph().as_default() as g: # We want to create the iterations variable outside the distribution scope # as that is just stored on the host and mainly used to drive the loop # and doesn't need to be a Mirrored/Device variable. with strategy.scope(): random_seed.set_random_seed(self._config.tf_random_seed) if self._train_with_eval: self.handler = array_ops.placeholder(dtypes.string, shape=(), name="Handler") iterator, self.train_iterator, self.eval_iterator, input_hooks = ( self._get_iterator_for_train_and_eval( input_fn, self.handler, strategy)) else: self.handler, self.train_iterator, self.eval_iterator = None, None, None iterator, input_hooks = self._get_iterator_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN, strategy) worker_hooks.extend(input_hooks) global_step_tensor = self._create_and_assert_global_step(g) # we want to add to the global collection in the main thread not the # tower threads. ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, strategy.read_var(global_step_tensor)) features, labels = estimator_util.parse_iterator_result( per_device_dataset(iterator, strategy.extended._devices)) grouped_estimator_spec = strategy.call_for_each_replica( self._call_model_fn, args=(features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)) loss = strategy.reduce(distribute_lib.get_loss_reduction(), grouped_estimator_spec.loss) distributed_train_op = grouped_estimator_spec.train_op predictions = {} for key, val in grouped_estimator_spec.predictions.items(): if key == "GlobalStep": predictions["GlobalStep"] = strategy.unwrap(val)[0] elif "/" in key: predictions[key] = strategy.reduce( reduce_util.ReduceOp.MEAN, val) else: predictions[key] = array_ops.concat( strategy.unwrap(val), axis=0) scaffold = estimator_lib._combine_distributed_scaffold( grouped_estimator_spec.scaffold, strategy) # add a test for unwrapping per_device_hooks. def get_hooks_from_the_first_device(per_device_hooks): # In tensorflow-1.12 Estimator, Next line is self._distribution.unwrap() # but self._distribution is not defined, which maybe a bug? return [ strategy.unwrap(per_device_hook)[0] for per_device_hook in per_device_hooks ] training_hooks = get_hooks_from_the_first_device( grouped_estimator_spec.training_hooks) training_chief_hooks = get_hooks_from_the_first_device( grouped_estimator_spec.training_chief_hooks) worker_hooks.append( estimator_util.StrategyInitFinalizeHook( strategy.initialize, strategy.finalize)) estimator_spec = model_fn_lib.EstimatorSpec( mode=grouped_estimator_spec.mode, loss=loss, train_op=strategy.group(distributed_train_op), predictions=predictions, training_hooks=training_hooks, training_chief_hooks=training_chief_hooks, scaffold=scaffold) return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners, save_best_ckpt)
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a `Tensor`, an `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: A list of (gradient, variable) pairs. Variable is always present, but gradient can be `None`. Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. RuntimeError: If called with eager execution enabled and `loss` is not callable. @compatibility(eager) When eager execution is enabled, `gate_gradients`, `aggregation_method`, and `colocate_gradients_with_ops` are ignored. @end_compatibility """ if callable(loss): with backprop.GradientTape() as tape: if var_list is not None: tape.watch(var_list) loss_value = loss() # Scale loss if using a "mean" loss reduction and multiple towers. # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss_value *= (1. / num_towers) if var_list is None: var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: loss *= (1. / num_towers) if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) self._assert_valid_dtypes([loss]) if grad_loss is not None: self._assert_valid_dtypes([grad_loss]) if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) else: var_list = nest.flatten(var_list) # pylint: disable=protected-access var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) # pylint: enable=protected-access processors = [_get_processor(v) for v in var_list] if not var_list: raise ValueError("No variables to optimize.") var_refs = [p.target() for p in processors] grads = gradients.gradients( loss, var_refs, grad_ys=grad_loss, gate_gradients=(gate_gradients == Optimizer.GATE_OP), aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops) if gate_gradients == Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) grads_and_vars = list(zip(grads, var_list)) self._assert_valid_dtypes( [v for g, v in grads_and_vars if g is not None and v.dtype != dtypes.resource]) return grads_and_vars