class OptimizerConfig(oneof.OneOfConfig):
  """Configuration for optimizer.

  Attributes:
    type: 'str', type of optimizer to be used, on the of fields below.
    sgd: sgd optimizer config.
    adam: adam optimizer config.
    adamw: adam with weight decay.
    lamb: lamb optimizer.
    rmsprop: rmsprop optimizer.
    lars: lars optimizer.
    adagrad: adagrad optimizer.
    slide: slide optimizer.
  """
  type: Optional[str] = None
  sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
  sgd_experimental: opt_cfg.SGDExperimentalConfig = (
      opt_cfg.SGDExperimentalConfig())
  adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
  adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
  lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
  rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
  lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
  adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig()
  slide: opt_cfg.SLIDEConfig = opt_cfg.SLIDEConfig()
  adafactor: opt_cfg.AdafactorConfig = opt_cfg.AdafactorConfig()
Beispiel #2
0
class OptimizerConfig(oneof.OneOfConfig):
    """Configuration for optimizer.

  Attributes:
    type: 'str', type of optimizer to be used, on the of fields below.
    sgd: sgd optimizer config.
    adam: adam optimizer config.
    adamw: adam with weight decay.
    lamb: lamb optimizer.
  """
    type: Optional[str] = None
    sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
    adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
    adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
    lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
 def test_config(self):
     opt_config = optimization_config.OptimizationConfig({
         'optimizer': {
             'type': 'sgd',
             'sgd': {}  # default config
         },
         'learning_rate': {
             'type': 'polynomial',
             'polynomial': {}
         },
         'warmup': {
             'type': 'linear'
         }
     })
     self.assertEqual(opt_config.optimizer.get(), opt_cfg.SGDConfig())
     self.assertEqual(opt_config.learning_rate.get(),
                      lr_cfg.PolynomialLrConfig())
     self.assertEqual(opt_config.warmup.get(), lr_cfg.LinearWarmupConfig())