def _scale_loss(loss_value): if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def _scale_loss(loss_value): ops.get_default_graph()._is_loss_scaled_by_optimizer = False # pylint: disable=protected-access if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) ops.get_default_graph()._is_loss_scaled_by_optimizer = True # pylint: disable=protected-access return loss_value
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. inputs, targets = inputs clone_model_on_replicas( model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=_Mode.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_fit_function, args=(model._grouped_model_train,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_eval_function.""" # TODO(priyag, sourabhbajaj): The model gets cloned every time # fit/test/predict is called. We should look into caching this keyed on # input shapes. inputs, targets = inputs clone_model_on_replicas(model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=_Mode.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args ) = current_strategy.extended.call_for_each_replica( _per_device_eval_function, args=(model._grouped_model_test, )) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function(all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_eval_function.""" inputs, targets = inputs if model._compile_distribution: distributed_training_utils. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=distributed_training_utils.ModeKeys.TEST) else: distributed_training_utils._build_distributed_network( model, current_strategy, inputs, targets, mode=distributed_training_utils.ModeKeys.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_eval_function, args=(model._distributed_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_eval_function.""" inputs, targets = inputs if model._compile_distribution: distributed_training_utils. clone_model_on_replicas( model, current_strategy, make_callback_model=False, inputs=inputs, targets=targets, mode=distributed_training_utils.ModeKeys.TEST) else: distributed_training_utils._build_distributed_network( model, current_strategy, inputs, targets, mode=distributed_training_utils.ModeKeys.TEST) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_eval_function, args=(model._distributed_model_test,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_test_function', **all_session_args) for label, output in zip(model.metrics_names, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) return combined_fn.updates_op
def step_fn(ctx, inputs): """Clones the model and calls make_fit_function.""" inputs, targets = inputs if model._compile_distribution: clone_model_on_replicas(model, current_strategy, make_callback_model=True, inputs=inputs, targets=targets, mode=ModeKeys.TRAIN) else: _build_distributed_network(model, current_strategy, inputs, targets, mode=ModeKeys.TRAIN) (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.extended.call_for_each_replica( _per_device_fit_function, args=(model._distributed_model_train,)) (all_inputs, all_outputs, all_updates, all_session_args) = distributed_training_utils.unwrap_values( current_strategy, grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) combined_fn = K.function( all_inputs, all_outputs, updates=all_updates, name='distributed_fit_function', **all_session_args) for label, output in zip(out_labels, combined_fn.outputs): if label == 'loss': reduce_op = distribute_lib.get_loss_reduction() else: # We reduce all other metrics using mean for now. This is temporary # workaround until new metrics are in place. reduce_op = ds_reduce_util.ReduceOp.MEAN ctx.set_last_step_output(label, output, reduce_op) # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: # feed_dict, session kwargs, run options, run_metadata for now. These should # be handled appropriately return combined_fn.updates_op
def _scale_loss(loss_value): if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= (1. / num_replicas) return loss_value
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, grouped_updates=None, grouped_session_args=None, with_loss_tensor=False): """Unwrap and return the list of values contained in the PerDevice parameters. This function calls `flatten_perdevice_values` to parse each of the input parameters into a list of values on the different devices. If we set `with_loss_tensor` to be True, we also call `reduce` on the list of losses on the different devices to give us one loss tensor. Args: distribution_strategy: DistributionStrategy used to distribute training and validation. grouped_inputs: PerDevice inputs returned from the train or test function that we ran on each device. grouped_outputs: PerDevice outputs returned from the train or test function that we ran on each device. grouped_updates: PerDevice updates returned from the train or test function that we ran on each device. grouped_session_args: PerDevice session args returned from the train or test function that we ran on each device. with_loss_tensor: Boolean that indicates if we need to add the reduced loss tensor as one of the outputs. Returns: Values of each of the PerDevice parameters. """ # Unwrap per device values returned from each model's train function. # This will be used to construct the main train function. all_inputs = flatten_perdevice_values(distribution_strategy, grouped_inputs) if with_loss_tensor: # reduce loss tensor before adding it to the list of fetches loss = distribution_strategy.reduce( distribute_lib.get_loss_reduction(), grouped_outputs[0]) all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs[1:]) all_outputs = [loss] + all_outputs else: all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs) if grouped_updates: all_updates = flatten_perdevice_values(distribution_strategy, grouped_updates) else: all_updates = None all_session_args = {} if grouped_session_args: grouped_feed_dict = grouped_session_args.get('feed_dict') if grouped_feed_dict: all_session_args['feed_dict'] = flatten_perdevice_values( distribution_strategy, grouped_feed_dict) grouped_fetches = grouped_session_args.get('fetches') if grouped_fetches: all_session_args['fetches'] = flatten_perdevice_values( distribution_strategy, grouped_fetches) # TODO(priyag): Return only non empty/None values return all_inputs, all_outputs, all_updates, all_session_args
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, grouped_updates=None, grouped_session_args=None, with_loss_tensor=False): """Unwrap and return the list of values contained in the PerDevice parameters. This function calls `flatten_perdevice_values` to parse each of the input parameters into a list of values on the different devices. If we set `with_loss_tensor` to be True, we also call `reduce` on the list of losses on the different devices to give us one loss tensor. Args: distribution_strategy: DistributionStrategy used to distribute training and validation. grouped_inputs: PerDevice inputs returned from the train or test function that we ran on each device. grouped_outputs: PerDevice outputs returned from the train or test function that we ran on each device. grouped_updates: PerDevice updates returned from the train or test function that we ran on each device. grouped_session_args: PerDevice session args returned from the train or test function that we ran on each device. with_loss_tensor: Boolean that indicates if we need to add the reduced loss tensor as one of the outputs. Returns: Values of each of the PerDevice parameters. """ # Unwrap per device values returned from each model's train function. # This will be used to construct the main train function. all_inputs = flatten_perdevice_values(distribution_strategy, grouped_inputs) if with_loss_tensor: # reduce loss tensor before adding it to the list of fetches loss = distribution_strategy.reduce(distribute_lib.get_loss_reduction(), grouped_outputs[0]) all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs[1:]) all_outputs = [loss] + all_outputs else: all_outputs = flatten_perdevice_values(distribution_strategy, grouped_outputs) if grouped_updates: all_updates = flatten_perdevice_values(distribution_strategy, grouped_updates) else: all_updates = None all_session_args = {} if grouped_session_args: grouped_feed_dict = grouped_session_args.get('feed_dict') if grouped_feed_dict: all_session_args['feed_dict'] = flatten_perdevice_values( distribution_strategy, grouped_feed_dict) grouped_fetches = grouped_session_args.get('fetches') if grouped_fetches: all_session_args['fetches'] = flatten_perdevice_values( distribution_strategy, grouped_fetches) # TODO(priyag): Return only non empty/None values return all_inputs, all_outputs, all_updates, all_session_args