コード例 #1
0
ファイル: wide_deep.py プロジェクト: paolodedios/keras
    def _make_train_function(self):
        # Only needed for graph mode and model_to_estimator.
        has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
        self._check_trainable_weights_consistency()
        # If we have re-compiled the loss/weighted metric sub-graphs then create
        # train function even if one exists already. This is because
        # `_feed_sample_weights` list has been updated on re-compile.
        if getattr(self, "train_function", None) is None or has_recompiled:
            # Restore the compiled trainable state.
            current_trainable_state = self._get_trainable_state()
            self._set_trainable_state(self._compiled_trainable_state)

            inputs = (self._feed_inputs + self._feed_targets +
                      self._feed_sample_weights)
            if not isinstance(backend.symbolic_learning_phase(), int):
                inputs += [backend.symbolic_learning_phase()]

            if isinstance(self.optimizer, (list, tuple)):
                linear_optimizer = self.optimizer[0]
                dnn_optimizer = self.optimizer[1]
            else:
                linear_optimizer = self.optimizer
                dnn_optimizer = self.optimizer

            with backend.get_graph().as_default():
                with backend.name_scope("training"):
                    # Training updates
                    updates = []
                    linear_updates = linear_optimizer.get_updates(
                        params=self.linear_model.trainable_weights,
                        loss=self.total_loss,
                    )
                    updates += linear_updates
                    dnn_updates = dnn_optimizer.get_updates(
                        params=self.dnn_model.trainable_weights,
                        loss=self.total_loss,
                    )
                    updates += dnn_updates
                    # Unconditional updates
                    updates += self.get_updates_for(None)
                    # Conditional updates relevant to this model
                    updates += self.get_updates_for(self.inputs)

                metrics = self._get_training_eval_metrics()
                metrics_tensors = [
                    m._call_result for m in metrics
                    if hasattr(m, "_call_result")
                ]

            with backend.name_scope("training"):
                # Gets loss and metrics. Updates weights at each call.
                fn = backend.function(inputs,
                                      [self.total_loss] + metrics_tensors,
                                      updates=updates,
                                      name="train_function",
                                      **self._function_kwargs)
                setattr(self, "train_function", fn)

            # Restore the current trainable state
            self._set_trainable_state(current_trainable_state)
コード例 #2
0
def _update_sample_weight_mode(model, mode, inputs):
    """Updates the sample_weight_mode of a given model."""
    # Add a quick return to prevent us from calling model._feed_targets that
    # accesses certain model properties that may not be set in the `PREDICT`
    # mode.
    if mode == ModeKeys.PREDICT:
        return

    sample_weights = None
    # `inputs` is the model's inputs + targets + sample_weights +
    # learning phase placeholder if specified. To update the sample_weight_mode
    # we need to determine if the user has passed sample weights as part of the
    # input.
    if not callable(inputs):
        sample_weights = inputs[len(model._feed_inputs) +
                                len(model._feed_targets):]
        has_learning_phase_pl = mode == ModeKeys.TRAIN and not isinstance(
            backend.symbolic_learning_phase(), int)
        if has_learning_phase_pl:
            sample_weights = sample_weights[:-1]
        model._update_sample_weight_modes(sample_weights=sample_weights)

    # Call the DistributionStrategy specific function to update the
    # sample_weight_mode on the model.
    if model._distribution_strategy:
        distributed_training_utils_v1._update_sample_weight_modes(
            model, mode, sample_weights)
コード例 #3
0
def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
    """Prepare feed values to the model execution function.

    Args:
      model: Model to prepare feed values for.
      inputs: List or dict of model inputs.
      targets: Optional list of model targets.
      sample_weights: Optional list of sample weight arrays.
      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.

    Returns:
      Feed values for the model in the given mode.
    """
    if model._distribution_strategy:
        if isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
            inputs = distributed_training_utils_v1.get_iterator(
                inputs, model._distribution_strategy)

        def get_distributed_inputs():
            return distributed_training_utils_v1._prepare_feed_values(
                model, inputs, targets, sample_weights, mode)

        # In the eager case, we want to call the input method per step, so
        # return a lambda from here that can be called. Note that this is
        # applicable only in Distribution Strategy case as it follows the same
        # code path for both eager and graph modes.
        # TODO(priyag,omalleyt): Either we should move the training DS with
        # IteratorBase to use training_generator code path, or figure out how to
        # set a symbolic Iterator out of a Dataset when in eager mode.
        if tf.executing_eagerly():
            return get_distributed_inputs
        else:
            return get_distributed_inputs()

    if isinstance(
            inputs,
        (
            tf.compat.v1.data.Dataset,
            tf.data.Dataset,
            tf.compat.v1.data.Iterator,
        ),
    ):
        inputs, targets, sample_weights = model._standardize_user_data(
            inputs, extract_tensors_from_dataset=True)

    inputs = training_utils_v1.ModelInputs(inputs).as_list()
    targets = list(targets or [])
    sample_weights = list(sample_weights or [])
    ins = inputs + targets + sample_weights
    if mode == ModeKeys.TRAIN and not isinstance(
            backend.symbolic_learning_phase(), int):
        ins += [True]  # Add learning phase value.
    return ins