Esempio n. 1
0
    def testSingletonPerGraph(self):
        config_str = """
      ConfigurableClass.kwarg1 = @obj1/singleton_per_graph()
      ConfigurableClass.kwarg2 = @obj2/singleton_per_graph()

      obj1/singleton_per_graph.constructor = @new_object
      obj2/singleton_per_graph.constructor = @new_object
    """
        config.parse_config(config_str)

        with tf.Graph().as_default():
            class1 = ConfigurableClass()
            class2 = ConfigurableClass()

        with tf.Graph().as_default():
            class3 = ConfigurableClass()
            class4 = ConfigurableClass()

        self.assertIs(class1.kwarg1, class2.kwarg1)
        self.assertIs(class1.kwarg2, class2.kwarg2)
        self.assertIsNot(class1.kwarg1, class1.kwarg2)
        self.assertIsNot(class2.kwarg1, class2.kwarg2)

        self.assertIs(class3.kwarg1, class4.kwarg1)
        self.assertIs(class3.kwarg2, class4.kwarg2)
        self.assertIsNot(class3.kwarg1, class3.kwarg2)
        self.assertIsNot(class4.kwarg1, class4.kwarg2)

        self.assertIsNot(class1.kwarg1, class3.kwarg1)
        self.assertIsNot(class1.kwarg2, class3.kwarg2)
        self.assertIsNot(class2.kwarg1, class4.kwarg1)
        self.assertIsNot(class2.kwarg2, class4.kwarg2)
Esempio n. 2
0
    def run_log_config_hook_maybe_with_summary(self, global_step_value,
                                               **kwargs):
        config.parse_config(GinConfigSaverHookTest.CONFIG_STR)

        configurable_fn()
        ConfigurableClass()
        no_args_fn()

        if global_step_value is not None:
            tf.get_variable(
                'global_step',
                shape=(),
                dtype=tf.int64,
                initializer=tf.constant_initializer(global_step_value),
                trainable=False)

        output_dir = tempfile.mkdtemp()
        summary_writer = tf.contrib.testing.FakeSummaryWriter(output_dir)
        h = utils.GinConfigSaverHook(output_dir,
                                     summary_writer=summary_writer,
                                     **kwargs)
        with tf.train.MonitoredSession(hooks=[h]):
            pass

        return output_dir, summary_writer
    def testConfigureOptimizerAndLearningRate(self):
        config_str = """
      fake_train_model.optimizer = @Adam
      torch.optim.Adam.lr = 0.001
      torch.optim.Adam.betas = (0.8, 0.888)
      fake_train_model.scheduler = @StepLR
      StepLR.step_size = 10
    """
        config.parse_config(config_str)

        opt, sch = fake_train_model()  # pylint: disable=no-value-for-parameter

        self.assertIsInstance(opt, torch.optim.Adam)
        self.assertAlmostEqual(opt.param_groups[0]['betas'][0], 0.8)
        self.assertAlmostEqual(opt.param_groups[0]['betas'][1], 0.888)
        self.assertAlmostEqual(opt.defaults['betas'][0], 0.8)
        self.assertAlmostEqual(opt.defaults['betas'][1], 0.888)
        self.assertAlmostEqual(sch.step_size, 10)

        lrs = []
        for _ in range(15):
            lrs.append(opt.param_groups[0]['lr'])
            opt.step()
            sch.step()

        # Divide lr in tenth epoch by 10
        target_lrs = [0.001] * 10 + [0.0001] * 5

        self.assertAlmostEqualList(lrs, target_lrs)
Esempio n. 4
0
  def testOptimizersWithDefaults(self):
    optimizers = [
        tf.compat.v1.train.GradientDescentOptimizer,
        tf.compat.v1.train.AdadeltaOptimizer,
        tf.compat.v1.train.AdagradOptimizer,
        (tf.compat.v1.train.AdagradDAOptimizer, {
            'global_step': '@get_global_step()'
        }),
        (tf.compat.v1.train.MomentumOptimizer, {
            'momentum': 0.9
        }),
        tf.compat.v1.train.AdamOptimizer,
        tf.compat.v1.train.FtrlOptimizer,
        tf.compat.v1.train.ProximalGradientDescentOptimizer,
        tf.compat.v1.train.ProximalAdagradOptimizer,
        tf.compat.v1.train.RMSPropOptimizer,
    ]
    constant_lr = lambda global_step: 0.01
    for optimizer in optimizers:
      extra_bindings = {}
      if isinstance(optimizer, tuple):
        optimizer, extra_bindings = optimizer

      config.clear_config()
      config_lines = ['fake_train_model.optimizer = @%s' % optimizer.__name__]
      for param, val in extra_bindings.items():
        config_lines.append('%s.%s = %s' % (optimizer.__name__, param, val))
      config.parse_config(config_lines)

      # pylint: disable=no-value-for-parameter
      _, configed_optimizer = fake_train_model(constant_lr)
      # pylint: enable=no-value-for-parameter
      self.assertIsInstance(configed_optimizer, optimizer)
    def testCompatibilityWithDynamicRegistration(self):
        config_str = """
      from __gin__ import dynamic_registration
      from gin.tf import external_configurables
      import __main__

      __main__.configurable.arg = %tf.float32
    """
        config.parse_config(config_str)
        self.assertEqual(configurable(), {'arg': tf.float32})
    def testDynamicRegistrationImportAsGinError(self):
        config_str = """
      from __gin__ import dynamic_registration
      import gin.tf.external_configurables
      import __main__

      __main__.configurable.arg = %gin.REQUIRED
    """
        expected_msg = 'The `gin` symbol is reserved; cannot bind import statement '
        with self.assertRaisesRegex(ValueError, expected_msg):
            config.parse_config(config_str)
    def testDtypes(self):
        # Spot check a few.
        config_str = """
      # Test without tf prefix, but using the prefix is strongly recommended!
      configurable.float32 = %float32
      # Test with tf prefix.
      configurable.string = %tf.string
      configurable.qint8 = %tf.qint8
    """
        config.parse_config(config_str)

        vals = configurable()
        self.assertIs(vals['float32'], tf.float32)
        self.assertIs(vals['string'], tf.string)
        self.assertIs(vals['qint8'], tf.qint8)
    def testDtypes(self):
        # Spot check a few.
        config_str = """
      # Test without torch prefix, but using the
      # prefix is strongly recommended!
      configurable.float32 = %float32
      # Test with torch prefix.
      configurable.int8 = %torch.int8
      configurable.float16 = %torch.float16
    """
        config.parse_config(config_str)

        vals = configurable()
        self.assertIs(vals['float32'], torch.float32)
        self.assertIs(vals['int8'], torch.int8)
        self.assertIs(vals['float16'], torch.float16)
Esempio n. 9
0
  def run_log_config_hook_maybe_with_summary(self, global_step_value, **kwargs):
    config.parse_config(GinConfigSaverHookTest.CONFIG_STR)

    configurable_fn()
    ConfigurableClass()
    no_args_fn()

    output_dir = tempfile.mkdtemp()
    summary_writer = FakeSummaryWriter()
    h = utils.GinConfigSaverHook(
        output_dir, summary_writer=summary_writer, **kwargs)
    with self.session() as sess:
      if global_step_value is not None:
        global_step = tf.compat.v1.train.get_or_create_global_step()
        sess.run(global_step.assign(global_step_value))
      h.after_create_session(sess)

    return output_dir, summary_writer
Esempio n. 10
0
  def testKwOnlyArgs(self):
    config_str = """
      fn_with_kw_only_args.arg1 = 'arg1'
      fn_with_kw_only_args.kwarg1 = 'kwarg1'
    """

    arg, kwarg = fn_with_kw_only_args(None)
    self.assertEqual(arg, None)
    self.assertEqual(kwarg, None)
    self.assertIn('fn_with_kw_only_args.kwarg1 = None',
                  config.operative_config_str())

    config.parse_config(config_str)

    arg, kwarg = fn_with_kw_only_args('arg1')
    self.assertEqual(arg, 'arg1')
    self.assertEqual(kwarg, 'kwarg1')
    self.assertIn("fn_with_kw_only_args.kwarg1 = 'kwarg1'",
                  config.operative_config_str())
    def testConfigureOptimizerAndLearningRate(self):
        config_str = """
      fake_train_model.learning_rate = @piecewise_constant
      piecewise_constant.boundaries = [200000]
      piecewise_constant.values = [0.01, 0.001]

      fake_train_model.optimizer = @MomentumOptimizer
      MomentumOptimizer.momentum = 0.95
    """
        config.parse_config(config_str)

        lr, opt = fake_train_model()  # pylint: disable=no-value-for-parameter

        self.assertIsInstance(opt, tf.compat.v1.train.MomentumOptimizer)
        self.assertAlmostEqual(opt._momentum, 0.95)

        global_step = tf.compat.v1.train.get_or_create_global_step()
        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertAlmostEqual(self.evaluate(lr), 0.01)
        self.evaluate(global_step.assign(300000))
        self.assertAlmostEqual(self.evaluate(lr), 0.001)
 def testOptimizersWithDefaults(self):
     optimizers = [
         torch.optim.Adadelta,
         torch.optim.Adagrad,
         torch.optim.Adam,
         torch.optim.SparseAdam,
         torch.optim.Adamax,
         torch.optim.ASGD,
         torch.optim.LBFGS,
         torch.optim.RMSprop,
         torch.optim.Rprop,
         torch.optim.SGD,
     ]
     for optimizer in optimizers:
         config.clear_config()
         config_str = """
     fake_train_model.optimizer = @{optimizer}
     {optimizer}.lr = 0.001
   """
         config.parse_config(
             config_str.format(optimizer=optimizer.__name__))
         configed_optimizer, _ = fake_train_model(config.REQUIRED)
         self.assertIsInstance(configed_optimizer, optimizer)
Esempio n. 13
0
    def testConfigureOptimizerAndLearningRate(self):
        config_str = """
      fake_train_model.learning_rate = @piecewise_constant
      piecewise_constant.boundaries = [200000]
      piecewise_constant.values = [0.01, 0.001]

      fake_train_model.optimizer = @MomentumOptimizer
      MomentumOptimizer.momentum = 0.95
    """
        config.parse_config(config_str)

        lr, opt = fake_train_model()  # pylint: disable=no-value-for-parameter
        global_step = tf.contrib.framework.get_or_create_global_step()
        update_global_step = global_step.assign(300000)
        init = tf.global_variables_initializer()

        self.assertIsInstance(opt, tf.train.MomentumOptimizer)
        self.assertAlmostEqual(opt._momentum, 0.95)
        with self.test_session() as sess:
            sess.run(init)
            self.assertAlmostEqual(sess.run(lr), 0.01)
            sess.run(update_global_step)
            self.assertAlmostEqual(sess.run(lr), 0.001)