Example #1
0
 def make_config():
     return config_lib.Config(
         data=config_lib.OneOf(
             [config_lib.Config(task=1, a='hello'),
              config_lib.Config(task=2, a='world', b='stuff'),
              config_lib.Config(task=3, c=1234)],
             task=2),
         model=config_lib.Config(stuff=1))
Example #2
0
    def testLinearDecaySchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10,
                              end_time=20),
            schedules.LinearDecaySchedule,
            [(0, 2), (1, 2), (10, 2), (11, 1.8), (15, 1), (19, 0.2), (20, 0),
             (100000, 0)])

        # Test step function.
        self.ScheduleTestHelper(
            config_lib.Config(fn='linear_decay', initial=2, final=0, start_time=10,
                              end_time=10),
            schedules.LinearDecaySchedule,
            [(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
Example #3
0
    def testSmootherstepDecaySchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10,
                              end_time=20),
            schedules.SmootherstepDecaySchedule,
            [(0, 2), (1, 2), (10, 2), (11, 1.98288), (15, 1), (19, 0.01712),
             (20, 0), (100000, 0)])

        # Test step function.
        self.ScheduleTestHelper(
            config_lib.Config(fn='smooth_decay', initial=2, final=0, start_time=10,
                              end_time=10),
            schedules.SmootherstepDecaySchedule,
            [(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
Example #4
0
 def testConfigStrictUpdateFail(self):
     config = config_lib.Config(a=1, b=2, c=3, x=config_lib.Config(a=1, b=2))
     with self.assertRaises(KeyError):
         config.strict_update({'b': 10, 'c': 20, 'd': 50})
     with self.assertRaises(KeyError):
         config.strict_update(b=10, d=50)
     with self.assertRaises(KeyError):
         config.strict_update(x={'c': 3})
     with self.assertRaises(TypeError):
         config.strict_update(a='string')
     with self.assertRaises(TypeError):
         config.strict_update(x={'a': 'string'})
     with self.assertRaises(TypeError):
         config.strict_update(x=[1, 2, 3])
Example #5
0
    def testExponentialDecaySchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6),
                              start_time=10, end_time=20),
            schedules.ExponentialDecaySchedule,
            [(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-1 / 2. - 1)),
             (15, exp(-5 / 2. - 1)), (19, exp(-9 / 2. - 1)), (20, exp(-6)),
             (100000, exp(-6))])

        # Test step function.
        self.ScheduleTestHelper(
            config_lib.Config(fn='exp_decay', initial=exp(-1), final=exp(-6),
                              start_time=10, end_time=10),
            schedules.ExponentialDecaySchedule,
            [(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-6)),
             (15, exp(-6))])
Example #6
0
    def testHardOscillatorSchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100,
                              period=10, transition_fraction=0.5),
            schedules.HardOscillatorSchedule,
            [(0, 2), (1, 2), (10, 2), (100, 2), (101, 1.2), (102, 0.4), (103, 0),
             (104, 0), (105, 0), (106, 0.8), (107, 1.6), (108, 2), (109, 2),
             (110, 2), (111, 1.2), (112, 0.4), (115, 0), (116, 0.8), (119, 2),
             (120, 2), (100001, 1.2), (100002, 0.4), (100005, 0), (100006, 0.8),
             (100010, 2)])

        # Test instantaneous step.
        self.ScheduleTestHelper(
            config_lib.Config(fn='hard_osc', high=2, low=0, start_time=100,
                              period=10, transition_fraction=0),
            schedules.HardOscillatorSchedule,
            [(0, 2), (1, 2), (10, 2), (99, 2), (100, 0), (104, 0), (105, 2),
             (106, 2), (109, 2), (110, 0)])
Example #7
0
 def testConfig(self):
     config = config_lib.Config(hello='world', foo='bar', num=123, f=56.7)
     self.assertEqual('world', config.hello)
     self.assertEqual('bar', config['foo'])
     config.hello = 'everyone'
     config['bar'] = 9000
     self.assertEqual('everyone', config['hello'])
     self.assertEqual(9000, config.bar)
     self.assertEqual(5, len(config))
Example #8
0
 def make_config():
     return config_lib.Config(
         data=config_lib.OneOf(
             [config_lib.Config(task=1, a='hello'),
              config_lib.Config(
                  task=2,
                  a=config_lib.OneOf(
                      [config_lib.Config(x=1, y=2),
                       config_lib.Config(x=-1, y=1000, z=4)],
                      x=1)),
              config_lib.Config(task=3, c=1234)],
             task=2),
         model=config_lib.Config(stuff=1))
Example #9
0
    def __init__(self, config):
        super(ExponentialDecaySchedule, self).__init__(config)
        self.initial = config.initial
        self.final = config.final
        self.start_time = config.start_time
        self.end_time = config.end_time

        if self.initial <= 0 or self.final <= 0:
            raise ValueError('initial and final must be positive numbers.')

        # Linear interpolation in log space.
        self._linear_fn = LinearDecaySchedule(
            config_lib.Config(
                initial=math.log(self.initial),
                final=math.log(self.final),
                start_time=self.start_time,
                end_time=self.end_time))
Example #10
0
def default_config():
    return config_lib.Config(
        agent=config_lib.OneOf(
            [config_lib.Config(
                algorithm='pg',
                policy_lstm_sizes=[35, 35],
                # Set value_lstm_sizes to None to share weights with policy.
                value_lstm_sizes=[35, 35],
                obs_embedding_size=10,
                grad_clip_threshold=10.0,
                param_init_factor=1.0,
                lr=5e-5,
                pi_loss_hparam=1.0,
                vf_loss_hparam=0.5,
                entropy_beta=1e-2,
                regularizer=0.0,
                softmax_tr=1.0,  # Reciprocal temperature.
                optimizer='rmsprop',  # 'adam', 'sgd', 'rmsprop'
                topk=0,  # Top-k unique codes will be stored.
                topk_loss_hparam=0.0,  # off policy loss multiplier.
                # Uniformly sample this many episodes from topk buffer per batch.
                # If topk is 0, this has no effect.
                topk_batch_size=1,
                # Exponential moving average baseline for REINFORCE.
                # If zero, A2C is used.
                # If non-zero, should be close to 1, like .99, .999, etc.
                ema_baseline_decay=0.99,
                # Whether agent can emit EOS token. If true, agent can emit EOS
                # token which ends the episode early (ends the sequence).
                # If false, agent must emit tokens until the timestep limit is
                # reached. e.g. True means variable length code, False means fixed
                # length code.
                # WARNING: Making this false slows things down.
                eos_token=False,
                replay_temperature=1.0,
                # Replay probability. 1 = always replay, 0 = always on policy.
                alpha=0.0,
                # Whether to normalize importance weights in each minibatch.
                iw_normalize=True),
                config_lib.Config(
                    algorithm='ga',
                    crossover_rate=0.99,
                    mutation_rate=0.086),
                config_lib.Config(
                    algorithm='rand')],
            algorithm='pg',
        ),
        env=config_lib.Config(
            # If True, task-specific settings are not needed.
            task='',  # 'print', 'echo', 'reverse', 'remove', ...
            task_cycle=[],  # If non-empty, reptitions will cycle through tasks.
            task_kwargs='{}',  # Python dict literal.
            task_manager_config=config_lib.Config(
                # Reward recieved per test case. These bonuses will be scaled
                # based on how many test cases there are.
                correct_bonus=2.0,  # Bonus for code getting correct answer.
                code_length_bonus=1.0),  # Maximum bonus for short code.
            correct_syntax=False,
        ),
        batch_size=64,
        timestep_limit=32)
Example #11
0
 def testConstSchedule(self):
     self.ScheduleTestHelper(
         config_lib.Config(fn='const', const=5),
         schedules.ConstSchedule,
         [(0, 5), (1, 5), (10, 5), (20, 5), (100, 5), (1000000, 5)])
Example #12
0
    def testNestedOneOf(self):
        def make_config():
            return config_lib.Config(
                data=config_lib.OneOf(
                    [config_lib.Config(task=1, a='hello'),
                     config_lib.Config(
                         task=2,
                         a=config_lib.OneOf(
                             [config_lib.Config(x=1, y=2),
                              config_lib.Config(x=-1, y=1000, z=4)],
                             x=1)),
                     config_lib.Config(task=3, c=1234)],
                    task=2),
                model=config_lib.Config(stuff=1))

        config = make_config()
        config.update(config_lib.Config.parse(
            'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(
                    task=2,
                    a=config_lib.Config(x=-1, y=1000, z=8)),
                model=config_lib.Config(stuff=2)),
            config)

        config = make_config()
        config.strict_update(config_lib.Config.parse(
            'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(
                    task=2,
                    a=config_lib.Config(x=-1, y=1000, z=8)),
                model=config_lib.Config(stuff=2)),
            config)

        config = make_config()
        config.update(config_lib.Config.parse('model=c(stuff=2)'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(
                    task=2,
                    a=config_lib.Config(x=1, y=2)),
                model=config_lib.Config(stuff=2)),
            config)

        config = make_config()
        config.strict_update(config_lib.Config.parse('model=c(stuff=2)'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(
                    task=2,
                    a=config_lib.Config(x=1, y=2)),
                model=config_lib.Config(stuff=2)),
            config)
Example #13
0
    def testOneOfStrict(self):
        def make_config():
            return config_lib.Config(
                data=config_lib.OneOf(
                    [config_lib.Config(task=1, a='hello'),
                     config_lib.Config(task=2, a='world', b='stuff'),
                     config_lib.Config(task=3, c=1234)],
                    task=2),
                model=config_lib.Config(stuff=1))

        config = make_config()
        config.strict_update(config_lib.Config.parse(
            'model=c(stuff=2),data=c(task=1,a="hi")'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(task=1, a='hi'),
                model=config_lib.Config(stuff=2)),
            config)

        config = make_config()
        config.strict_update(config_lib.Config.parse(
            'model=c(stuff=2),data=c(task=2,a="hi")'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(task=2, a='hi', b='stuff'),
                model=config_lib.Config(stuff=2)),
            config)

        config = make_config()
        config.strict_update(config_lib.Config.parse(
            'model=c(stuff=2),data=c(task=3)'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(task=3, c=1234),
                model=config_lib.Config(stuff=2)),
            config)

        config = make_config()
        config.strict_update(config_lib.Config.parse(
            'model=c(stuff=2)'))
        self.assertEqual(
            config_lib.Config(
                data=config_lib.Config(task=2, a='world', b='stuff'),
                model=config_lib.Config(stuff=2)),
            config)
Example #14
0
    def testConfigUpdate(self):
        config = config_lib.Config(a=1, b=2, c=3)
        config.update({'b': 10, 'd': 4})
        self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config)

        config = config_lib.Config(a=1, b=2, c=3)
        config.update(b=10, d=4)
        self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config)

        config = config_lib.Config(a=1, b=2, c=3)
        config.update({'e': 5}, b=10, d=4)
        self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4, 'e': 5}, config)

        config = config_lib.Config(
            a=1,
            b=2,
            x=config_lib.Config(
                l='a',
                y=config_lib.Config(m=1, n=2),
                z=config_lib.Config(
                    q=config_lib.Config(a=10, b=20),
                    r=config_lib.Config(s=1, t=2))))
        config.update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}})
        self.assertEqual(
            config_lib.Config(
                a=1, b=2,
                x=config_lib.Config(
                    l='a',
                    y=config_lib.Config(m=10, n=2),
                    z=config_lib.Config(
                        q=config_lib.Config(a=10, b=20),
                        r=config_lib.Config(s=5, t=2)))),
            config)

        config = config_lib.Config(
            foo='bar',
            num=100,
            x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
            y=config_lib.Config(qrs=5, tuv=10),
            d={'a': 1, 'b': 2},
            l=[1, 2, 3])
        config.update(
            config_lib.Config(
                foo='hat',
                num=50.5,
                x={'a': 5, 'z': -10},
                y=config_lib.Config(wxyz=-1)),
            d={'a': 10, 'c': 20},
            l=[3, 4, 5, 6])
        self.assertEqual(
            config_lib.Config(
                foo='hat',
                num=50.5,
                x=config_lib.Config(a=5, b=2, z=-10,
                                    c=config_lib.Config(h=10, i=20, j=30)),
                y=config_lib.Config(qrs=5, tuv=10, wxyz=-1),
                d={'a': 10, 'c': 20},
                l=[3, 4, 5, 6]),
            config)
        self.assertTrue(isinstance(config.x, config_lib.Config))
        self.assertTrue(isinstance(config.x.c, config_lib.Config))
        self.assertTrue(isinstance(config.y, config_lib.Config))

        config = config_lib.Config(
            foo='bar',
            num=100,
            x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
            y=config_lib.Config(qrs=5, tuv=10),
            d={'a': 1, 'b': 2},
            l=[1, 2, 3])
        config.update(
            config_lib.Config(
                foo=1234,
                num='hello',
                x={'a': 5, 'z': -10, 'c': {'h': -5, 'k': 40}},
                y=[1, 2, 3, 4],
                d='stuff',
                l={'a': 1, 'b': 2}))
        self.assertEqual(
            config_lib.Config(
                foo=1234,
                num='hello',
                x=config_lib.Config(a=5, b=2, z=-10,
                                    c=config_lib.Config(h=-5, i=20, j=30, k=40)),
                y=[1, 2, 3, 4],
                d='stuff',
                l={'a': 1, 'b': 2}),
            config)
        self.assertTrue(isinstance(config.x, config_lib.Config))
        self.assertTrue(isinstance(config.x.c, config_lib.Config))
        self.assertTrue(isinstance(config.y, list))
Example #15
0
    def testConfigStrictUpdate(self):
        config = config_lib.Config(a=1, b=2, c=3)
        config.strict_update({'b': 10, 'c': 20})
        self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config)

        config = config_lib.Config(a=1, b=2, c=3)
        config.strict_update(b=10, c=20)
        self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config)

        config = config_lib.Config(a=1, b=2, c=3, d=4)
        config.strict_update({'d': 100}, b=10, a=20)
        self.assertEqual({'a': 20, 'b': 10, 'c': 3, 'd': 100}, config)

        config = config_lib.Config(
            a=1,
            b=2,
            x=config_lib.Config(
                l='a',
                y=config_lib.Config(m=1, n=2),
                z=config_lib.Config(
                    q=config_lib.Config(a=10, b=20),
                    r=config_lib.Config(s=1, t=2))))
        config.strict_update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}})
        self.assertEqual(
            config_lib.Config(
                a=1, b=2,
                x=config_lib.Config(
                    l='a',
                    y=config_lib.Config(m=10, n=2),
                    z=config_lib.Config(
                        q=config_lib.Config(a=10, b=20),
                        r=config_lib.Config(s=5, t=2)))),
            config)

        config = config_lib.Config(
            foo='bar',
            num=100,
            x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
            y=config_lib.Config(qrs=5, tuv=10),
            d={'a': 1, 'b': 2},
            l=[1, 2, 3])
        config.strict_update(
            config_lib.Config(
                foo='hat',
                num=50,
                x={'a': 5, 'c': {'h': 100}},
                y=config_lib.Config(tuv=-1)),
            d={'a': 10, 'c': 20},
            l=[3, 4, 5, 6])
        self.assertEqual(
            config_lib.Config(
                foo='hat',
                num=50,
                x=config_lib.Config(a=5, b=2,
                                    c=config_lib.Config(h=100, i=20, j=30)),
                y=config_lib.Config(qrs=5, tuv=-1),
                d={'a': 10, 'c': 20},
                l=[3, 4, 5, 6]),
            config)