def testHardOscillatorSchedule(self): self.ScheduleTestHelper( config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100, period=10, transition_fraction=0.5), schedules.HardOscillatorSchedule, [(0, 2), (1, 2), (10, 2), (100, 2), (101, 1.2), (102, 0.4), (103, 0), (104, 0), (105, 0), (106, 0.8), (107, 1.6), (108, 2), (109, 2), (110, 2), (111, 1.2), (112, 0.4), (115, 0), (116, 0.8), (119, 2), (120, 2), (100001, 1.2), (100002, 0.4), (100005, 0), (100006, 0.8), (100010, 2)]) # Test instantaneous step. self.ScheduleTestHelper( config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100, period=10, transition_fraction=0), schedules.HardOscillatorSchedule, [(0, 2), (1, 2), (10, 2), (99, 2), (100, 0), (104, 0), (105, 2), (106, 2), (109, 2), (110, 0)])
def make_config(): return config_lib.Config( data=config_lib.OneOf( [config_lib.Config(task=1, a='hello'), config_lib.Config(task=2, a='world', b='stuff'), config_lib.Config(task=3, c=1234)], task=2), model=config_lib.Config(stuff=1))
def testConfigStrictUpdateFail(self): config = config_lib.Config(a=1, b=2, c=3, x=config_lib.Config(a=1, b=2)) with self.assertRaises(KeyError): config.strict_update({'b': 10, 'c': 20, 'd': 50}) with self.assertRaises(KeyError): config.strict_update(b=10, d=50) with self.assertRaises(KeyError): config.strict_update(x={'c': 3}) with self.assertRaises(TypeError): config.strict_update(a='string') with self.assertRaises(TypeError): config.strict_update(x={'a': 'string'}) with self.assertRaises(TypeError): config.strict_update(x=[1, 2, 3])
def testConfig(self): config = config_lib.Config(hello='world', foo='bar', num=123, f=56.7) self.assertEqual('world', config.hello) self.assertEqual('bar', config['foo']) config.hello = 'everyone' config['bar'] = 9000 self.assertEqual('everyone', config['hello']) self.assertEqual(9000, config.bar) self.assertEqual(5, len(config))
def testLinearDecaySchedule(self): self.ScheduleTestHelper( config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10, end_time=20), schedules.LinearDecaySchedule, [(0, 2), (1, 2), (10, 2), (11, 1.8), (15, 1), (19, 0.2), (20, 0), (100000, 0)]) # Test step function. self.ScheduleTestHelper( config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10, end_time=10), schedules.LinearDecaySchedule, [(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
def testExponentialDecaySchedule(self): self.ScheduleTestHelper( config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6), start_time=10, end_time=20), schedules.ExponentialDecaySchedule, [(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-1 / 2. - 1)), (15, exp(-5 / 2. - 1)), (19, exp(-9 / 2. - 1)), (20, exp(-6)), (100000, exp(-6))]) # Test step function. self.ScheduleTestHelper( config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6), start_time=10, end_time=10), schedules.ExponentialDecaySchedule, [(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-6)), (15, exp(-6))])
def testSmootherstepDecaySchedule(self): self.ScheduleTestHelper( config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10, end_time=20), schedules.SmootherstepDecaySchedule, [(0, 2), (1, 2), (10, 2), (11, 1.98288), (15, 1), (19, 0.01712), (20, 0), (100000, 0)]) # Test step function. self.ScheduleTestHelper( config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10, end_time=10), schedules.SmootherstepDecaySchedule, [(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
def default_config(): return config_lib.Config( agent=config_lib.OneOf( [ config_lib.Config( algorithm='pg', policy_lstm_sizes=[35, 35], # Set value_lstm_sizes to None to share weights with policy. value_lstm_sizes=[35, 35], obs_embedding_size=10, grad_clip_threshold=10.0, param_init_factor=1.0, lr=5e-5, pi_loss_hparam=1.0, vf_loss_hparam=0.5, entropy_beta=1e-2, regularizer=0.0, softmax_tr=1.0, # Reciprocal temperature. optimizer='rmsprop', # 'adam', 'sgd', 'rmsprop' topk=0, # Top-k unique codes will be stored. topk_loss_hparam=0.0, # off policy loss multiplier. # Uniformly sample this many episodes from topk buffer per batch. # If topk is 0, this has no effect. topk_batch_size=1, # Exponential moving average baseline for REINFORCE. # If zero, A2C is used. # If non-zero, should be close to 1, like .99, .999, etc. ema_baseline_decay=0.99, # Whether agent can emit EOS token. If true, agent can emit EOS # token which ends the episode early (ends the sequence). # If false, agent must emit tokens until the timestep limit is # reached. e.g. True means variable length code, False means fixed # length code. # WARNING: Making this false slows things down. eos_token=False, replay_temperature=1.0, # Replay probability. 1 = always replay, 0 = always on policy. alpha=0.0, # Whether to normalize importance weights in each minibatch. iw_normalize=True), config_lib.Config( algorithm='ga', crossover_rate=0.99, mutation_rate=0.086), config_lib.Config(algorithm='rand') ], algorithm='pg', ), env=config_lib.Config( # If True, task-specific settings are not needed. task='', # 'print', 'echo', 'reverse', 'remove', ... task_cycle=[], # If non-empty, reptitions will cycle through tasks. task_kwargs='{}', # Python dict literal. task_manager_config=config_lib.Config( # Reward recieved per test case. These bonuses will be scaled # based on how many test cases there are. correct_bonus=2.0, # Bonus for code getting correct answer. code_length_bonus=1.0), # Maximum bonus for short code. correct_syntax=False, ), batch_size=64, timestep_limit=32)
def make_config(): return config_lib.Config( data=config_lib.OneOf( [config_lib.Config(task=1, a='hello'), config_lib.Config( task=2, a=config_lib.OneOf( [config_lib.Config(x=1, y=2), config_lib.Config(x=-1, y=1000, z=4)], x=1)), config_lib.Config(task=3, c=1234)], task=2), model=config_lib.Config(stuff=1))
def __init__(self, config): super(ExponentialDecaySchedule, self).__init__(config) self.initial = config.initial self.final = config.final self.start_time = config.start_time self.end_time = config.end_time if self.initial <= 0 or self.final <= 0: raise ValueError('initial and final must be positive numbers.') # Linear interpolation in log space. self._linear_fn = LinearDecaySchedule( config_lib.Config(initial=math.log(self.initial), final=math.log(self.final), start_time=self.start_time, end_time=self.end_time))
def testNestedOneOf(self): def make_config(): return config_lib.Config( data=config_lib.OneOf( [config_lib.Config(task=1, a='hello'), config_lib.Config( task=2, a=config_lib.OneOf( [config_lib.Config(x=1, y=2), config_lib.Config(x=-1, y=1000, z=4)], x=1)), config_lib.Config(task=3, c=1234)], task=2), model=config_lib.Config(stuff=1)) config = make_config() config.update(config_lib.Config.parse( 'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))')) self.assertEqual( config_lib.Config( data=config_lib.Config( task=2, a=config_lib.Config(x=-1, y=1000, z=8)), model=config_lib.Config(stuff=2)), config) config = make_config() config.strict_update(config_lib.Config.parse( 'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))')) self.assertEqual( config_lib.Config( data=config_lib.Config( task=2, a=config_lib.Config(x=-1, y=1000, z=8)), model=config_lib.Config(stuff=2)), config) config = make_config() config.update(config_lib.Config.parse('model=c(stuff=2)')) self.assertEqual( config_lib.Config( data=config_lib.Config( task=2, a=config_lib.Config(x=1, y=2)), model=config_lib.Config(stuff=2)), config) config = make_config() config.strict_update(config_lib.Config.parse('model=c(stuff=2)')) self.assertEqual( config_lib.Config( data=config_lib.Config( task=2, a=config_lib.Config(x=1, y=2)), model=config_lib.Config(stuff=2)), config)
def testOneOfStrict(self): def make_config(): return config_lib.Config( data=config_lib.OneOf( [config_lib.Config(task=1, a='hello'), config_lib.Config(task=2, a='world', b='stuff'), config_lib.Config(task=3, c=1234)], task=2), model=config_lib.Config(stuff=1)) config = make_config() config.strict_update(config_lib.Config.parse( 'model=c(stuff=2),data=c(task=1,a="hi")')) self.assertEqual( config_lib.Config( data=config_lib.Config(task=1, a='hi'), model=config_lib.Config(stuff=2)), config) config = make_config() config.strict_update(config_lib.Config.parse( 'model=c(stuff=2),data=c(task=2,a="hi")')) self.assertEqual( config_lib.Config( data=config_lib.Config(task=2, a='hi', b='stuff'), model=config_lib.Config(stuff=2)), config) config = make_config() config.strict_update(config_lib.Config.parse( 'model=c(stuff=2),data=c(task=3)')) self.assertEqual( config_lib.Config( data=config_lib.Config(task=3, c=1234), model=config_lib.Config(stuff=2)), config) config = make_config() config.strict_update(config_lib.Config.parse( 'model=c(stuff=2)')) self.assertEqual( config_lib.Config( data=config_lib.Config(task=2, a='world', b='stuff'), model=config_lib.Config(stuff=2)), config)
def testConfigUpdate(self): config = config_lib.Config(a=1, b=2, c=3) config.update({'b': 10, 'd': 4}) self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config) config = config_lib.Config(a=1, b=2, c=3) config.update(b=10, d=4) self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config) config = config_lib.Config(a=1, b=2, c=3) config.update({'e': 5}, b=10, d=4) self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4, 'e': 5}, config) config = config_lib.Config( a=1, b=2, x=config_lib.Config( l='a', y=config_lib.Config(m=1, n=2), z=config_lib.Config( q=config_lib.Config(a=10, b=20), r=config_lib.Config(s=1, t=2)))) config.update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}}) self.assertEqual( config_lib.Config( a=1, b=2, x=config_lib.Config( l='a', y=config_lib.Config(m=10, n=2), z=config_lib.Config( q=config_lib.Config(a=10, b=20), r=config_lib.Config(s=5, t=2)))), config) config = config_lib.Config( foo='bar', num=100, x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)), y=config_lib.Config(qrs=5, tuv=10), d={'a': 1, 'b': 2}, l=[1, 2, 3]) config.update( config_lib.Config( foo='hat', num=50.5, x={'a': 5, 'z': -10}, y=config_lib.Config(wxyz=-1)), d={'a': 10, 'c': 20}, l=[3, 4, 5, 6]) self.assertEqual( config_lib.Config( foo='hat', num=50.5, x=config_lib.Config(a=5, b=2, z=-10, c=config_lib.Config(h=10, i=20, j=30)), y=config_lib.Config(qrs=5, tuv=10, wxyz=-1), d={'a': 10, 'c': 20}, l=[3, 4, 5, 6]), config) self.assertTrue(isinstance(config.x, config_lib.Config)) self.assertTrue(isinstance(config.x.c, config_lib.Config)) self.assertTrue(isinstance(config.y, config_lib.Config)) config = config_lib.Config( foo='bar', num=100, x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)), y=config_lib.Config(qrs=5, tuv=10), d={'a': 1, 'b': 2}, l=[1, 2, 3]) config.update( config_lib.Config( foo=1234, num='hello', x={'a': 5, 'z': -10, 'c': {'h': -5, 'k': 40}}, y=[1, 2, 3, 4], d='stuff', l={'a': 1, 'b': 2})) self.assertEqual( config_lib.Config( foo=1234, num='hello', x=config_lib.Config(a=5, b=2, z=-10, c=config_lib.Config(h=-5, i=20, j=30, k=40)), y=[1, 2, 3, 4], d='stuff', l={'a': 1, 'b': 2}), config) self.assertTrue(isinstance(config.x, config_lib.Config)) self.assertTrue(isinstance(config.x.c, config_lib.Config)) self.assertTrue(isinstance(config.y, list))
def testConfigStrictUpdate(self): config = config_lib.Config(a=1, b=2, c=3) config.strict_update({'b': 10, 'c': 20}) self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config) config = config_lib.Config(a=1, b=2, c=3) config.strict_update(b=10, c=20) self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config) config = config_lib.Config(a=1, b=2, c=3, d=4) config.strict_update({'d': 100}, b=10, a=20) self.assertEqual({'a': 20, 'b': 10, 'c': 3, 'd': 100}, config) config = config_lib.Config( a=1, b=2, x=config_lib.Config( l='a', y=config_lib.Config(m=1, n=2), z=config_lib.Config( q=config_lib.Config(a=10, b=20), r=config_lib.Config(s=1, t=2)))) config.strict_update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}}) self.assertEqual( config_lib.Config( a=1, b=2, x=config_lib.Config( l='a', y=config_lib.Config(m=10, n=2), z=config_lib.Config( q=config_lib.Config(a=10, b=20), r=config_lib.Config(s=5, t=2)))), config) config = config_lib.Config( foo='bar', num=100, x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)), y=config_lib.Config(qrs=5, tuv=10), d={'a': 1, 'b': 2}, l=[1, 2, 3]) config.strict_update( config_lib.Config( foo='hat', num=50, x={'a': 5, 'c': {'h': 100}}, y=config_lib.Config(tuv=-1)), d={'a': 10, 'c': 20}, l=[3, 4, 5, 6]) self.assertEqual( config_lib.Config( foo='hat', num=50, x=config_lib.Config(a=5, b=2, c=config_lib.Config(h=100, i=20, j=30)), y=config_lib.Config(qrs=5, tuv=-1), d={'a': 10, 'c': 20}, l=[3, 4, 5, 6]), config)
def testConstSchedule(self): self.ScheduleTestHelper(config_lib.Config(fn='const', const=5), schedules.ConstSchedule, [(0, 5), (1, 5), (10, 5), (20, 5), (100, 5), (1000000, 5)])