def _update_sample_weight_mode(model, mode, inputs): """Updates the sample_weight_mode of a given model.""" # Add a quick return to prevent us from calling model._feed_targets that # accesses certain model properties that may not be set in the `PREDICT` mode. if mode == ModeKeys.PREDICT: return sample_weights = None # `inputs` is the model's inputs + targets + sample_weights + # learning phase placeholder if specified. To update the sample_weight_mode # we need to determine if the user has passed sample weights as part of the # input. if not callable(inputs): sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):] has_learning_phase_pl = ( mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(), int)) if has_learning_phase_pl: sample_weights = sample_weights[:-1] model._update_sample_weight_modes(sample_weights=sample_weights) # Call the DistributionStrategy specific function to update the # sample_weight_mode on the model. if model._distribution_strategy: distributed_training_utils._update_sample_weight_modes( model, mode, sample_weights)
def _update_sample_weight_mode(model, mode, adapter): """Updates the sample_weight_mode of a given model.""" # Add a quick return to prevent us from calling model._feed_targets that # accesses certain model properties that may not be set in the `PREDICT` mode. if mode == ModeKeys.PREDICT: return sample_weights = None # Get some sample inputs from the data_adapter iterator = _create_dataset_iterator(model._distribution_strategy, adapter.get_dataset()) inputs = create_batch_inputs(iterator, mode, model, model._distribution_strategy) # `inputs` is the model's inputs + targets + sample_weights + # learning phase placeholder if specified. To update the sample_weight_mode # we need to determine if the user has passed sample weights as part of the # input. if not callable(inputs): # if not isinstance(inputs, collections.Sequence): # inputs = (inputs,) # Note that the batch inputs should be a tuple of 2, 3 or 4 items. # (input, target, {sample_weights}, {learning_phase}) sample_weights_index = 0 if model._feed_inputs: sample_weights_index += 1 if model._feed_targets: sample_weights_index += 1 sample_weights = inputs[sample_weights_index:] has_learning_phase_pl = ( mode == ModeKeys.TRAIN and not isinstance(backend.symbolic_learning_phase(), int)) if has_learning_phase_pl: sample_weights = sample_weights[:-1] model._update_sample_weight_modes(nest.flatten(sample_weights)) # Call the DistributionStrategy specific function to update the # sample_weight_mode on the model. if model._distribution_strategy: dist_utils._update_sample_weight_modes(model, mode, sample_weights) # Force delete the iterator. del iterator
def _update_sample_weight_mode(model, mode, dataset): """Updates the sample_weight_mode of a given model.""" # TODO(kaftan): This won't actually do anything right now because ## dist_utils._update_sample_weight_modes only does things when the model ## is distributed by cloning. We will need to revisit if a method here ## is needed at all, and if so how it should look. # Add a quick return to prevent us from calling model._feed_targets that # accesses certain model properties that may not be set in the `PREDICT` mode. if mode == ModeKeys.PREDICT: return # Get some sample inputs from the data_adapter iterator = iter(dataset) _, _, sample_weights = training_v2_utils._prepare_feed_values( model, iterator, mode) # Call the DistributionStrategy specific function to update the # sample_weight_mode on the model. dist_utils._update_sample_weight_modes(model, mode, sample_weights) # Force delete the iterator. del iterator