def test_learning_rate_schedule(self): options_str = r""" piecewise_constant_decay{ values: 0.001 } """ options = text_format.Merge( options_str, learning_rate_schedule_pb2.LearningRateSchedule()) schedule = learning_rate_schedule.create_learning_rate_schedule( options) self.assertIsInstance( schedule, tf.keras.optimizers.schedules.PiecewiseConstantDecay) options_str = r""" exponential_decay { initial_learning_rate: 0.001 decay_steps: 1000 decay_rate: 1.0 } """ options = text_format.Merge( options_str, learning_rate_schedule_pb2.LearningRateSchedule()) schedule = learning_rate_schedule.create_learning_rate_schedule( options) self.assertIsInstance(schedule, tf.keras.optimizers.schedules.ExponentialDecay) options_str = r""" polynomial_decay { initial_learning_rate: 0.001 decay_steps: 1000 end_learning_rate: 0.0001 power: 1.0 cycle: true } """ options = text_format.Merge( options_str, learning_rate_schedule_pb2.LearningRateSchedule()) schedule = learning_rate_schedule.create_learning_rate_schedule( options) self.assertIsInstance(schedule, tf.keras.optimizers.schedules.PolynomialDecay)
def _model_fn(features, labels, mode, params): """Creates the model. Args: features: A dict mapping from names to tensors, denoting the features. labels: A dict mapping from names to tensors, denoting the labels. mode: Mode parameter required by the estimator. params: Additional parameters used for creating the model. Returns: An instance of EstimatorSpec. """ is_training = (tf.estimator.ModeKeys.TRAIN == mode) logging.info("Current mode is %s, is_training=%s", mode, is_training) model = builder.build(pipeline_proto.model, is_training) # Predict resutls. predictions = model.predict(features) # Compute losses. Note: variables created in build_loss are not trainable. losses = model.build_losses(features, predictions) for name, loss in losses.items(): tf.compat.v1.summary.scalar('losses/' + name, loss) tf.losses.add_loss(loss) for loss in tf.compat.v1.losses.get_regularization_losses(): tf.summary.scalar( "regularization/" + '/'.join(loss.op.name.split('/')[:2]), loss) total_loss = tf.compat.v1.losses.get_total_loss( add_regularization_losses=True) # Get variables_to_train. variables_to_train = model.get_variables_to_train() scaffold = model.get_scaffold() train_op = None eval_metric_ops = None if tf.estimator.ModeKeys.TRAIN == mode: _summarize_variables(tf.compat.v1.global_variables()) global_step = tf.compat.v1.train.get_global_step() # Set learning rate. train_config = pipeline_proto.train_config lr_schedule_fn = learning_rate_schedule.create_learning_rate_schedule( train_config.learning_rate_schedule) learning_rate = lr_schedule_fn(global_step) tf.compat.v1.summary.scalar('losses/learning_rate', learning_rate) # Use optimizer to minimize loss. optimizer = optimization.create_optimizer( train_config.optimizer, learning_rate=learning_rate) def transform_grads_fn(grads): if train_config.HasField('max_gradient_norm'): grads = tf.contrib.training.clip_gradient_norms( grads, max_norm=train_config.max_gradient_norm) return grads train_op = tf.contrib.training.create_train_op( total_loss, optimizer, variables_to_train=variables_to_train, transform_grads_fn=transform_grads_fn, summarize_gradients=True) elif tf.estimator.ModeKeys.EVAL == mode: eval_metric_ops = model.build_metrics(features, predictions) for name, loss in losses.items(): loss_metric = tf.keras.metrics.Mean() loss_metric.update_state(loss) eval_metric_ops['losses/' + name] = loss_metric elif tf.estimator.ModeKeys.PREDICT == mode: # Add input tensors to the predictions. predictions.update(features) # Create additional tensors if specified. create_additional_predictions = params.get( 'create_additional_predictions', None) if create_additional_predictions: assert callable(create_additional_predictions) predictions.update( create_additional_predictions(tf.get_default_graph())) # Merge summaries. summary_saver_hook = tf.estimator.SummarySaverHook( summary_op=tf.compat.v1.summary.merge_all(), save_steps=pipeline_proto.train_config.save_summary_steps) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops, training_hooks=[summary_saver_hook], scaffold=scaffold)