def testBoolParsing(self): for value in 'true', 'false', 'True', 'False', '1', '0': for initial in False, True: hparams = HParams(use_gpu=initial) hparams.parse('use_gpu=' + value) self.assertEqual(hparams.use_gpu, value in ['True', 'true', '1'])
def test_type_kwargs(self): """The the special cases involving "type" and "kwargs" hyperparameters. """ default_hparams = {"type": "type_name", "kwargs": {"arg1": "argv1"}} hparams = {"type": "type_name"} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) self.assertEqual(hparams_.kwargs.todict(), default_hparams["kwargs"]) hparams = {"type": "type_name", "kwargs": {"arg2": "argv2"}} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) full_kwargs = {} full_kwargs.update(default_hparams["kwargs"]) full_kwargs.update(hparams["kwargs"]) self.assertEqual(hparams_.kwargs.todict(), full_kwargs) hparams = {"kwargs": {"arg2": "argv2"}} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) self.assertEqual(hparams_.kwargs.todict(), full_kwargs) hparams = {"type": "type_name2"} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) self.assertEqual(hparams_.kwargs.todict(), {}) hparams = {"type": "type_name2", "kwargs": {"arg3": "argv3"}} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) self.assertEqual(hparams_.kwargs.todict(), hparams["kwargs"])
def copy_hparams(hparams: HParams): hp_vals = hparams.values() new_hparams = HParams(**hp_vals) # Make sure below params are part of new phparams while making a copy other_attrs = ["problem", "problem_hparams"] for attr in other_attrs: attr_val = getattr(hparams, attr, None) if attr_val is not None: setattr(new_hparams, attr, attr_val) return new_hparams
def testEmpty(self): hparams = HParams() self.assertDictEqual({}, hparams.values()) hparams.parse('') self.assertDictEqual({}, hparams.values()) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('xyz=123')
def testInnerDictValues(self): class Input: def __init__(self): self._hparams = HParams() class Output: def __init__(self): self._hparams = HParams() hparams = HParams(outer_name='outer_value') hparams.a_dict_called_modality = {"input": Input, "output": Output} values = list(six.itervalues(hparams.a_dict_called_modality.todict())) # for feature_name, modality_cls in six.iteritems(hparams.a_dict_called_modality): # print(feature_name, modality_cls) self.assertEqual(values, [Input, Output])
def testSetFromMap(self): hparams = HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) self.assertDictEqual({ 'a': -2, 'c': 'identity', 'b': 2.0 }, hparams.values()) hparams = HParams(x=1, b=2.0, d=[0.5]) hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) self.assertDictEqual({ 'd': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0 }, hparams.values())
def testLossSingleWeights(self): """Ensure _loss_single() respects optional 'weights' argument.""" with tf.Graph().as_default(): with self.test_session() as sess: batch_size = 2 sequence_size = 16 vocab_size = 3 model_hparams = HParams( prepend_mode="none", loss={}, weights_fn={}, label_smoothing=0.0, shared_embedding_and_softmax_weights=False) ph = problem_hparams.TestProblem( vocab_size, vocab_size).get_hparams(model_hparams) model = t2t_model.T2TModel(model_hparams, problem_hparams=ph) logits = tf.zeros((batch_size, sequence_size, 1, 1, vocab_size)) feature = tf.ones((batch_size, sequence_size, 1, 1)) # all-zero weights == zero loss. weights = tf.zeros((batch_size, sequence_size)) loss_num, loss_denom = model._loss_single( logits, "targets", feature, weights=weights) self.assertAllClose(tf.zeros_like(loss_num), sess.run(loss_num)) self.assertAllClose(tf.zeros_like(loss_denom), sess.run(loss_denom)) # non-zero weights > zero loss. weights = tf.ones((batch_size, sequence_size)) loss_num, loss_denom = model._loss_single( logits, "targets", feature, weights=weights) self.assertAllLess(0.0, sess.run(loss_num)) self.assertAllClose(batch_size * sequence_size, sess.run(loss_denom))
def _default_hparams(): """A set of basic model hyperparameters.""" return HParams( # Use this parameter to get comparable perplexity numbers with different # tokenizations. This value should be set to the ratio of the number of # tokens in the test set according to the tokenization used to the number # of tokens in the test set in the "official" tokenization. For # example, if we are using a word-piece based model and we want to # compute per-word perplexity, then we set loss_multiplier to the number # of wordpieces per word in the test set. loss_multiplier=1.0, # Use this parameter to allow for larger sequences in the batch. Without # the use of this parameter, the size of the inner two dimensions will # be used to judge the sequence length. batch_size_multiplier=1, # During inference for autoregressive problems, if the batch_size is 1, # the inference will stop when the model predict a text_encoder.EOS_ID # token. stop_at_eos=False, # Modalities used to map from features to a space compatible with # chosen model architecture. It comprises key-value pairs of a feature # name (str) and its modality type. modality={}, # Identifiers used to tell the model which input/target space will be # expected. For example, it can tell that we expect French as characters # as output, or Spanish as sound. Spaces defined as constants in SpaceID # class. input_space_id=SpaceID.GENERIC, target_space_id=SpaceID.GENERIC)
def default_model_hparams(): return HParams( max_input_seq_length=0, max_target_seq_length=0, prepend_mode="none", split_to_length=0, data_dir=None)
def testCreateOutputTrainMode(self, likelihood, num_mixtures, depth): batch = 1 height = 8 width = 8 channels = 3 rows = height if likelihood == common_image_attention.DistributionType.CAT: cols = channels * width else: cols = width hparams = HParams( hidden_size=2, likelihood=likelihood, mode=tf.estimator.ModeKeys.TRAIN, num_mixtures=num_mixtures, ) decoder_output = tf.random_normal( [batch, rows, cols, hparams.hidden_size]) targets = tf.random_uniform([batch, height, width, channels], minval=-1., maxval=1.) output = common_image_attention.create_output(decoder_output, rows, cols, targets, hparams) if hparams.likelihood == common_image_attention.DistributionType.CAT: self.assertEqual(output.shape, (batch, height, width, channels, depth)) else: self.assertEqual(output.shape, (batch, height, width, depth))
def testSummarizeLosses(self): with tf.Graph().as_default(): model = vf_model.VfModel(HParams()) losses = {"training": tf.random_normal([]), "extra": tf.random_normal([])} outputs = model._summarize_losses(losses) self.assertIsNone(outputs, None) self.assertEqual( len(tf.get_collection(tf.GraphKeys.SUMMARIES, scope="losses")), len(losses))
def __init__(self, hparams=None): IPreprocessor.__init__(self, hparams=hparams) self._hparams = HParams(hparams, self.default_hparams()) self._spark_master = self._hparams.spark_master self._num_clips = self._hparams.num_clips self._duration = self._hparams.duration self._is_spark_initialized = False self.preprocess_prepare()
def create_hparams_from_json(json_path, hparams=None): """Loading hparams from json; can also start from hparams if specified.""" tf.logging.info("Loading hparams from existing json %s" % json_path) with tf.gfile.Open(json_path, "r") as f: hparams_values = json.load(f) new_hparams = HParams(**hparams_values) # Some keys are in new_hparams but not hparams, so we need to be more # careful than simply using parse_json() from HParams if hparams: # hparams specified, so update values from json for key in sorted(new_hparams.values().keys()): if hasattr(hparams, key): # Overlapped keys value = getattr(hparams, key) new_value = getattr(new_hparams, key) if value != new_value: # Different values tf.logging.info("Overwrite key %s: %s -> %s" % (key, value, new_value)) setattr(hparams, key, new_value) else: hparams = new_hparams return hparams
def test_typecheck(self): """Tests type-check functionality. """ def _foo(): pass def _bar(): pass default_hparams = {"fn": _foo, "fn_2": _foo} hparams = {"fn": _foo, "fn_2": _bar} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) self.assertEqual(hparams_.fn, default_hparams["fn"])
def testPostProcessImageTrainMode(self, likelihood, num_mixtures, depth): batch = 1 rows = 8 cols = 24 hparams = HParams( hidden_size=2, likelihood=likelihood, mode=tf.estimator.ModeKeys.TRAIN, num_mixtures=num_mixtures, ) inputs = tf.random_uniform([batch, rows, cols, hparams.hidden_size], minval=-1., maxval=1.) outputs = common_image_attention.postprocess_image( inputs, rows, cols, hparams) self.assertEqual(outputs.shape, (batch, rows, cols, depth))
def __init__(self, hparams=None, dataset=None): """ Uses the TEDLium preprocessed (speaker wise folders) data and generates the mixture signals of two speakers The features are as follows: - spectral features in log scale of two speaker voices. shape: [frames_per_sample, neff] - Voice activity detection array (0's/1's). shape: [frames_per_sample, neff] - Active Speaker array as labels. shape: [frames_per_sample, neff, 2] :param hparams: :param dataset: """ IIteratorBase.__init__(self, hparams=hparams) ShabdaWavPairFeature.__init__(self) self._hparams = HParams(hparams=hparams, default_hparams=self.default_hparams()) self._dataset = dataset
def __init__(self, hparams=None, data_iterator=None): """ https://arxiv.org/abs/1508.04306 :param hparams: :param data_iterator: """ # ITextFeature.__init__(self) ModelBase.__init__(self, hparams=hparams) ShabdaWavPairFeature.__init__(self) self._hparams = HParams(hparams, self.default_hparams()) print_error(self._hparams) self.lstm_hidden_size = self._hparams.lstm_hidden_size self.batch_size = self._hparams.batch_size self.p_keep_ff = self._hparams.p_keep_ff self.p_keep_rc = self._hparams.p_keep_rc self.neff = self._hparams.neff self.embd_dim_k = self._hparams.embd_dim_k self.frames_per_sample = self._hparams.frames_per_sample self.weights = None self.biases = None
def testPostProcessImageInferMode(self, likelihood, num_mixtures, depth): batch = 1 rows = 8 cols = 24 block_length = 4 block_width = 2 hparams = HParams( block_raster_scan=True, hidden_size=2, likelihood=likelihood, mode=tf.estimator.ModeKeys.PREDICT, num_mixtures=num_mixtures, query_shape=[block_length, block_width], ) inputs = tf.random_uniform([batch, rows, cols, hparams.hidden_size], minval=-1., maxval=1.) outputs = common_image_attention.postprocess_image( inputs, rows, cols, hparams) num_blocks_rows = rows // block_length num_blocks_cols = cols // block_width self.assertEqual(outputs.shape, (batch, num_blocks_rows, num_blocks_cols, block_length, block_width, depth))
def __init__(self): self._hparams = HParams()
def testLists(self): hparams = HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6']) self.assertDictEqual({ 'aaa': [1], 'b': [2.0, 3.0], 'c_c': ['relu6'] }, hparams.values()) self.assertEqual([1], hparams.aaa) self.assertEqual([2.0, 3.0], hparams.b) self.assertEqual(['relu6'], hparams.c_c) hparams.parse('aaa=[12]') self.assertEqual([12], hparams.aaa) hparams.parse('aaa=[12,34,56]') self.assertEqual([12, 34, 56], hparams.aaa) hparams.parse('c_c=[relu4,relu12],b=[1.0]') self.assertEqual(['relu4', 'relu12'], hparams.c_c) self.assertEqual([1.0], hparams.b) hparams.parse('c_c=[],aaa=[-34]') self.assertEqual([-34], hparams.aaa) self.assertEqual([], hparams.c_c) hparams.parse('c_c=[_12,3\'4"],aaa=[+3]') self.assertEqual([3], hparams.aaa) self.assertEqual(['_12', '3\'4"'], hparams.c_c) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=[123]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[poipoi]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[1.0]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[12x]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[relu]') with self.assertRaisesRegexp(ValueError, 'Must pass a list'): hparams.parse('aaa=123')
def testDel(self): hparams = HParams(aaa=1, b=2.0) with self.assertRaises(ValueError): hparams.set_hparam('aaa', 'will fail') with self.assertRaises(ValueError): hparams.add_hparam('aaa', 'will fail') hparams.del_hparam('aaa') hparams.add_hparam('aaa', 'will work') self.assertEqual('will work', hparams.get('aaa')) hparams.set_hparam('aaa', 'still works') self.assertEqual('still works', hparams.get('aaa'))
def testGet(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0]) # Existing parameters with default=None. self.assertEqual(1, hparams.get('aaa')) self.assertEqual(2.0, hparams.get('b')) self.assertEqual('relu6', hparams.get('c_c')) self.assertEqual(True, hparams.get('d')) self.assertEqual([5.0, 6.0], hparams.get('e', None)) # Existing parameters with compatible defaults. self.assertEqual(1, hparams.get('aaa', 2)) self.assertEqual(2.0, hparams.get('b', 3.0)) self.assertEqual(2.0, hparams.get('b', 3)) self.assertEqual('relu6', hparams.get('c_c', 'default')) self.assertEqual(True, hparams.get('d', True)) self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0])) self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3])) # Existing parameters with incompatible defaults. with self.assertRaises(ValueError): hparams.get('aaa', 2.0) with self.assertRaises(ValueError): hparams.get('b', False) with self.assertRaises(ValueError): hparams.get('c_c', [1, 2, 3]) with self.assertRaises(ValueError): hparams.get('d', 'relu') with self.assertRaises(ValueError): hparams.get('e', 123.0) with self.assertRaises(ValueError): hparams.get('e', ['a', 'b', 'c']) # Nonexistent parameters. self.assertEqual(None, hparams.get('unknown')) self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
def testSetHParamTypeMismatch(self): hparams = HParams(int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None) with self.assertRaises(ValueError): hparams.set_hparam('str_', 2.2) with self.assertRaises(ValueError): hparams.set_hparam('int_', False) with self.assertRaises(ValueError): hparams.set_hparam('bool_', 1) with self.assertRaises(ValueError): hparams.set_hparam('int_', 2.2) with self.assertRaises(ValueError): hparams.set_hparam('list_int', [2, 3.3]) with self.assertRaises(ValueError): hparams.set_hparam('int_', '2') # Casting int to float is OK hparams.set_hparam('float_', 1) # Getting stuck with NoneType :( hparams.set_hparam('none', '1') self.assertEqual('1', hparams.none)
def testSetHParamListNonListMismatch(self): hparams = HParams(a=1, b=[2.0, 3.0]) with self.assertRaisesRegexp(ValueError, r'Must not pass a list'): hparams.set_hparam('a', [1.0]) with self.assertRaisesRegexp(ValueError, r'Must pass a list'): hparams.set_hparam('b', 1.0)
def testSetHParam(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.set_hparam('aaa', 12) hparams.set_hparam('b', 3.0) hparams.set_hparam('c_c', 'relu4') hparams.set_hparam('d', False) self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c)
def testWithPeriodInVariableName(self): hparams = HParams() hparams.add_hparam(name='a.b', value=0.0) hparams.parse('a.b=1.0') self.assertEqual(1.0, getattr(hparams, 'a.b')) hparams.add_hparam(name='c.d', value=0.0) with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('c.d=abc') hparams.add_hparam(name='e.f', value='') hparams.parse('e.f=abc') self.assertEqual('abc', getattr(hparams, 'e.f')) hparams.add_hparam(name='d..', value=0.0) hparams.parse('d..=10.0') self.assertEqual(10.0, getattr(hparams, 'd..'))
def test_hparams(self): """Tests the HParams class. """ default_hparams = { "str": "str", "list": ['item1', 'item2'], "dict": { "key1": "value1", "key2": "value2" }, "nested_dict": { "dict_l2": { "key1_l2": "value1_l2" } }, "type": "type", "kwargs": { "arg1": "argv1" }, } # Test HParams.items() function hparams_ = HParams(hparams=None, default_hparams=default_hparams) names = [] for name, _ in hparams_.items(): names.append(name) self.assertEqual(set(names), set(default_hparams.keys())) hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) # Test HParams construction self.assertEqual(hparams_.str, default_hparams["str"]) self.assertEqual(hparams_.list, default_hparams["list"]) self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"]) self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"]) self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2, default_hparams["nested_dict"]["dict_l2"]["key1_l2"]) self.assertEqual(len(hparams_), len(default_hparams)) new_hparams = copy.deepcopy(default_hparams) new_hparams["dict"]["key1"] = hparams["dict"]["key1"] new_hparams["kwargs"].update(hparams["kwargs"]) self.assertEqual(hparams_.todict(), new_hparams) self.assertTrue("dict" in hparams_) self.assertIsNone(hparams_.get('not_existed_name', None)) self.assertEqual(hparams_.get('str'), default_hparams['str']) # Test HParams update related operations hparams_.str = "new_str" hparams_.dict = {"key3": "value3"} self.assertEqual(hparams_.str, "new_str") self.assertEqual(hparams_.dict.key3, "value3") hparams_.add_hparam("added_str", "added_str") hparams_.add_hparam("added_dict", {"key4": "value4"}) hparams_.kwargs.add_hparam("added_arg", "added_argv") self.assertEqual(hparams_.added_str, "added_str") self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"}) self.assertEqual(hparams_.kwargs.added_arg, "added_argv") # Test HParams I/O hparams_file = tempfile.NamedTemporaryFile() pickle.dump(hparams_, hparams_file) with open(hparams_file.name, 'rb') as hparams_file: hparams_loaded = pickle.load(hparams_file) self.assertEqual(hparams_loaded.todict(), hparams_.todict())
def testSomeValues(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') self.assertDictEqual( { 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) # expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' # '(\'d\', \'/a/b=c/d\')]') expected_str = ('{\'aaa\': 1, \'b\': 2.0, \'c_c\': \'relu6\', ' '\'d\': \'/a/b=c/d\'}') self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual( { 'aaa': 12, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual( { 'aaa': 12, 'b': -2.0e10, 'c_c': 'relu4', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') self.assertDictEqual({ 'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) hparams.parse('d=/a/b/c/d,aaa=11,') self.assertEqual(11, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a/b/c/d', hparams.d) hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=poipoi') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=1.0') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=12x') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d)
def testJson(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}') self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c) json_str = hparams.to_json() hparams2 = HParams(aaa=10, b=20.0, c_c='hello', d=False) hparams2.parse_json(json_str) self.assertEqual(12, hparams2.aaa) self.assertEqual(3.0, hparams2.b) self.assertEqual('relu4', hparams2.c_c) self.assertEqual(False, hparams2.d) hparams3 = HParams(aaa=123) self.assertEqual('{"aaa": 123}', hparams3.to_json()) self.assertEqual('{\n "aaa": 123\n}', hparams3.to_json(indent=2)) self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(';', '='))) hparams4 = HParams(aaa=123, b='hello', c_c=False) self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}', hparams4.to_json(sort_keys=True))
def testBoolParsingFail(self): hparams = HParams(use_gpu=True) with self.assertRaisesRegexp(ValueError, r'Could not parse.*use_gpu'): hparams.parse('use_gpu=yep')