예제 #1
0
 def testBoolParsing(self):
     for value in 'true', 'false', 'True', 'False', '1', '0':
         for initial in False, True:
             hparams = HParams(use_gpu=initial)
             hparams.parse('use_gpu=' + value)
             self.assertEqual(hparams.use_gpu, value
                              in ['True', 'true', '1'])
예제 #2
0
    def test_type_kwargs(self):
        """The the special cases involving "type" and "kwargs"
        hyperparameters.
        """
        default_hparams = {"type": "type_name", "kwargs": {"arg1": "argv1"}}

        hparams = {"type": "type_name"}
        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), default_hparams["kwargs"])

        hparams = {"type": "type_name", "kwargs": {"arg2": "argv2"}}
        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)
        full_kwargs = {}
        full_kwargs.update(default_hparams["kwargs"])
        full_kwargs.update(hparams["kwargs"])
        self.assertEqual(hparams_.kwargs.todict(), full_kwargs)

        hparams = {"kwargs": {"arg2": "argv2"}}
        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), full_kwargs)

        hparams = {"type": "type_name2"}
        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), {})

        hparams = {"type": "type_name2", "kwargs": {"arg3": "argv3"}}
        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), hparams["kwargs"])
예제 #3
0
def copy_hparams(hparams: HParams):
    hp_vals = hparams.values()
    new_hparams = HParams(**hp_vals)
    # Make sure below params are part of new phparams while making a copy
    other_attrs = ["problem", "problem_hparams"]
    for attr in other_attrs:
        attr_val = getattr(hparams, attr, None)
        if attr_val is not None:
            setattr(new_hparams, attr, attr_val)
    return new_hparams
예제 #4
0
 def testEmpty(self):
     hparams = HParams()
     self.assertDictEqual({}, hparams.values())
     hparams.parse('')
     self.assertDictEqual({}, hparams.values())
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('xyz=123')
예제 #5
0
    def testInnerDictValues(self):
        class Input:
            def __init__(self):
                self._hparams = HParams()

        class Output:
            def __init__(self):
                self._hparams = HParams()

        hparams = HParams(outer_name='outer_value')
        hparams.a_dict_called_modality = {"input": Input, "output": Output}
        values = list(six.itervalues(hparams.a_dict_called_modality.todict()))
        # for feature_name, modality_cls in six.iteritems(hparams.a_dict_called_modality):
        #     print(feature_name, modality_cls)
        self.assertEqual(values, [Input, Output])
예제 #6
0
    def testSetFromMap(self):
        hparams = HParams(a=1, b=2.0, c='tanh')
        hparams.override_from_dict({'a': -2, 'c': 'identity'})
        self.assertDictEqual({
            'a': -2,
            'c': 'identity',
            'b': 2.0
        }, hparams.values())

        hparams = HParams(x=1, b=2.0, d=[0.5])
        hparams.override_from_dict({'d': [0.1, 0.2, 0.3]})
        self.assertDictEqual({
            'd': [0.1, 0.2, 0.3],
            'x': 1,
            'b': 2.0
        }, hparams.values())
예제 #7
0
  def testLossSingleWeights(self):
    """Ensure _loss_single() respects optional 'weights' argument."""
    with tf.Graph().as_default():
      with self.test_session() as sess:
        batch_size = 2
        sequence_size = 16
        vocab_size = 3

        model_hparams = HParams(
            prepend_mode="none",
            loss={},
            weights_fn={},
            label_smoothing=0.0,
            shared_embedding_and_softmax_weights=False)

        ph = problem_hparams.TestProblem(
            vocab_size, vocab_size).get_hparams(model_hparams)

        model = t2t_model.T2TModel(model_hparams, problem_hparams=ph)
        logits = tf.zeros((batch_size, sequence_size, 1, 1, vocab_size))
        feature = tf.ones((batch_size, sequence_size, 1, 1))

        # all-zero weights == zero loss.
        weights = tf.zeros((batch_size, sequence_size))
        loss_num, loss_denom = model._loss_single(
            logits, "targets", feature, weights=weights)
        self.assertAllClose(tf.zeros_like(loss_num), sess.run(loss_num))
        self.assertAllClose(tf.zeros_like(loss_denom), sess.run(loss_denom))

        # non-zero weights > zero loss.
        weights = tf.ones((batch_size, sequence_size))
        loss_num, loss_denom = model._loss_single(
            logits, "targets", feature, weights=weights)
        self.assertAllLess(0.0, sess.run(loss_num))
        self.assertAllClose(batch_size * sequence_size, sess.run(loss_denom))
예제 #8
0
def _default_hparams():
    """A set of basic model hyperparameters."""
    return HParams(
        # Use this parameter to get comparable perplexity numbers with different
        # tokenizations.  This value should be set to the ratio of the number of
        # tokens in the test set according to the tokenization used to the number
        # of tokens in the test set in the "official" tokenization.  For
        # example, if we are using a word-piece based model and we want to
        # compute per-word perplexity, then we set loss_multiplier to the number
        # of wordpieces per word in the test set.
        loss_multiplier=1.0,

        # Use this parameter to allow for larger sequences in the batch. Without
        # the use of this parameter, the size of the inner two dimensions will
        # be used to judge the sequence length.
        batch_size_multiplier=1,

        # During inference for autoregressive problems, if the batch_size is 1,
        # the inference will stop when the model predict a text_encoder.EOS_ID
        # token.
        stop_at_eos=False,

        # Modalities used to map from features to a space compatible with
        # chosen model architecture. It comprises key-value pairs of a feature
        # name (str) and its modality type.
        modality={},

        # Identifiers used to tell the model which input/target space will be
        # expected. For example, it can tell that we expect French as characters
        # as output, or Spanish as sound. Spaces defined as constants in SpaceID
        # class.
        input_space_id=SpaceID.GENERIC,
        target_space_id=SpaceID.GENERIC)
예제 #9
0
def default_model_hparams():
    return HParams(
        max_input_seq_length=0,
        max_target_seq_length=0,
        prepend_mode="none",
        split_to_length=0,
        data_dir=None)
 def testCreateOutputTrainMode(self, likelihood, num_mixtures, depth):
     batch = 1
     height = 8
     width = 8
     channels = 3
     rows = height
     if likelihood == common_image_attention.DistributionType.CAT:
         cols = channels * width
     else:
         cols = width
     hparams = HParams(
         hidden_size=2,
         likelihood=likelihood,
         mode=tf.estimator.ModeKeys.TRAIN,
         num_mixtures=num_mixtures,
     )
     decoder_output = tf.random_normal(
         [batch, rows, cols, hparams.hidden_size])
     targets = tf.random_uniform([batch, height, width, channels],
                                 minval=-1.,
                                 maxval=1.)
     output = common_image_attention.create_output(decoder_output, rows,
                                                   cols, targets, hparams)
     if hparams.likelihood == common_image_attention.DistributionType.CAT:
         self.assertEqual(output.shape,
                          (batch, height, width, channels, depth))
     else:
         self.assertEqual(output.shape, (batch, height, width, depth))
예제 #11
0
 def testSummarizeLosses(self):
   with tf.Graph().as_default():
     model = vf_model.VfModel(HParams())
     losses = {"training": tf.random_normal([]),
               "extra": tf.random_normal([])}
     outputs = model._summarize_losses(losses)
     self.assertIsNone(outputs, None)
     self.assertEqual(
         len(tf.get_collection(tf.GraphKeys.SUMMARIES, scope="losses")),
         len(losses))
예제 #12
0
    def __init__(self, hparams=None):
        IPreprocessor.__init__(self, hparams=hparams)
        self._hparams = HParams(hparams, self.default_hparams())

        self._spark_master = self._hparams.spark_master
        self._num_clips = self._hparams.num_clips
        self._duration = self._hparams.duration

        self._is_spark_initialized = False

        self.preprocess_prepare()
예제 #13
0
def create_hparams_from_json(json_path, hparams=None):
    """Loading hparams from json; can also start from hparams if specified."""
    tf.logging.info("Loading hparams from existing json %s" % json_path)
    with tf.gfile.Open(json_path, "r") as f:
        hparams_values = json.load(f)
        new_hparams = HParams(**hparams_values)
        # Some keys are in new_hparams but not hparams, so we need to be more
        #   careful than simply using parse_json() from HParams
        if hparams:  # hparams specified, so update values from json
            for key in sorted(new_hparams.values().keys()):
                if hasattr(hparams, key):  # Overlapped keys
                    value = getattr(hparams, key)
                    new_value = getattr(new_hparams, key)
                    if value != new_value:  # Different values
                        tf.logging.info("Overwrite key %s: %s -> %s" %
                                        (key, value, new_value))
                        setattr(hparams, key, new_value)
        else:
            hparams = new_hparams

    return hparams
예제 #14
0
    def test_typecheck(self):
        """Tests type-check functionality.
        """
        def _foo():
            pass

        def _bar():
            pass

        default_hparams = {"fn": _foo, "fn_2": _foo}
        hparams = {"fn": _foo, "fn_2": _bar}
        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)
        self.assertEqual(hparams_.fn, default_hparams["fn"])
 def testPostProcessImageTrainMode(self, likelihood, num_mixtures, depth):
     batch = 1
     rows = 8
     cols = 24
     hparams = HParams(
         hidden_size=2,
         likelihood=likelihood,
         mode=tf.estimator.ModeKeys.TRAIN,
         num_mixtures=num_mixtures,
     )
     inputs = tf.random_uniform([batch, rows, cols, hparams.hidden_size],
                                minval=-1.,
                                maxval=1.)
     outputs = common_image_attention.postprocess_image(
         inputs, rows, cols, hparams)
     self.assertEqual(outputs.shape, (batch, rows, cols, depth))
예제 #16
0
    def __init__(self, hparams=None, dataset=None):
        """
        Uses the TEDLium preprocessed (speaker wise folders) data and generates the mixture signals
        of two speakers

        The features are as follows:
            - spectral features in log scale of two speaker voices. shape: [frames_per_sample, neff]
            - Voice activity detection array (0's/1's). shape: [frames_per_sample, neff]
            - Active Speaker array as labels. shape: [frames_per_sample, neff, 2]

        :param hparams:
        :param dataset:
        """
        IIteratorBase.__init__(self, hparams=hparams)
        ShabdaWavPairFeature.__init__(self)
        self._hparams = HParams(hparams=hparams,
                                default_hparams=self.default_hparams())
        self._dataset = dataset
예제 #17
0
    def __init__(self, hparams=None, data_iterator=None):
        """
        https://arxiv.org/abs/1508.04306
        :param hparams:
        :param data_iterator:
        """
        # ITextFeature.__init__(self)
        ModelBase.__init__(self, hparams=hparams)
        ShabdaWavPairFeature.__init__(self)
        self._hparams = HParams(hparams, self.default_hparams())

        print_error(self._hparams)
        self.lstm_hidden_size = self._hparams.lstm_hidden_size
        self.batch_size = self._hparams.batch_size
        self.p_keep_ff = self._hparams.p_keep_ff
        self.p_keep_rc = self._hparams.p_keep_rc
        self.neff = self._hparams.neff
        self.embd_dim_k = self._hparams.embd_dim_k
        self.frames_per_sample = self._hparams.frames_per_sample
        self.weights = None
        self.biases = None
 def testPostProcessImageInferMode(self, likelihood, num_mixtures, depth):
     batch = 1
     rows = 8
     cols = 24
     block_length = 4
     block_width = 2
     hparams = HParams(
         block_raster_scan=True,
         hidden_size=2,
         likelihood=likelihood,
         mode=tf.estimator.ModeKeys.PREDICT,
         num_mixtures=num_mixtures,
         query_shape=[block_length, block_width],
     )
     inputs = tf.random_uniform([batch, rows, cols, hparams.hidden_size],
                                minval=-1.,
                                maxval=1.)
     outputs = common_image_attention.postprocess_image(
         inputs, rows, cols, hparams)
     num_blocks_rows = rows // block_length
     num_blocks_cols = cols // block_width
     self.assertEqual(outputs.shape,
                      (batch, num_blocks_rows, num_blocks_cols,
                       block_length, block_width, depth))
예제 #19
0
 def __init__(self):
     self._hparams = HParams()
예제 #20
0
 def testLists(self):
     hparams = HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6'])
     self.assertDictEqual({
         'aaa': [1],
         'b': [2.0, 3.0],
         'c_c': ['relu6']
     }, hparams.values())
     self.assertEqual([1], hparams.aaa)
     self.assertEqual([2.0, 3.0], hparams.b)
     self.assertEqual(['relu6'], hparams.c_c)
     hparams.parse('aaa=[12]')
     self.assertEqual([12], hparams.aaa)
     hparams.parse('aaa=[12,34,56]')
     self.assertEqual([12, 34, 56], hparams.aaa)
     hparams.parse('c_c=[relu4,relu12],b=[1.0]')
     self.assertEqual(['relu4', 'relu12'], hparams.c_c)
     self.assertEqual([1.0], hparams.b)
     hparams.parse('c_c=[],aaa=[-34]')
     self.assertEqual([-34], hparams.aaa)
     self.assertEqual([], hparams.c_c)
     hparams.parse('c_c=[_12,3\'4"],aaa=[+3]')
     self.assertEqual([3], hparams.aaa)
     self.assertEqual(['_12', '3\'4"'], hparams.c_c)
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('x=[123]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=[poipoi]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=[1.0]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=[12x]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=[relu]')
     with self.assertRaisesRegexp(ValueError, 'Must pass a list'):
         hparams.parse('aaa=123')
예제 #21
0
    def testDel(self):
        hparams = HParams(aaa=1, b=2.0)

        with self.assertRaises(ValueError):
            hparams.set_hparam('aaa', 'will fail')

        with self.assertRaises(ValueError):
            hparams.add_hparam('aaa', 'will fail')

        hparams.del_hparam('aaa')
        hparams.add_hparam('aaa', 'will work')
        self.assertEqual('will work', hparams.get('aaa'))

        hparams.set_hparam('aaa', 'still works')
        self.assertEqual('still works', hparams.get('aaa'))
예제 #22
0
    def testGet(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0])

        # Existing parameters with default=None.
        self.assertEqual(1, hparams.get('aaa'))
        self.assertEqual(2.0, hparams.get('b'))
        self.assertEqual('relu6', hparams.get('c_c'))
        self.assertEqual(True, hparams.get('d'))
        self.assertEqual([5.0, 6.0], hparams.get('e', None))

        # Existing parameters with compatible defaults.
        self.assertEqual(1, hparams.get('aaa', 2))
        self.assertEqual(2.0, hparams.get('b', 3.0))
        self.assertEqual(2.0, hparams.get('b', 3))
        self.assertEqual('relu6', hparams.get('c_c', 'default'))
        self.assertEqual(True, hparams.get('d', True))
        self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0]))
        self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3]))

        # Existing parameters with incompatible defaults.
        with self.assertRaises(ValueError):
            hparams.get('aaa', 2.0)

        with self.assertRaises(ValueError):
            hparams.get('b', False)

        with self.assertRaises(ValueError):
            hparams.get('c_c', [1, 2, 3])

        with self.assertRaises(ValueError):
            hparams.get('d', 'relu')

        with self.assertRaises(ValueError):
            hparams.get('e', 123.0)

        with self.assertRaises(ValueError):
            hparams.get('e', ['a', 'b', 'c'])

        # Nonexistent parameters.
        self.assertEqual(None, hparams.get('unknown'))
        self.assertEqual(123, hparams.get('unknown', 123))
        self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
예제 #23
0
    def testSetHParamTypeMismatch(self):
        hparams = HParams(int_=1,
                          str_='str',
                          bool_=True,
                          float_=1.1,
                          list_int=[1, 2],
                          none=None)

        with self.assertRaises(ValueError):
            hparams.set_hparam('str_', 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', False)

        with self.assertRaises(ValueError):
            hparams.set_hparam('bool_', 1)

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam('list_int', [2, 3.3])

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', '2')

        # Casting int to float is OK
        hparams.set_hparam('float_', 1)

        # Getting stuck with NoneType :(
        hparams.set_hparam('none', '1')
        self.assertEqual('1', hparams.none)
예제 #24
0
 def testSetHParamListNonListMismatch(self):
     hparams = HParams(a=1, b=[2.0, 3.0])
     with self.assertRaisesRegexp(ValueError, r'Must not pass a list'):
         hparams.set_hparam('a', [1.0])
     with self.assertRaisesRegexp(ValueError, r'Must pass a list'):
         hparams.set_hparam('b', 1.0)
예제 #25
0
    def testSetHParam(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)

        hparams.set_hparam('aaa', 12)
        hparams.set_hparam('b', 3.0)
        hparams.set_hparam('c_c', 'relu4')
        hparams.set_hparam('d', False)
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)
예제 #26
0
 def testWithPeriodInVariableName(self):
     hparams = HParams()
     hparams.add_hparam(name='a.b', value=0.0)
     hparams.parse('a.b=1.0')
     self.assertEqual(1.0, getattr(hparams, 'a.b'))
     hparams.add_hparam(name='c.d', value=0.0)
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('c.d=abc')
     hparams.add_hparam(name='e.f', value='')
     hparams.parse('e.f=abc')
     self.assertEqual('abc', getattr(hparams, 'e.f'))
     hparams.add_hparam(name='d..', value=0.0)
     hparams.parse('d..=10.0')
     self.assertEqual(10.0, getattr(hparams, 'd..'))
예제 #27
0
    def test_hparams(self):
        """Tests the HParams class.
        """
        default_hparams = {
            "str": "str",
            "list": ['item1', 'item2'],
            "dict": {
                "key1": "value1",
                "key2": "value2"
            },
            "nested_dict": {
                "dict_l2": {
                    "key1_l2": "value1_l2"
                }
            },
            "type": "type",
            "kwargs": {
                "arg1": "argv1"
            },
        }

        # Test HParams.items() function
        hparams_ = HParams(hparams=None, default_hparams=default_hparams)
        names = []
        for name, _ in hparams_.items():
            names.append(name)
        self.assertEqual(set(names), set(default_hparams.keys()))

        hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}}

        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)

        # Test HParams construction
        self.assertEqual(hparams_.str, default_hparams["str"])
        self.assertEqual(hparams_.list, default_hparams["list"])
        self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"])
        self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"])
        self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2,
                         default_hparams["nested_dict"]["dict_l2"]["key1_l2"])

        self.assertEqual(len(hparams_), len(default_hparams))

        new_hparams = copy.deepcopy(default_hparams)
        new_hparams["dict"]["key1"] = hparams["dict"]["key1"]
        new_hparams["kwargs"].update(hparams["kwargs"])
        self.assertEqual(hparams_.todict(), new_hparams)

        self.assertTrue("dict" in hparams_)

        self.assertIsNone(hparams_.get('not_existed_name', None))
        self.assertEqual(hparams_.get('str'), default_hparams['str'])

        # Test HParams update related operations
        hparams_.str = "new_str"
        hparams_.dict = {"key3": "value3"}
        self.assertEqual(hparams_.str, "new_str")
        self.assertEqual(hparams_.dict.key3, "value3")

        hparams_.add_hparam("added_str", "added_str")
        hparams_.add_hparam("added_dict", {"key4": "value4"})
        hparams_.kwargs.add_hparam("added_arg", "added_argv")
        self.assertEqual(hparams_.added_str, "added_str")
        self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"})
        self.assertEqual(hparams_.kwargs.added_arg, "added_argv")

        # Test HParams I/O
        hparams_file = tempfile.NamedTemporaryFile()
        pickle.dump(hparams_, hparams_file)
        with open(hparams_file.name, 'rb') as hparams_file:
            hparams_loaded = pickle.load(hparams_file)
        self.assertEqual(hparams_loaded.todict(), hparams_.todict())
예제 #28
0
 def testSomeValues(self):
     hparams = HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d')
     self.assertDictEqual(
         {
             'aaa': 1,
             'b': 2.0,
             'c_c': 'relu6',
             'd': '/a/b=c/d'
         }, hparams.values())
     # expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), '
     #                 '(\'d\', \'/a/b=c/d\')]')
     expected_str = ('{\'aaa\': 1, \'b\': 2.0, \'c_c\': \'relu6\', '
                     '\'d\': \'/a/b=c/d\'}')
     self.assertEqual(expected_str, str(hparams.__str__()))
     self.assertEqual(expected_str, str(hparams))
     self.assertEqual(1, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('relu6', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('aaa=12')
     self.assertDictEqual(
         {
             'aaa': 12,
             'b': 2.0,
             'c_c': 'relu6',
             'd': '/a/b=c/d'
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('relu6', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=relu4, b=-2.0e10')
     self.assertDictEqual(
         {
             'aaa': 12,
             'b': -2.0e10,
             'c_c': 'relu4',
             'd': '/a/b=c/d'
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(-2.0e10, hparams.b)
     self.assertEqual('relu4', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=,b=0,')
     self.assertDictEqual({
         'aaa': 12,
         'b': 0,
         'c_c': '',
         'd': '/a/b=c/d'
     }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(0.0, hparams.b)
     self.assertEqual('', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=2.3",b=+2,')
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     hparams.parse('d=/a/b/c/d,aaa=11,')
     self.assertEqual(11, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a/b/c/d', hparams.d)
     hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,')
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a=b/c/d', hparams.d)
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('x=123')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=poipoi')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=1.0')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=12x')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=relu')
     with self.assertRaisesRegexp(ValueError, 'Must not pass a list'):
         hparams.parse('aaa=[123]')
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a=b/c/d', hparams.d)
예제 #29
0
    def testJson(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)
        hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}')
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)

        json_str = hparams.to_json()
        hparams2 = HParams(aaa=10, b=20.0, c_c='hello', d=False)
        hparams2.parse_json(json_str)
        self.assertEqual(12, hparams2.aaa)
        self.assertEqual(3.0, hparams2.b)
        self.assertEqual('relu4', hparams2.c_c)
        self.assertEqual(False, hparams2.d)

        hparams3 = HParams(aaa=123)
        self.assertEqual('{"aaa": 123}', hparams3.to_json())
        self.assertEqual('{\n  "aaa": 123\n}', hparams3.to_json(indent=2))
        self.assertEqual('{"aaa"=123}',
                         hparams3.to_json(separators=(';', '=')))

        hparams4 = HParams(aaa=123, b='hello', c_c=False)
        self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}',
                         hparams4.to_json(sort_keys=True))
예제 #30
0
 def testBoolParsingFail(self):
     hparams = HParams(use_gpu=True)
     with self.assertRaisesRegexp(ValueError, r'Could not parse.*use_gpu'):
         hparams.parse('use_gpu=yep')