def testSingletonPerGraph(self): config_str = """ ConfigurableClass.kwarg1 = @obj1/singleton_per_graph() ConfigurableClass.kwarg2 = @obj2/singleton_per_graph() obj1/singleton_per_graph.constructor = @new_object obj2/singleton_per_graph.constructor = @new_object """ config.parse_config(config_str) with tf.Graph().as_default(): class1 = ConfigurableClass() class2 = ConfigurableClass() with tf.Graph().as_default(): class3 = ConfigurableClass() class4 = ConfigurableClass() self.assertIs(class1.kwarg1, class2.kwarg1) self.assertIs(class1.kwarg2, class2.kwarg2) self.assertIsNot(class1.kwarg1, class1.kwarg2) self.assertIsNot(class2.kwarg1, class2.kwarg2) self.assertIs(class3.kwarg1, class4.kwarg1) self.assertIs(class3.kwarg2, class4.kwarg2) self.assertIsNot(class3.kwarg1, class3.kwarg2) self.assertIsNot(class4.kwarg1, class4.kwarg2) self.assertIsNot(class1.kwarg1, class3.kwarg1) self.assertIsNot(class1.kwarg2, class3.kwarg2) self.assertIsNot(class2.kwarg1, class4.kwarg1) self.assertIsNot(class2.kwarg2, class4.kwarg2)
def run_log_config_hook_maybe_with_summary(self, global_step_value, **kwargs): config.parse_config(GinConfigSaverHookTest.CONFIG_STR) configurable_fn() ConfigurableClass() no_args_fn() if global_step_value is not None: tf.get_variable( 'global_step', shape=(), dtype=tf.int64, initializer=tf.constant_initializer(global_step_value), trainable=False) output_dir = tempfile.mkdtemp() summary_writer = tf.contrib.testing.FakeSummaryWriter(output_dir) h = utils.GinConfigSaverHook(output_dir, summary_writer=summary_writer, **kwargs) with tf.train.MonitoredSession(hooks=[h]): pass return output_dir, summary_writer
def testConfigureOptimizerAndLearningRate(self): config_str = """ fake_train_model.optimizer = @Adam torch.optim.Adam.lr = 0.001 torch.optim.Adam.betas = (0.8, 0.888) fake_train_model.scheduler = @StepLR StepLR.step_size = 10 """ config.parse_config(config_str) opt, sch = fake_train_model() # pylint: disable=no-value-for-parameter self.assertIsInstance(opt, torch.optim.Adam) self.assertAlmostEqual(opt.param_groups[0]['betas'][0], 0.8) self.assertAlmostEqual(opt.param_groups[0]['betas'][1], 0.888) self.assertAlmostEqual(opt.defaults['betas'][0], 0.8) self.assertAlmostEqual(opt.defaults['betas'][1], 0.888) self.assertAlmostEqual(sch.step_size, 10) lrs = [] for _ in range(15): lrs.append(opt.param_groups[0]['lr']) opt.step() sch.step() # Divide lr in tenth epoch by 10 target_lrs = [0.001] * 10 + [0.0001] * 5 self.assertAlmostEqualList(lrs, target_lrs)
def testOptimizersWithDefaults(self): optimizers = [ tf.compat.v1.train.GradientDescentOptimizer, tf.compat.v1.train.AdadeltaOptimizer, tf.compat.v1.train.AdagradOptimizer, (tf.compat.v1.train.AdagradDAOptimizer, { 'global_step': '@get_global_step()' }), (tf.compat.v1.train.MomentumOptimizer, { 'momentum': 0.9 }), tf.compat.v1.train.AdamOptimizer, tf.compat.v1.train.FtrlOptimizer, tf.compat.v1.train.ProximalGradientDescentOptimizer, tf.compat.v1.train.ProximalAdagradOptimizer, tf.compat.v1.train.RMSPropOptimizer, ] constant_lr = lambda global_step: 0.01 for optimizer in optimizers: extra_bindings = {} if isinstance(optimizer, tuple): optimizer, extra_bindings = optimizer config.clear_config() config_lines = ['fake_train_model.optimizer = @%s' % optimizer.__name__] for param, val in extra_bindings.items(): config_lines.append('%s.%s = %s' % (optimizer.__name__, param, val)) config.parse_config(config_lines) # pylint: disable=no-value-for-parameter _, configed_optimizer = fake_train_model(constant_lr) # pylint: enable=no-value-for-parameter self.assertIsInstance(configed_optimizer, optimizer)
def testCompatibilityWithDynamicRegistration(self): config_str = """ from __gin__ import dynamic_registration from gin.tf import external_configurables import __main__ __main__.configurable.arg = %tf.float32 """ config.parse_config(config_str) self.assertEqual(configurable(), {'arg': tf.float32})
def testDynamicRegistrationImportAsGinError(self): config_str = """ from __gin__ import dynamic_registration import gin.tf.external_configurables import __main__ __main__.configurable.arg = %gin.REQUIRED """ expected_msg = 'The `gin` symbol is reserved; cannot bind import statement ' with self.assertRaisesRegex(ValueError, expected_msg): config.parse_config(config_str)
def testDtypes(self): # Spot check a few. config_str = """ # Test without tf prefix, but using the prefix is strongly recommended! configurable.float32 = %float32 # Test with tf prefix. configurable.string = %tf.string configurable.qint8 = %tf.qint8 """ config.parse_config(config_str) vals = configurable() self.assertIs(vals['float32'], tf.float32) self.assertIs(vals['string'], tf.string) self.assertIs(vals['qint8'], tf.qint8)
def testDtypes(self): # Spot check a few. config_str = """ # Test without torch prefix, but using the # prefix is strongly recommended! configurable.float32 = %float32 # Test with torch prefix. configurable.int8 = %torch.int8 configurable.float16 = %torch.float16 """ config.parse_config(config_str) vals = configurable() self.assertIs(vals['float32'], torch.float32) self.assertIs(vals['int8'], torch.int8) self.assertIs(vals['float16'], torch.float16)
def run_log_config_hook_maybe_with_summary(self, global_step_value, **kwargs): config.parse_config(GinConfigSaverHookTest.CONFIG_STR) configurable_fn() ConfigurableClass() no_args_fn() output_dir = tempfile.mkdtemp() summary_writer = FakeSummaryWriter() h = utils.GinConfigSaverHook( output_dir, summary_writer=summary_writer, **kwargs) with self.session() as sess: if global_step_value is not None: global_step = tf.compat.v1.train.get_or_create_global_step() sess.run(global_step.assign(global_step_value)) h.after_create_session(sess) return output_dir, summary_writer
def testKwOnlyArgs(self): config_str = """ fn_with_kw_only_args.arg1 = 'arg1' fn_with_kw_only_args.kwarg1 = 'kwarg1' """ arg, kwarg = fn_with_kw_only_args(None) self.assertEqual(arg, None) self.assertEqual(kwarg, None) self.assertIn('fn_with_kw_only_args.kwarg1 = None', config.operative_config_str()) config.parse_config(config_str) arg, kwarg = fn_with_kw_only_args('arg1') self.assertEqual(arg, 'arg1') self.assertEqual(kwarg, 'kwarg1') self.assertIn("fn_with_kw_only_args.kwarg1 = 'kwarg1'", config.operative_config_str())
def testConfigureOptimizerAndLearningRate(self): config_str = """ fake_train_model.learning_rate = @piecewise_constant piecewise_constant.boundaries = [200000] piecewise_constant.values = [0.01, 0.001] fake_train_model.optimizer = @MomentumOptimizer MomentumOptimizer.momentum = 0.95 """ config.parse_config(config_str) lr, opt = fake_train_model() # pylint: disable=no-value-for-parameter self.assertIsInstance(opt, tf.compat.v1.train.MomentumOptimizer) self.assertAlmostEqual(opt._momentum, 0.95) global_step = tf.compat.v1.train.get_or_create_global_step() self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAlmostEqual(self.evaluate(lr), 0.01) self.evaluate(global_step.assign(300000)) self.assertAlmostEqual(self.evaluate(lr), 0.001)
def testOptimizersWithDefaults(self): optimizers = [ torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam, torch.optim.SparseAdam, torch.optim.Adamax, torch.optim.ASGD, torch.optim.LBFGS, torch.optim.RMSprop, torch.optim.Rprop, torch.optim.SGD, ] for optimizer in optimizers: config.clear_config() config_str = """ fake_train_model.optimizer = @{optimizer} {optimizer}.lr = 0.001 """ config.parse_config( config_str.format(optimizer=optimizer.__name__)) configed_optimizer, _ = fake_train_model(config.REQUIRED) self.assertIsInstance(configed_optimizer, optimizer)
def testConfigureOptimizerAndLearningRate(self): config_str = """ fake_train_model.learning_rate = @piecewise_constant piecewise_constant.boundaries = [200000] piecewise_constant.values = [0.01, 0.001] fake_train_model.optimizer = @MomentumOptimizer MomentumOptimizer.momentum = 0.95 """ config.parse_config(config_str) lr, opt = fake_train_model() # pylint: disable=no-value-for-parameter global_step = tf.contrib.framework.get_or_create_global_step() update_global_step = global_step.assign(300000) init = tf.global_variables_initializer() self.assertIsInstance(opt, tf.train.MomentumOptimizer) self.assertAlmostEqual(opt._momentum, 0.95) with self.test_session() as sess: sess.run(init) self.assertAlmostEqual(sess.run(lr), 0.01) sess.run(update_global_step) self.assertAlmostEqual(sess.run(lr), 0.001)