Example #1
0
 def testNonProtoFails(self):
   with self.assertRaisesRegexp(AssertionError, ''):
     hparam.HParams(hparam_def=1)
   with self.assertRaisesRegexp(AssertionError, ''):
     hparam.HParams(hparam_def=1.0)
   with self.assertRaisesRegexp(AssertionError, ''):
     hparam.HParams(hparam_def='hello')
   with self.assertRaisesRegexp(AssertionError, ''):
     hparam.HParams(hparam_def=[1, 2, 3])
Example #2
0
  def testSetFromMap(self):
    hparams = hparam.HParams(a=1, b=2.0, c='tanh')
    hparams.override_from_dict({'a': -2, 'c': 'identity'})
    self.assertDictEqual({'a': -2, 'c': 'identity', 'b': 2.0}, hparams.values())

    hparams = hparam.HParams(x=1, b=2.0, d=[0.5])
    hparams.override_from_dict({'d': [0.1, 0.2, 0.3]})
    self.assertDictEqual({
        'd': [0.1, 0.2, 0.3],
        'x': 1,
        'b': 2.0
    }, hparams.values())
Example #3
0
 def testStr(self):
   hparam1 = hparam.HParams(a=1, b=[2.0, 3.0], c='relu6')
   hparam1_str = str(hparam1)
   # Create the signature
   hparam2 = hparam.HParams()
   hparam2.add_hparam('a', 4)
   hparam2.add_hparam('b', [5.0, 6.0])
   hparam2.add_hparam('c', 'relu10')
   # Load from string
   hparam2.parse(hparam1_str)
   # Verifies all hparams are restored
   self.assertEqual(hparam2.a, hparam1.a)
   self.assertEqual(hparam2.b, hparam1.b)
   self.assertEqual(hparam2.c, hparam1.c)
Example #4
0
  def testSetHParam(self):
    hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True)
    self.assertDictEqual({
        'aaa': 1,
        'b': 2.0,
        'c_c': 'relu6',
        'd': True
    }, hparams.values())
    self.assertEqual(1, hparams.aaa)
    self.assertEqual(2.0, hparams.b)
    self.assertEqual('relu6', hparams.c_c)

    hparams.set_hparam('aaa', 12)
    hparams.set_hparam('b', 3.0)
    hparams.set_hparam('c_c', 'relu4')
    hparams.set_hparam('d', False)
    self.assertDictEqual({
        'aaa': 12,
        'b': 3.0,
        'c_c': 'relu4',
        'd': False
    }, hparams.values())
    self.assertEqual(12, hparams.aaa)
    self.assertEqual(3.0, hparams.b)
    self.assertEqual('relu4', hparams.c_c)
Example #5
0
 def testEmpty(self):
   hparams = hparam.HParams()
   self.assertDictEqual({}, hparams.values())
   hparams.parse('')
   self.assertDictEqual({}, hparams.values())
   with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
     hparams.parse('xyz=123')
Example #6
0
  def testBoolParsing(self):
    for value in 'true', 'false', 'True', 'False', '1', '0':
      for initial in False, True:
        hparams = hparam.HParams(use_gpu=initial)
        hparams.parse('use_gpu=' + value)
        self.assertEqual(hparams.use_gpu, value in ['True', 'true', '1'])

        # Exports to proto.
        hparam_def = hparams.to_proto()
        # Imports from proto.
        hparams2 = hparam.HParams(hparam_def=hparam_def)
        self.assertEqual(hparams.use_gpu, hparams2.use_gpu)
        # Check that hparams2.use_gpu is a bool rather than an int.
        # The assertEqual() call above won't catch this, since
        # (0 == False) and (1 == True) in Python.
        self.assertEqual(bool, type(hparams2.use_gpu))
Example #7
0
 def testLists(self):
   hparams = hparam.HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6'])
   self.assertDictEqual({
       'aaa': [1],
       'b': [2.0, 3.0],
       'c_c': ['relu6']
   }, hparams.values())
   self.assertEqual([1], hparams.aaa)
   self.assertEqual([2.0, 3.0], hparams.b)
   self.assertEqual(['relu6'], hparams.c_c)
   hparams.parse('aaa=[12]')
   self.assertEqual([12], hparams.aaa)
   hparams.parse('aaa=[12,34,56]')
   self.assertEqual([12, 34, 56], hparams.aaa)
   hparams.parse('c_c=[relu4,relu12],b=[1.0]')
   self.assertEqual(['relu4', 'relu12'], hparams.c_c)
   self.assertEqual([1.0], hparams.b)
   hparams.parse('c_c=[],aaa=[-34]')
   self.assertEqual([-34], hparams.aaa)
   self.assertEqual([], hparams.c_c)
   hparams.parse('c_c=[_12,3\'4"],aaa=[+3]')
   self.assertEqual([3], hparams.aaa)
   self.assertEqual(['_12', '3\'4"'], hparams.c_c)
   with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
     hparams.parse('x=[123]')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('aaa=[poipoi]')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('aaa=[1.0]')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('b=[12x]')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('b=[relu]')
   with self.assertRaisesRegexp(ValueError, 'Must pass a list'):
     hparams.parse('aaa=123')
   # Exports to proto.
   hparam_def = hparams.to_proto()
   # Imports from proto.
   hparams2 = hparam.HParams(hparam_def=hparam_def)
   # Verifies that all hparams are restored.
   self.assertEqual([3], hparams2.aaa)
   self.assertEqual([1.0], hparams2.b)
   self.assertEqual(['_12', '3\'4"'], hparams2.c_c)
Example #8
0
  def testSetHParamExactTypeMatch(self):

    class DummyContext(object):

      def __init__(self, a, b=0):
        self.a = a
        self.b = b

    hparams = hparam.HParams(x=DummyContext(a=100, b=100))
    # Verify x is assigned directly, without casting.
    hparams.set_hparam('x', DummyContext(a=100, b=100))
    self.assertEqual(hparams.x.a, 100)
    self.assertEqual(hparams.x.b, 100)
Example #9
0
  def testJson(self):
    hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True)
    self.assertDictEqual({
        'aaa': 1,
        'b': 2.0,
        'c_c': 'relu6',
        'd': True
    }, hparams.values())
    self.assertEqual(1, hparams.aaa)
    self.assertEqual(2.0, hparams.b)
    self.assertEqual('relu6', hparams.c_c)
    hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}')
    self.assertDictEqual({
        'aaa': 12,
        'b': 3.0,
        'c_c': 'relu4',
        'd': False
    }, hparams.values())
    self.assertEqual(12, hparams.aaa)
    self.assertEqual(3.0, hparams.b)
    self.assertEqual('relu4', hparams.c_c)

    json_str = hparams.to_json()
    hparams2 = hparam.HParams(aaa=10, b=20.0, c_c='hello', d=False)
    hparams2.parse_json(json_str)
    self.assertEqual(12, hparams2.aaa)
    self.assertEqual(3.0, hparams2.b)
    self.assertEqual('relu4', hparams2.c_c)
    self.assertEqual(False, hparams2.d)

    hparams3 = hparam.HParams(aaa=123)
    self.assertEqual('{"aaa": 123}', hparams3.to_json())
    self.assertEqual('{\n  "aaa": 123\n}', hparams3.to_json(indent=2))
    self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(';', '=')))

    hparams4 = hparam.HParams(aaa=123, b='hello', c_c=False)
    self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}',
                     hparams4.to_json(sort_keys=True))
Example #10
0
 def testWithPeriodInVariableName(self):
   hparams = hparam.HParams()
   hparams.add_hparam(name='a.b', value=0.0)
   hparams.parse('a.b=1.0')
   self.assertEqual(1.0, getattr(hparams, 'a.b'))
   hparams.add_hparam(name='c.d', value=0.0)
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('c.d=abc')
   hparams.add_hparam(name='e.f', value='')
   hparams.parse('e.f=abc')
   self.assertEqual('abc', getattr(hparams, 'e.f'))
   hparams.add_hparam(name='d..', value=0.0)
   hparams.parse('d..=10.0')
   self.assertEqual(10.0, getattr(hparams, 'd..'))
Example #11
0
  def testDel(self):
    hparams = hparam.HParams(aaa=1, b=2.0)

    with self.assertRaises(ValueError):
      hparams.set_hparam('aaa', 'will fail')

    with self.assertRaises(ValueError):
      hparams.add_hparam('aaa', 'will fail')

    hparams.del_hparam('aaa')
    hparams.add_hparam('aaa', 'will work')
    self.assertEqual('will work', hparams.get('aaa'))

    hparams.set_hparam('aaa', 'still works')
    self.assertEqual('still works', hparams.get('aaa'))
Example #12
0
  def testSetHParamTypeMismatch(self):
    hparams = hparam.HParams(
        int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None)

    with self.assertRaises(ValueError):
      hparams.set_hparam('str_', 2.2)

    with self.assertRaises(ValueError):
      hparams.set_hparam('int_', False)

    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', 1)

    # Unfortunately there is no automagic conversion of bool-like strings to
    # bool.
    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', 'true')

    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', 'True')

    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', 'false')

    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', 'False')

    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', '0')

    with self.assertRaises(ValueError):
      hparams.set_hparam('bool_', '1')

    with self.assertRaises(ValueError):
      hparams.set_hparam('int_', 2.2)

    with self.assertRaises(ValueError):
      hparams.set_hparam('list_int', [2, 3.3])

    with self.assertRaises(ValueError):
      hparams.set_hparam('int_', '2')

    # Casting int to float is OK
    hparams.set_hparam('float_', 1)

    # Getting stuck with NoneType :(
    hparams.set_hparam('none', '1')
    self.assertEqual('1', hparams.none)
Example #13
0
  def testGet(self):
    hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0])

    # Existing parameters with default=None.
    self.assertEqual(1, hparams.get('aaa'))
    self.assertEqual(2.0, hparams.get('b'))
    self.assertEqual('relu6', hparams.get('c_c'))
    self.assertEqual(True, hparams.get('d'))
    self.assertEqual([5.0, 6.0], hparams.get('e', None))

    # Existing parameters with compatible defaults.
    self.assertEqual(1, hparams.get('aaa', 2))
    self.assertEqual(2.0, hparams.get('b', 3.0))
    self.assertEqual(2.0, hparams.get('b', 3))
    self.assertEqual('relu6', hparams.get('c_c', 'default'))
    self.assertEqual(True, hparams.get('d', True))
    self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0]))
    self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3]))

    # Existing parameters with incompatible defaults.
    with self.assertRaises(ValueError):
      hparams.get('aaa', 2.0)

    with self.assertRaises(ValueError):
      hparams.get('b', False)

    with self.assertRaises(ValueError):
      hparams.get('c_c', [1, 2, 3])

    with self.assertRaises(ValueError):
      hparams.get('d', 'relu')

    with self.assertRaises(ValueError):
      hparams.get('e', 123.0)

    with self.assertRaises(ValueError):
      hparams.get('e', ['a', 'b', 'c'])

    # Nonexistent parameters.
    self.assertEqual(None, hparams.get('unknown'))
    self.assertEqual(123, hparams.get('unknown', 123))
    self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
Example #14
0
 def testSetHParamListNonListMismatch(self):
   hparams = hparam.HParams(a=1, b=[2.0, 3.0])
   with self.assertRaisesRegexp(ValueError, r'Must not pass a list'):
     hparams.set_hparam('a', [1.0])
   with self.assertRaisesRegexp(ValueError, r'Must pass a list'):
     hparams.set_hparam('b', 1.0)
class XlaDecoratorTest(test.TestCase, parameterized.TestCase):

  @parameterized.named_parameters(
      ('test_use_as_decorator', decorated_model_fn, None),
      ('test_use_as_function', xla.estimator_model_fn(_test_train_model_fn),
       None),
      ('test_use_tpu_false_hparams', decorated_model_fn,
       hparam.HParams(use_tpu=False)),
      ('test_use_tpu_false_dict_params', decorated_model_fn, {
          'use_tpu': False
      }),
  )
  def test_compile(self, model_fn, params):
    """Calls model_fn and verifies it is compiled."""
    with test.mock.patch.object(xla, 'compile') as mock_xla_compile:
      loss = constant_op.constant(_EXPECTED_LOSS)
      mock_xla_compile.return_value = [loss]

      features, labels = make_dummy_features_labels()
      estimator_spec = model_fn(
          features=features, labels=labels, mode=_TRAIN, params=params or {})

      self.assertEqual(mock_xla_compile.call_count, 1)
      self.assertEqual(estimator_spec.mode, _TRAIN)

      with self.test_session() as sess:
        self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss))
        self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss))

  @parameterized.named_parameters(
      ('test_use_tpu_true_hparams', decorated_model_fn,
       hparam.HParams(use_tpu=True)),
      ('test_use_tpu_true_dict_params', decorated_model_fn, {
          'use_tpu': True
      }),
  )
  def test_not_compile(self, model_fn, params):
    """Calls model_fn and verifies it is NOT compiled."""
    with test.mock.patch.object(xla, 'compile') as mock_xla_compile:
      loss = constant_op.constant(_EXPECTED_LOSS)
      mock_xla_compile.return_value = [loss]

      features, labels = make_dummy_features_labels()
      estimator_spec = model_fn(
          features=features, labels=labels, mode=_TRAIN, params=params or {})

      mock_xla_compile.assert_not_called()
      self.assertEqual(estimator_spec.mode, _TRAIN)

      with self.test_session() as sess:
        self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss))
        self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss))

  def test_model_with_summary(self):
    """Tests that summary ops are disabled."""

    @xla.estimator_model_fn
    def model_fn_with_summary(features, labels, mode, params):
      del features, labels, params
      loss = constant_op.constant(_EXPECTED_LOSS)
      summary.scalar('loss_scalar_summary', loss)
      summary.histogram('loss_histogram_summary', loss)
      summary.image('loss_image_summary', loss)
      return model_fn_lib.EstimatorSpec(
          mode=mode, loss=loss, train_op=array_ops.identity(loss))

    features, labels = make_dummy_features_labels()
    estimator_spec = model_fn_with_summary(
        features=features, labels=labels, mode=_TRAIN, params={})

    with self.test_session() as sess:
      self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS)
Example #16
0
 def testBoolParsingFail(self):
   hparams = hparam.HParams(use_gpu=True)
   with self.assertRaisesRegexp(ValueError, r'Could not parse.*use_gpu'):
     hparams.parse('use_gpu=yep')
Example #17
0
 def testContains(self):
   hparams = hparam.HParams(foo=1)
   self.assertTrue('foo' in hparams)
   self.assertFalse('bar' in hparams)
Example #18
0
 def testSomeValues(self):
   hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d')
   self.assertDictEqual({
       'aaa': 1,
       'b': 2.0,
       'c_c': 'relu6',
       'd': '/a/b=c/d'
   }, hparams.values())
   expected_str = ('HParams([(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'),'
                   ' (\'d\', \'/a/b=c/d\')])')
   self.assertEqual(expected_str, repr(hparams))
   self.assertEqual(expected_str, repr(hparams))
   self.assertEqual(1, hparams.aaa)
   self.assertEqual(2.0, hparams.b)
   self.assertEqual('relu6', hparams.c_c)
   self.assertEqual('/a/b=c/d', hparams.d)
   hparams.parse('aaa=12')
   self.assertDictEqual({
       'aaa': 12,
       'b': 2.0,
       'c_c': 'relu6',
       'd': '/a/b=c/d'
   }, hparams.values())
   self.assertEqual(12, hparams.aaa)
   self.assertEqual(2.0, hparams.b)
   self.assertEqual('relu6', hparams.c_c)
   self.assertEqual('/a/b=c/d', hparams.d)
   hparams.parse('c_c=relu4, b=-2.0e10')
   self.assertDictEqual({
       'aaa': 12,
       'b': -2.0e10,
       'c_c': 'relu4',
       'd': '/a/b=c/d'
   }, hparams.values())
   self.assertEqual(12, hparams.aaa)
   self.assertEqual(-2.0e10, hparams.b)
   self.assertEqual('relu4', hparams.c_c)
   self.assertEqual('/a/b=c/d', hparams.d)
   hparams.parse('c_c=,b=0,')
   self.assertDictEqual({
       'aaa': 12,
       'b': 0,
       'c_c': '',
       'd': '/a/b=c/d'
   }, hparams.values())
   self.assertEqual(12, hparams.aaa)
   self.assertEqual(0.0, hparams.b)
   self.assertEqual('', hparams.c_c)
   self.assertEqual('/a/b=c/d', hparams.d)
   hparams.parse('c_c=2.3",b=+2,')
   self.assertEqual(2.0, hparams.b)
   self.assertEqual('2.3"', hparams.c_c)
   hparams.parse('d=/a/b/c/d,aaa=11,')
   self.assertEqual(11, hparams.aaa)
   self.assertEqual(2.0, hparams.b)
   self.assertEqual('2.3"', hparams.c_c)
   self.assertEqual('/a/b/c/d', hparams.d)
   hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,')
   self.assertEqual(10, hparams.aaa)
   self.assertEqual(1.5, hparams.b)
   self.assertEqual('2.3"', hparams.c_c)
   self.assertEqual('/a=b/c/d', hparams.d)
   with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
     hparams.parse('x=123')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('aaa=poipoi')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('aaa=1.0')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('b=12x')
   with self.assertRaisesRegexp(ValueError, 'Could not parse'):
     hparams.parse('b=relu')
   with self.assertRaisesRegexp(ValueError, 'Must not pass a list'):
     hparams.parse('aaa=[123]')
   self.assertEqual(10, hparams.aaa)
   self.assertEqual(1.5, hparams.b)
   self.assertEqual('2.3"', hparams.c_c)
   self.assertEqual('/a=b/c/d', hparams.d)
   # Exports to proto.
   hparam_def = hparams.to_proto()
   # Imports from proto.
   hparams2 = hparam.HParams(hparam_def=hparam_def)
   # Verifies that all hparams are restored.
   self.assertEqual(10, hparams2.aaa)
   self.assertEqual(1.5, hparams2.b)
   self.assertEqual('2.3"', hparams2.c_c)
   self.assertEqual('/a=b/c/d', hparams2.d)
Example #19
0
def get_pruning_hparams():
  """Get a tf.HParams object with the default values for the hyperparameters.

    name: string
      name of the pruning specification. Used for adding summaries and ops under
      a common tensorflow name_scope
    begin_pruning_step: integer
      the global step at which to begin pruning
    end_pruning_step: integer
      the global step at which to terminate pruning. Defaults to -1 implying
      that pruning continues till the training stops
    weight_sparsity_map: list of strings
       comma separed list of {weight_variable_name:target sparsity} or
       {regex:target sparsity} pairs.
       For layers/weights not in this list, sparsity as specified by the
       target_sparsity hyperparameter is used.
       Eg. [conv1:0.9,conv2/kernel:0.8]
    block_dims_map: list of strings
       comma separated list of {weight variable name:block_height x block_width}
       or {regex:block_height x block_width} pairs. For layers/weights not in
       this list, block dims are specified by the block_height, block_width
       hyperparameters are used Eg. [dense1:4x4,dense2:1x16,dense3:1x1]
    threshold_decay: float
      the decay factor to use for exponential decay of the thresholds
    pruning_frequency: integer
      How often should the masks be updated? (in # of global_steps)
    nbins: integer
      number of bins to use for histogram computation
    block_height: integer
      number of rows in a block (defaults to 1), can be -1 in which
      case it is set to the size of the corresponding weight tensor.
    block_width: integer
      number of cols in a block (defaults to 1), can be -1 in which
      case it is set to the size of the corresponding weight tensor.
    block_pooling_function: string
      Whether to perform average (AVG) or max (MAX) pooling in the block
      (default: AVG)
    initial_sparsity: float
      initial sparsity value
    target_sparsity: float
      target sparsity value
    sparsity_function_begin_step: integer
      the global step at this which the gradual sparsity function begins to
      take effect
    sparsity_function_end_step: integer
      the global step used as the end point for the gradual sparsity function
    sparsity_function_exponent: float
      exponent = 1 is linearly varying sparsity between initial and final.
      exponent > 1 varies more slowly towards the end than the beginning
    use_tpu: False
      Indicates whether to use TPU

    We use the following sparsity function:

    num_steps = (sparsity_function_end_step -
                 sparsity_function_begin_step)/pruning_frequency
    sparsity(step) = (initial_sparsity - target_sparsity)*
                     [1-step/(num_steps -1)]**exponent + target_sparsity

  Args:
    None

  Returns:
    tf.HParams object initialized to default values

  """
  return hparam.HParams(
      name='model_pruning',
      begin_pruning_step=0,
      end_pruning_step=-1,
      weight_sparsity_map=[''],
      block_dims_map=[''],
      threshold_decay=0.0,
      pruning_frequency=10,
      nbins=256,
      block_height=1,
      block_width=1,
      block_pooling_function='AVG',
      initial_sparsity=0.0,
      target_sparsity=0.5,
      sparsity_function_begin_step=0,
      sparsity_function_end_step=100,
      sparsity_function_exponent=3.0,
      use_tpu=False)
Example #20
0
from astronet.contrib.learn.python.learn import evaluable  # pylint: disable=g-import-not-at-top
from astronet.contrib.learn.python.learn import experiment
from astronet.contrib.learn.python.learn import learn_runner
from astronet.contrib.learn.python.learn import trainable

from astronet.contrib.learn.python.learn.estimators import run_config as run_config_lib
from astronet.contrib.training.python.training import hparam as hparam_lib
from tensorflow.python.estimator import run_config as core_run_config_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging

patch = test.mock.patch

_MODIR_DIR = "/tmp"
_HPARAMS = hparam_lib.HParams(learning_rate=0.01)
_MUST_SPECIFY_OUTPUT_DIR_MSG = "Must specify an output directory"
_MISSING_MODEL_DIR_ERR_MSG = (
    "Must specify a model directory `model_dir` in `run_config`.")
_EXP_NOT_CALLABLE_MSG = "Experiment builder .* is not callable"
_INVALID_HPARAMS_ERR_MSG = "`hparams` must be `HParams` instance"
_NOT_EXP_TYPE_MSG = "Experiment builder did not return an Experiment"
_NON_EXIST_TASK_MSG = "Schedule references non-existent task"
_NON_CALLABLE_MSG = "Schedule references non-callable member"
_MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG = (
    "Must set value for `output_dir` or `run_config`")
_HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG = (
    "Must set `hparams` as None for `experiment_fn` with `output_dir`.")
_CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG = (
    "Cannot provide both `output_dir` and `run_config`")
_INVALID_RUN_CONFIG_TYPE_MSG = (