def setUp(self): super(KerasModelToEstimatorTest, self).setUp() (context_feature_columns, example_feature_columns, custom_objects) = _get_feature_columns() self._context_feature_columns = context_feature_columns self._example_feature_columns = example_feature_columns # Remove label feature from example feature column. del self._example_feature_columns[_LABEL_FEATURE] self._custom_objects = custom_objects self._network = _DummyUnivariateRankingNetwork( context_feature_columns=self._context_feature_columns, example_feature_columns=self._example_feature_columns) self._loss = losses.get( losses.RankingLossKey.SOFTMAX_LOSS, reduction=tf.compat.v2.losses.Reduction.SUM_OVER_BATCH_SIZE) self._eval_metrics = metrics.default_keras_metrics() self._optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1) self._config = tf.estimator.RunConfig(keep_checkpoint_max=2, save_checkpoints_secs=2) self._data_file = os.path.join(tf.compat.v1.test.get_temp_dir(), 'test_elwc.tfrecord') serialized_elwc_list = [ _ELWC_PROTO.SerializeToString(), ] * 20 if tf.io.gfile.exists(self._data_file): tf.io.gfile.remove(self._data_file) with tf.io.TFRecordWriter(self._data_file) as writer: for serialized_elwc in serialized_elwc_list: writer.write(serialized_elwc)
def setUp(self): super(KerasModelToEstimatorTest, self).setUp() self.context_feature_columns = _context_feature_columns() self.example_feature_columns = _example_feature_columns() self.features = _features() self.network = _DummyUnivariateRankingNetwork( context_feature_columns=self.context_feature_columns, example_feature_columns=self.example_feature_columns) self.loss = losses.get( losses.RankingLossKey.SOFTMAX_LOSS, reduction=tf.compat.v2.losses.Reduction.SUM_OVER_BATCH_SIZE) self.eval_metrics = metrics.default_keras_metrics() self.optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1) self.config = tf.estimator.RunConfig( keep_checkpoint_max=2, save_checkpoints_secs=2)
def test_default_keras_metrics(self): default_metrics = metrics_lib.default_keras_metrics() self.assertLen(default_metrics, 11) for metric in default_metrics: self.assertIsInstance(metric, tf.keras.metrics.Metric)
def _model_fn(features, labels, mode, params, config): """Defines an `Estimator` `model_fn`.""" del [config, params] # In Estimator, all sub-graphs need to be constructed inside the model_fn. # Hence, ranker, losses, metrics and optimizer are cloned inside this # function. ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn) training = (mode == tf.compat.v1.estimator.ModeKeys.TRAIN) weights = None if weights_feature_name and mode != tf.compat.v1.estimator.ModeKeys.PREDICT: if weights_feature_name not in features: raise ValueError( "weights_feature '{0}' can not be found in 'features'.". format(weights_feature_name)) else: weights = utils.reshape_to_2d( features.pop(weights_feature_name)) logits = ranker(features, training=training) if mode == tf.compat.v1.estimator.ModeKeys.PREDICT: return tf.compat.v1.estimator.EstimatorSpec(mode=mode, predictions=logits) loss = _clone_fn(model.loss) total_loss = loss(labels, logits, sample_weight=weights) keras_metrics = [] for metric in model.metrics: keras_metrics.append(_clone_fn(metric)) # Adding default metrics here as model.metrics does not contain custom # metrics. keras_metrics += metrics.default_keras_metrics() eval_metric_ops = {} for keras_metric in keras_metrics: keras_metric.update_state(labels, logits, sample_weight=weights) eval_metric_ops[keras_metric.name] = keras_metric train_op = None if training: optimizer = _clone_fn(model.optimizer) optimizer.iterations = tf.compat.v1.train.get_or_create_global_step( ) # Get both the unconditional updates (the None part) # and the input-conditional updates (the features part). # These updates are for layers like BatchNormalization, which have # separate update and minimize ops. update_ops = ranker.get_updates_for(None) + ranker.get_updates_for( features) minimize_op = optimizer.get_updates( loss=total_loss, params=ranker.trainable_variables)[0] train_op = tf.group(minimize_op, *update_ops) return tf.compat.v1.estimator.EstimatorSpec( mode=mode, predictions=logits, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops)