Exemplo n.º 1
0
  def testOptimizersWithDefaults(self):
    optimizers = [
        tf.compat.v1.train.GradientDescentOptimizer,
        tf.compat.v1.train.AdadeltaOptimizer,
        tf.compat.v1.train.AdagradOptimizer,
        (tf.compat.v1.train.AdagradDAOptimizer, {
            'global_step': '@get_global_step()'
        }),
        (tf.compat.v1.train.MomentumOptimizer, {
            'momentum': 0.9
        }),
        tf.compat.v1.train.AdamOptimizer,
        tf.compat.v1.train.FtrlOptimizer,
        tf.compat.v1.train.ProximalGradientDescentOptimizer,
        tf.compat.v1.train.ProximalAdagradOptimizer,
        tf.compat.v1.train.RMSPropOptimizer,
    ]
    constant_lr = lambda global_step: 0.01
    for optimizer in optimizers:
      extra_bindings = {}
      if isinstance(optimizer, tuple):
        optimizer, extra_bindings = optimizer

      config.clear_config()
      config_lines = ['fake_train_model.optimizer = @%s' % optimizer.__name__]
      for param, val in extra_bindings.items():
        config_lines.append('%s.%s = %s' % (optimizer.__name__, param, val))
      config.parse_config(config_lines)

      # pylint: disable=no-value-for-parameter
      _, configed_optimizer = fake_train_model(constant_lr)
      # pylint: enable=no-value-for-parameter
      self.assertIsInstance(configed_optimizer, optimizer)
 def testOptimizersWithDefaults(self):
     optimizers = [
         torch.optim.Adadelta,
         torch.optim.Adagrad,
         torch.optim.Adam,
         torch.optim.SparseAdam,
         torch.optim.Adamax,
         torch.optim.ASGD,
         torch.optim.LBFGS,
         torch.optim.RMSprop,
         torch.optim.Rprop,
         torch.optim.SGD,
     ]
     for optimizer in optimizers:
         config.clear_config()
         config_str = """
     fake_train_model.optimizer = @{optimizer}
     {optimizer}.lr = 0.001
   """
         config.parse_config(
             config_str.format(optimizer=optimizer.__name__))
         configed_optimizer, _ = fake_train_model(config.REQUIRED)
         self.assertIsInstance(configed_optimizer, optimizer)
Exemplo n.º 3
0
 def setUp(self):
     tf.reset_default_graph()
     config.clear_config()
 def tearDown(self):
     config.clear_config()
     super(TFConfigTest, self).tearDown()
 def tearDown(self):
     config.clear_config()
     super(PyTorchConfigTest, self).tearDown()
Exemplo n.º 6
0
 def tearDown(self):
     config.clear_config()
Exemplo n.º 7
0
 def setUp(self):
   super().setUp()
   tf.compat.v1.reset_default_graph()
   config.clear_config()
Exemplo n.º 8
0
 def setUp(self):
   super().setUp()
   tf.compat.v1.disable_eager_execution()
   tf.compat.v1.reset_default_graph()
   config.clear_config()