def testNonProtoFails(self): with self.assertRaisesRegexp(AssertionError, ''): hparam.HParams(hparam_def=1) with self.assertRaisesRegexp(AssertionError, ''): hparam.HParams(hparam_def=1.0) with self.assertRaisesRegexp(AssertionError, ''): hparam.HParams(hparam_def='hello') with self.assertRaisesRegexp(AssertionError, ''): hparam.HParams(hparam_def=[1, 2, 3])
def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) self.assertDictEqual({'a': -2, 'c': 'identity', 'b': 2.0}, hparams.values()) hparams = hparam.HParams(x=1, b=2.0, d=[0.5]) hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) self.assertDictEqual({ 'd': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0 }, hparams.values())
def testStr(self): hparam1 = hparam.HParams(a=1, b=[2.0, 3.0], c='relu6') hparam1_str = str(hparam1) # Create the signature hparam2 = hparam.HParams() hparam2.add_hparam('a', 4) hparam2.add_hparam('b', [5.0, 6.0]) hparam2.add_hparam('c', 'relu10') # Load from string hparam2.parse(hparam1_str) # Verifies all hparams are restored self.assertEqual(hparam2.a, hparam1.a) self.assertEqual(hparam2.b, hparam1.b) self.assertEqual(hparam2.c, hparam1.c)
def testSetHParam(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.set_hparam('aaa', 12) hparams.set_hparam('b', 3.0) hparams.set_hparam('c_c', 'relu4') hparams.set_hparam('d', False) self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c)
def testEmpty(self): hparams = hparam.HParams() self.assertDictEqual({}, hparams.values()) hparams.parse('') self.assertDictEqual({}, hparams.values()) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('xyz=123')
def testBoolParsing(self): for value in 'true', 'false', 'True', 'False', '1', '0': for initial in False, True: hparams = hparam.HParams(use_gpu=initial) hparams.parse('use_gpu=' + value) self.assertEqual(hparams.use_gpu, value in ['True', 'true', '1']) # Exports to proto. hparam_def = hparams.to_proto() # Imports from proto. hparams2 = hparam.HParams(hparam_def=hparam_def) self.assertEqual(hparams.use_gpu, hparams2.use_gpu) # Check that hparams2.use_gpu is a bool rather than an int. # The assertEqual() call above won't catch this, since # (0 == False) and (1 == True) in Python. self.assertEqual(bool, type(hparams2.use_gpu))
def testLists(self): hparams = hparam.HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6']) self.assertDictEqual({ 'aaa': [1], 'b': [2.0, 3.0], 'c_c': ['relu6'] }, hparams.values()) self.assertEqual([1], hparams.aaa) self.assertEqual([2.0, 3.0], hparams.b) self.assertEqual(['relu6'], hparams.c_c) hparams.parse('aaa=[12]') self.assertEqual([12], hparams.aaa) hparams.parse('aaa=[12,34,56]') self.assertEqual([12, 34, 56], hparams.aaa) hparams.parse('c_c=[relu4,relu12],b=[1.0]') self.assertEqual(['relu4', 'relu12'], hparams.c_c) self.assertEqual([1.0], hparams.b) hparams.parse('c_c=[],aaa=[-34]') self.assertEqual([-34], hparams.aaa) self.assertEqual([], hparams.c_c) hparams.parse('c_c=[_12,3\'4"],aaa=[+3]') self.assertEqual([3], hparams.aaa) self.assertEqual(['_12', '3\'4"'], hparams.c_c) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=[123]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[poipoi]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[1.0]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[12x]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[relu]') with self.assertRaisesRegexp(ValueError, 'Must pass a list'): hparams.parse('aaa=123') # Exports to proto. hparam_def = hparams.to_proto() # Imports from proto. hparams2 = hparam.HParams(hparam_def=hparam_def) # Verifies that all hparams are restored. self.assertEqual([3], hparams2.aaa) self.assertEqual([1.0], hparams2.b) self.assertEqual(['_12', '3\'4"'], hparams2.c_c)
def testSetHParamExactTypeMatch(self): class DummyContext(object): def __init__(self, a, b=0): self.a = a self.b = b hparams = hparam.HParams(x=DummyContext(a=100, b=100)) # Verify x is assigned directly, without casting. hparams.set_hparam('x', DummyContext(a=100, b=100)) self.assertEqual(hparams.x.a, 100) self.assertEqual(hparams.x.b, 100)
def testJson(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}') self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c) json_str = hparams.to_json() hparams2 = hparam.HParams(aaa=10, b=20.0, c_c='hello', d=False) hparams2.parse_json(json_str) self.assertEqual(12, hparams2.aaa) self.assertEqual(3.0, hparams2.b) self.assertEqual('relu4', hparams2.c_c) self.assertEqual(False, hparams2.d) hparams3 = hparam.HParams(aaa=123) self.assertEqual('{"aaa": 123}', hparams3.to_json()) self.assertEqual('{\n "aaa": 123\n}', hparams3.to_json(indent=2)) self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(';', '='))) hparams4 = hparam.HParams(aaa=123, b='hello', c_c=False) self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}', hparams4.to_json(sort_keys=True))
def testWithPeriodInVariableName(self): hparams = hparam.HParams() hparams.add_hparam(name='a.b', value=0.0) hparams.parse('a.b=1.0') self.assertEqual(1.0, getattr(hparams, 'a.b')) hparams.add_hparam(name='c.d', value=0.0) with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('c.d=abc') hparams.add_hparam(name='e.f', value='') hparams.parse('e.f=abc') self.assertEqual('abc', getattr(hparams, 'e.f')) hparams.add_hparam(name='d..', value=0.0) hparams.parse('d..=10.0') self.assertEqual(10.0, getattr(hparams, 'd..'))
def testDel(self): hparams = hparam.HParams(aaa=1, b=2.0) with self.assertRaises(ValueError): hparams.set_hparam('aaa', 'will fail') with self.assertRaises(ValueError): hparams.add_hparam('aaa', 'will fail') hparams.del_hparam('aaa') hparams.add_hparam('aaa', 'will work') self.assertEqual('will work', hparams.get('aaa')) hparams.set_hparam('aaa', 'still works') self.assertEqual('still works', hparams.get('aaa'))
def testSetHParamTypeMismatch(self): hparams = hparam.HParams( int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None) with self.assertRaises(ValueError): hparams.set_hparam('str_', 2.2) with self.assertRaises(ValueError): hparams.set_hparam('int_', False) with self.assertRaises(ValueError): hparams.set_hparam('bool_', 1) # Unfortunately there is no automagic conversion of bool-like strings to # bool. with self.assertRaises(ValueError): hparams.set_hparam('bool_', 'true') with self.assertRaises(ValueError): hparams.set_hparam('bool_', 'True') with self.assertRaises(ValueError): hparams.set_hparam('bool_', 'false') with self.assertRaises(ValueError): hparams.set_hparam('bool_', 'False') with self.assertRaises(ValueError): hparams.set_hparam('bool_', '0') with self.assertRaises(ValueError): hparams.set_hparam('bool_', '1') with self.assertRaises(ValueError): hparams.set_hparam('int_', 2.2) with self.assertRaises(ValueError): hparams.set_hparam('list_int', [2, 3.3]) with self.assertRaises(ValueError): hparams.set_hparam('int_', '2') # Casting int to float is OK hparams.set_hparam('float_', 1) # Getting stuck with NoneType :( hparams.set_hparam('none', '1') self.assertEqual('1', hparams.none)
def testGet(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0]) # Existing parameters with default=None. self.assertEqual(1, hparams.get('aaa')) self.assertEqual(2.0, hparams.get('b')) self.assertEqual('relu6', hparams.get('c_c')) self.assertEqual(True, hparams.get('d')) self.assertEqual([5.0, 6.0], hparams.get('e', None)) # Existing parameters with compatible defaults. self.assertEqual(1, hparams.get('aaa', 2)) self.assertEqual(2.0, hparams.get('b', 3.0)) self.assertEqual(2.0, hparams.get('b', 3)) self.assertEqual('relu6', hparams.get('c_c', 'default')) self.assertEqual(True, hparams.get('d', True)) self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0])) self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3])) # Existing parameters with incompatible defaults. with self.assertRaises(ValueError): hparams.get('aaa', 2.0) with self.assertRaises(ValueError): hparams.get('b', False) with self.assertRaises(ValueError): hparams.get('c_c', [1, 2, 3]) with self.assertRaises(ValueError): hparams.get('d', 'relu') with self.assertRaises(ValueError): hparams.get('e', 123.0) with self.assertRaises(ValueError): hparams.get('e', ['a', 'b', 'c']) # Nonexistent parameters. self.assertEqual(None, hparams.get('unknown')) self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
def testSetHParamListNonListMismatch(self): hparams = hparam.HParams(a=1, b=[2.0, 3.0]) with self.assertRaisesRegexp(ValueError, r'Must not pass a list'): hparams.set_hparam('a', [1.0]) with self.assertRaisesRegexp(ValueError, r'Must pass a list'): hparams.set_hparam('b', 1.0)
class XlaDecoratorTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('test_use_as_decorator', decorated_model_fn, None), ('test_use_as_function', xla.estimator_model_fn(_test_train_model_fn), None), ('test_use_tpu_false_hparams', decorated_model_fn, hparam.HParams(use_tpu=False)), ('test_use_tpu_false_dict_params', decorated_model_fn, { 'use_tpu': False }), ) def test_compile(self, model_fn, params): """Calls model_fn and verifies it is compiled.""" with test.mock.patch.object(xla, 'compile') as mock_xla_compile: loss = constant_op.constant(_EXPECTED_LOSS) mock_xla_compile.return_value = [loss] features, labels = make_dummy_features_labels() estimator_spec = model_fn( features=features, labels=labels, mode=_TRAIN, params=params or {}) self.assertEqual(mock_xla_compile.call_count, 1) self.assertEqual(estimator_spec.mode, _TRAIN) with self.test_session() as sess: self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) @parameterized.named_parameters( ('test_use_tpu_true_hparams', decorated_model_fn, hparam.HParams(use_tpu=True)), ('test_use_tpu_true_dict_params', decorated_model_fn, { 'use_tpu': True }), ) def test_not_compile(self, model_fn, params): """Calls model_fn and verifies it is NOT compiled.""" with test.mock.patch.object(xla, 'compile') as mock_xla_compile: loss = constant_op.constant(_EXPECTED_LOSS) mock_xla_compile.return_value = [loss] features, labels = make_dummy_features_labels() estimator_spec = model_fn( features=features, labels=labels, mode=_TRAIN, params=params or {}) mock_xla_compile.assert_not_called() self.assertEqual(estimator_spec.mode, _TRAIN) with self.test_session() as sess: self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) def test_model_with_summary(self): """Tests that summary ops are disabled.""" @xla.estimator_model_fn def model_fn_with_summary(features, labels, mode, params): del features, labels, params loss = constant_op.constant(_EXPECTED_LOSS) summary.scalar('loss_scalar_summary', loss) summary.histogram('loss_histogram_summary', loss) summary.image('loss_image_summary', loss) return model_fn_lib.EstimatorSpec( mode=mode, loss=loss, train_op=array_ops.identity(loss)) features, labels = make_dummy_features_labels() estimator_spec = model_fn_with_summary( features=features, labels=labels, mode=_TRAIN, params={}) with self.test_session() as sess: self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS)
def testBoolParsingFail(self): hparams = hparam.HParams(use_gpu=True) with self.assertRaisesRegexp(ValueError, r'Could not parse.*use_gpu'): hparams.parse('use_gpu=yep')
def testContains(self): hparams = hparam.HParams(foo=1) self.assertTrue('foo' in hparams) self.assertFalse('bar' in hparams)
def testSomeValues(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) expected_str = ('HParams([(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'),' ' (\'d\', \'/a/b=c/d\')])') self.assertEqual(expected_str, repr(hparams)) self.assertEqual(expected_str, repr(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual({ 'aaa': 12, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual({ 'aaa': 12, 'b': -2.0e10, 'c_c': 'relu4', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') self.assertDictEqual({ 'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) hparams.parse('d=/a/b/c/d,aaa=11,') self.assertEqual(11, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a/b/c/d', hparams.d) hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=poipoi') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=1.0') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=12x') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d) # Exports to proto. hparam_def = hparams.to_proto() # Imports from proto. hparams2 = hparam.HParams(hparam_def=hparam_def) # Verifies that all hparams are restored. self.assertEqual(10, hparams2.aaa) self.assertEqual(1.5, hparams2.b) self.assertEqual('2.3"', hparams2.c_c) self.assertEqual('/a=b/c/d', hparams2.d)
def get_pruning_hparams(): """Get a tf.HParams object with the default values for the hyperparameters. name: string name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope begin_pruning_step: integer the global step at which to begin pruning end_pruning_step: integer the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops weight_sparsity_map: list of strings comma separed list of {weight_variable_name:target sparsity} or {regex:target sparsity} pairs. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. Eg. [conv1:0.9,conv2/kernel:0.8] block_dims_map: list of strings comma separated list of {weight variable name:block_height x block_width} or {regex:block_height x block_width} pairs. For layers/weights not in this list, block dims are specified by the block_height, block_width hyperparameters are used Eg. [dense1:4x4,dense2:1x16,dense3:1x1] threshold_decay: float the decay factor to use for exponential decay of the thresholds pruning_frequency: integer How often should the masks be updated? (in # of global_steps) nbins: integer number of bins to use for histogram computation block_height: integer number of rows in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor. block_width: integer number of cols in a block (defaults to 1), can be -1 in which case it is set to the size of the corresponding weight tensor. block_pooling_function: string Whether to perform average (AVG) or max (MAX) pooling in the block (default: AVG) initial_sparsity: float initial sparsity value target_sparsity: float target sparsity value sparsity_function_begin_step: integer the global step at this which the gradual sparsity function begins to take effect sparsity_function_end_step: integer the global step used as the end point for the gradual sparsity function sparsity_function_exponent: float exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning use_tpu: False Indicates whether to use TPU We use the following sparsity function: num_steps = (sparsity_function_end_step - sparsity_function_begin_step)/pruning_frequency sparsity(step) = (initial_sparsity - target_sparsity)* [1-step/(num_steps -1)]**exponent + target_sparsity Args: None Returns: tf.HParams object initialized to default values """ return hparam.HParams( name='model_pruning', begin_pruning_step=0, end_pruning_step=-1, weight_sparsity_map=[''], block_dims_map=[''], threshold_decay=0.0, pruning_frequency=10, nbins=256, block_height=1, block_width=1, block_pooling_function='AVG', initial_sparsity=0.0, target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, sparsity_function_exponent=3.0, use_tpu=False)
from astronet.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top from astronet.contrib.learn.python.learn import experiment from astronet.contrib.learn.python.learn import learn_runner from astronet.contrib.learn.python.learn import trainable from astronet.contrib.learn.python.learn.estimators import run_config as run_config_lib from astronet.contrib.training.python.training import hparam as hparam_lib from tensorflow.python.estimator import run_config as core_run_config_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging patch = test.mock.patch _MODIR_DIR = "/tmp" _HPARAMS = hparam_lib.HParams(learning_rate=0.01) _MUST_SPECIFY_OUTPUT_DIR_MSG = "Must specify an output directory" _MISSING_MODEL_DIR_ERR_MSG = ( "Must specify a model directory `model_dir` in `run_config`.") _EXP_NOT_CALLABLE_MSG = "Experiment builder .* is not callable" _INVALID_HPARAMS_ERR_MSG = "`hparams` must be `HParams` instance" _NOT_EXP_TYPE_MSG = "Experiment builder did not return an Experiment" _NON_EXIST_TASK_MSG = "Schedule references non-existent task" _NON_CALLABLE_MSG = "Schedule references non-callable member" _MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG = ( "Must set value for `output_dir` or `run_config`") _HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG = ( "Must set `hparams` as None for `experiment_fn` with `output_dir`.") _CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG = ( "Cannot provide both `output_dir` and `run_config`") _INVALID_RUN_CONFIG_TYPE_MSG = (