Exemplo n.º 1
0
def _update_sample_weight_mode(model, mode, inputs):
    """Updates the sample_weight_mode of a given model."""
    # Add a quick return to prevent us from calling model._feed_targets that
    # accesses certain model properties that may not be set in the `PREDICT` mode.
    if mode == ModeKeys.PREDICT:
        return

    sample_weights = None
    # `inputs` is the model's inputs + targets + sample_weights +
    # learning phase placeholder if specified. To update the sample_weight_mode
    # we need to determine if the user has passed sample weights as part of the
    # input.
    if not callable(inputs):
        sample_weights = inputs[len(model._feed_inputs) +
                                len(model._feed_targets):]
        has_learning_phase_pl = (
            mode == ModeKeys.TRAIN
            and not isinstance(K.symbolic_learning_phase(), int))
        if has_learning_phase_pl:
            sample_weights = sample_weights[:-1]
        model._update_sample_weight_modes(sample_weights=sample_weights)

    # Call the DistributionStrategy specific function to update the
    # sample_weight_mode on the model.
    if model._distribution_strategy:
        distributed_training_utils._update_sample_weight_modes(
            model, mode, sample_weights)
Exemplo n.º 2
0
def _update_sample_weight_mode(model, mode, adapter):
    """Updates the sample_weight_mode of a given model."""
    # Add a quick return to prevent us from calling model._feed_targets that
    # accesses certain model properties that may not be set in the `PREDICT` mode.
    if mode == ModeKeys.PREDICT:
        return

    sample_weights = None

    # Get some sample inputs from the data_adapter
    iterator = _create_dataset_iterator(model._distribution_strategy,
                                        adapter.get_dataset())
    inputs = create_batch_inputs(iterator, mode, model,
                                 model._distribution_strategy)
    # `inputs` is the model's inputs + targets + sample_weights +
    # learning phase placeholder if specified. To update the sample_weight_mode
    # we need to determine if the user has passed sample weights as part of the
    # input.
    if not callable(inputs):
        # if not isinstance(inputs, collections.Sequence):
        #   inputs = (inputs,)
        # Note that the batch inputs should be a tuple of 2, 3 or 4 items.
        # (input, target, {sample_weights}, {learning_phase})
        sample_weights_index = 0
        if model._feed_inputs:
            sample_weights_index += 1
        if model._feed_targets:
            sample_weights_index += 1

        sample_weights = inputs[sample_weights_index:]
        has_learning_phase_pl = (
            mode == ModeKeys.TRAIN
            and not isinstance(backend.symbolic_learning_phase(), int))
        if has_learning_phase_pl:
            sample_weights = sample_weights[:-1]
        model._update_sample_weight_modes(nest.flatten(sample_weights))

    # Call the DistributionStrategy specific function to update the
    # sample_weight_mode on the model.
    if model._distribution_strategy:
        dist_utils._update_sample_weight_modes(model, mode, sample_weights)

    # Force delete the iterator.
    del iterator
Exemplo n.º 3
0
def _update_sample_weight_mode(model, mode, dataset):
    """Updates the sample_weight_mode of a given model."""
    # TODO(kaftan): This won't actually do anything right now because
    ## dist_utils._update_sample_weight_modes only does things when the model
    ## is distributed by cloning. We will need to revisit if a method here
    ## is needed at all, and if so how it should look.
    # Add a quick return to prevent us from calling model._feed_targets that
    # accesses certain model properties that may not be set in the `PREDICT` mode.
    if mode == ModeKeys.PREDICT:
        return

    # Get some sample inputs from the data_adapter
    iterator = iter(dataset)
    _, _, sample_weights = training_v2_utils._prepare_feed_values(
        model, iterator, mode)

    # Call the DistributionStrategy specific function to update the
    # sample_weight_mode on the model.
    dist_utils._update_sample_weight_modes(model, mode, sample_weights)

    # Force delete the iterator.
    del iterator