Ejemplo n.º 1
0
  def test_TFParams(self):
    """Merging namespace args w/ ML Params"""
    class Foo(TFParams, HasBatchSize, HasSteps):
      def __init__(self, args):
        super(Foo, self).__init__()
        self.args = args

    n = Namespace({ 'a': 1, 'b': 2 })
    f = Foo(n).setBatchSize(10).setSteps(100)
    combined_args = f.merge_args_params()
    expected_args = Namespace({ 'a': 1, 'b': 2, 'batch_size': 10, 'steps': 100 })
    self.assertEqual(combined_args, expected_args)
Ejemplo n.º 2
0
 def test_namespace(self):
   """Namespace class from dict"""
   d = { 'string': 'foo', 'integer': 1, 'float': 3.14, 'array': [1,2,3], 'map': {'a':1, 'b':2} }
   n = Namespace(d)
   self.assertEqual(n.string, 'foo')
   self.assertEqual(n.integer, 1)
   self.assertEqual(n.float, 3.14)
   self.assertEqual(n.array, [1,2,3])
   self.assertEqual(n.map, {'a':1, 'b':2})
   self.assertTrue('string' in n)
   self.assertFalse('extra' in n)
Ejemplo n.º 3
0
    def test_namespace(self):
        """Namespace class initializers"""
        # from dictionary
        d = {
            'string': 'foo',
            'integer': 1,
            'float': 3.14,
            'array': [1, 2, 3],
            'map': {
                'a': 1,
                'b': 2
            }
        }
        n1 = Namespace(d)
        self.assertEqual(n1.string, 'foo')
        self.assertEqual(n1.integer, 1)
        self.assertEqual(n1.float, 3.14)
        self.assertEqual(n1.array, [1, 2, 3])
        self.assertEqual(n1.map, {'a': 1, 'b': 2})
        self.assertTrue('string' in n1)
        self.assertFalse('extra' in n1)

        # from namespace
        n2 = Namespace(n1)
        self.assertEqual(n2.string, 'foo')
        self.assertEqual(n2.integer, 1)
        self.assertEqual(n2.float, 3.14)
        self.assertEqual(n2.array, [1, 2, 3])
        self.assertEqual(n2.map, {'a': 1, 'b': 2})
        self.assertTrue('string' in n2)
        self.assertFalse('extra' in n2)

        # from argv list
        argv = ["--foo", "1", "--bar", "test", "--baz", "3.14"]
        n3 = Namespace(argv)
        self.assertEqual(n3.argv, argv)