Example #1
0
def create_estimator(params=None):
    """Returns the glyphs DNNClassifier estimator.

  Args:
    params: Optional hyperparameters, defaulting to command-line values.

  Returns:
    A DNNClassifier instance.
  """
    params = params or get_flag_params()
    if not params['layer_dims'] and params['activation_fn'] != 'sigmoid':
        tf.logging.warning(
            'activation_fn should be sigmoid for logistic regression. Got: %s',
            params['activation_fn'])

    activation_fn = getattr(tf.nn, params['activation_fn'])
    estimator = tf.estimator.DNNClassifier(
        params['layer_dims'],
        feature_columns=[glyph_patches.create_patch_feature_column()],
        weight_column=glyph_patches.WEIGHT_COLUMN_NAME,
        n_classes=len(musicscore_pb2.Glyph.Type.keys()),
        optimizer=tf.train.FtrlOptimizer(
            learning_rate=params['learning_rate'],
            l1_regularization_strength=params['l1_regularization_strength'],
            l2_regularization_strength=params['l2_regularization_strength'],
        ),
        activation_fn=activation_fn,
        dropout=FLAGS.dropout,
        model_dir=glyph_patches.FLAGS.model_dir,
    )
    return hyperparameters.estimator_with_saved_params(
        tf.contrib.estimator.add_metrics(estimator, _custom_metrics), params)
 def testSimpleModel(self):
   learning_rate = np.float32(0.123)
   params = {'learning_rate': learning_rate}
   estimator = hyperparameters.estimator_with_saved_params(
       tf.estimator.DNNClassifier(
           hidden_units=[10],
           feature_columns=[tf.feature_column.numeric_column('feature')]),
       params)
   with self.test_session():
     # Build the estimator model.
     estimator.model_fn(
         features={'feature': tf.placeholder(tf.float32)},
         labels=tf.placeholder(tf.float32),
         mode='TRAIN',
         config=None)
     # We should be able to pull hyperparameters out of the TensorFlow graph.
     # The entire graph will also be written to the saved model in training.
     self.assertEqual(
         learning_rate,
         tf.get_default_graph().get_tensor_by_name(
             'params/learning_rate:0').eval())