def _make_train_function(self): # Only needed for graph mode and model_to_estimator. has_recompiled = self._recompile_weights_loss_and_weighted_metrics() self._check_trainable_weights_consistency() # If we have re-compiled the loss/weighted metric sub-graphs then create # train function even if one exists already. This is because # `_feed_sample_weights` list has been updated on re-compile. if getattr(self, "train_function", None) is None or has_recompiled: # Restore the compiled trainable state. current_trainable_state = self._get_trainable_state() self._set_trainable_state(self._compiled_trainable_state) inputs = (self._feed_inputs + self._feed_targets + self._feed_sample_weights) if not isinstance(backend.symbolic_learning_phase(), int): inputs += [backend.symbolic_learning_phase()] if isinstance(self.optimizer, (list, tuple)): linear_optimizer = self.optimizer[0] dnn_optimizer = self.optimizer[1] else: linear_optimizer = self.optimizer dnn_optimizer = self.optimizer with backend.get_graph().as_default(): with backend.name_scope("training"): # Training updates updates = [] linear_updates = linear_optimizer.get_updates( params=self.linear_model.trainable_weights, loss=self.total_loss, ) updates += linear_updates dnn_updates = dnn_optimizer.get_updates( params=self.dnn_model.trainable_weights, loss=self.total_loss, ) updates += dnn_updates # Unconditional updates updates += self.get_updates_for(None) # Conditional updates relevant to this model updates += self.get_updates_for(self.inputs) metrics = self._get_training_eval_metrics() metrics_tensors = [ m._call_result for m in metrics if hasattr(m, "_call_result") ] with backend.name_scope("training"): # Gets loss and metrics. Updates weights at each call. fn = backend.function(inputs, [self.total_loss] + metrics_tensors, updates=updates, name="train_function", **self._function_kwargs) setattr(self, "train_function", fn) # Restore the current trainable state self._set_trainable_state(current_trainable_state)
def _update_sample_weight_mode(model, mode, inputs): """Updates the sample_weight_mode of a given model.""" # Add a quick return to prevent us from calling model._feed_targets that # accesses certain model properties that may not be set in the `PREDICT` # mode. if mode == ModeKeys.PREDICT: return sample_weights = None # `inputs` is the model's inputs + targets + sample_weights + # learning phase placeholder if specified. To update the sample_weight_mode # we need to determine if the user has passed sample weights as part of the # input. if not callable(inputs): sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):] has_learning_phase_pl = mode == ModeKeys.TRAIN and not isinstance( backend.symbolic_learning_phase(), int) if has_learning_phase_pl: sample_weights = sample_weights[:-1] model._update_sample_weight_modes(sample_weights=sample_weights) # Call the DistributionStrategy specific function to update the # sample_weight_mode on the model. if model._distribution_strategy: distributed_training_utils_v1._update_sample_weight_modes( model, mode, sample_weights)
def _prepare_feed_values(model, inputs, targets, sample_weights, mode): """Prepare feed values to the model execution function. Args: model: Model to prepare feed values for. inputs: List or dict of model inputs. targets: Optional list of model targets. sample_weights: Optional list of sample weight arrays. mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. Returns: Feed values for the model in the given mode. """ if model._distribution_strategy: if isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)): inputs = distributed_training_utils_v1.get_iterator( inputs, model._distribution_strategy) def get_distributed_inputs(): return distributed_training_utils_v1._prepare_feed_values( model, inputs, targets, sample_weights, mode) # In the eager case, we want to call the input method per step, so # return a lambda from here that can be called. Note that this is # applicable only in Distribution Strategy case as it follows the same # code path for both eager and graph modes. # TODO(priyag,omalleyt): Either we should move the training DS with # IteratorBase to use training_generator code path, or figure out how to # set a symbolic Iterator out of a Dataset when in eager mode. if tf.executing_eagerly(): return get_distributed_inputs else: return get_distributed_inputs() if isinstance( inputs, ( tf.compat.v1.data.Dataset, tf.data.Dataset, tf.compat.v1.data.Iterator, ), ): inputs, targets, sample_weights = model._standardize_user_data( inputs, extract_tensors_from_dataset=True) inputs = training_utils_v1.ModelInputs(inputs).as_list() targets = list(targets or []) sample_weights = list(sample_weights or []) ins = inputs + targets + sample_weights if mode == ModeKeys.TRAIN and not isinstance( backend.symbolic_learning_phase(), int): ins += [True] # Add learning phase value. return ins