Example #1
0
    def testHardOscillatorSchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='hard_osc',
                              high=2,
                              low=0,
                              start_time=100,
                              period=10,
                              transition_fraction=0.5),
            schedules.HardOscillatorSchedule,
            [(0, 2), (1, 2), (10, 2), (100, 2), (101, 1.2), (102, 0.4),
             (103, 0), (104, 0), (105, 0), (106, 0.8), (107, 1.6), (108, 2),
             (109, 2), (110, 2), (111, 1.2), (112, 0.4), (115, 0), (116, 0.8),
             (119, 2), (120, 2), (100001, 1.2), (100002, 0.4), (100005, 0),
             (100006, 0.8), (100010, 2)])

        # Test instantaneous step.
        self.ScheduleTestHelper(
            config_lib.Config(fn='hard_osc',
                              high=2,
                              low=0,
                              start_time=100,
                              period=10,
                              transition_fraction=0),
            schedules.HardOscillatorSchedule, [(0, 2), (1, 2), (10, 2),
                                               (99, 2), (100, 0), (104, 0),
                                               (105, 2), (106, 2), (109, 2),
                                               (110, 0)])
Example #2
0
 def make_config():
   return config_lib.Config(
       data=config_lib.OneOf(
           [config_lib.Config(task=1, a='hello'),
            config_lib.Config(task=2, a='world', b='stuff'),
            config_lib.Config(task=3, c=1234)],
           task=2),
       model=config_lib.Config(stuff=1))
Example #3
0
 def testConfigStrictUpdateFail(self):
   config = config_lib.Config(a=1, b=2, c=3, x=config_lib.Config(a=1, b=2))
   with self.assertRaises(KeyError):
     config.strict_update({'b': 10, 'c': 20, 'd': 50})
   with self.assertRaises(KeyError):
     config.strict_update(b=10, d=50)
   with self.assertRaises(KeyError):
     config.strict_update(x={'c': 3})
   with self.assertRaises(TypeError):
     config.strict_update(a='string')
   with self.assertRaises(TypeError):
     config.strict_update(x={'a': 'string'})
   with self.assertRaises(TypeError):
     config.strict_update(x=[1, 2, 3])
Example #4
0
 def testConfig(self):
   config = config_lib.Config(hello='world', foo='bar', num=123, f=56.7)
   self.assertEqual('world', config.hello)
   self.assertEqual('bar', config['foo'])
   config.hello = 'everyone'
   config['bar'] = 9000
   self.assertEqual('everyone', config['hello'])
   self.assertEqual(9000, config.bar)
   self.assertEqual(5, len(config))
Example #5
0
    def testLinearDecaySchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='linear_decay',
                              initial=2,
                              final=0,
                              start_time=10,
                              end_time=20), schedules.LinearDecaySchedule,
            [(0, 2), (1, 2), (10, 2), (11, 1.8), (15, 1), (19, 0.2), (20, 0),
             (100000, 0)])

        # Test step function.
        self.ScheduleTestHelper(
            config_lib.Config(fn='linear_decay',
                              initial=2,
                              final=0,
                              start_time=10,
                              end_time=10), schedules.LinearDecaySchedule,
            [(0, 2), (1, 2), (10, 2), (11, 0), (15, 0)])
Example #6
0
    def testExponentialDecaySchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='exp_decay',
                              initial=exp(-1),
                              final=exp(-6),
                              start_time=10,
                              end_time=20), schedules.ExponentialDecaySchedule,
            [(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-1 / 2. - 1)),
             (15, exp(-5 / 2. - 1)), (19, exp(-9 / 2. - 1)), (20, exp(-6)),
             (100000, exp(-6))])

        # Test step function.
        self.ScheduleTestHelper(
            config_lib.Config(fn='exp_decay',
                              initial=exp(-1),
                              final=exp(-6),
                              start_time=10,
                              end_time=10), schedules.ExponentialDecaySchedule,
            [(0, exp(-1)), (1, exp(-1)), (10, exp(-1)), (11, exp(-6)),
             (15, exp(-6))])
Example #7
0
    def testSmootherstepDecaySchedule(self):
        self.ScheduleTestHelper(
            config_lib.Config(fn='smooth_decay',
                              initial=2,
                              final=0,
                              start_time=10,
                              end_time=20),
            schedules.SmootherstepDecaySchedule, [(0, 2), (1, 2), (10, 2),
                                                  (11, 1.98288), (15, 1),
                                                  (19, 0.01712), (20, 0),
                                                  (100000, 0)])

        # Test step function.
        self.ScheduleTestHelper(
            config_lib.Config(fn='smooth_decay',
                              initial=2,
                              final=0,
                              start_time=10,
                              end_time=10),
            schedules.SmootherstepDecaySchedule, [(0, 2), (1, 2), (10, 2),
                                                  (11, 0), (15, 0)])
Example #8
0
def default_config():
    return config_lib.Config(
        agent=config_lib.OneOf(
            [
                config_lib.Config(
                    algorithm='pg',
                    policy_lstm_sizes=[35, 35],
                    # Set value_lstm_sizes to None to share weights with policy.
                    value_lstm_sizes=[35, 35],
                    obs_embedding_size=10,
                    grad_clip_threshold=10.0,
                    param_init_factor=1.0,
                    lr=5e-5,
                    pi_loss_hparam=1.0,
                    vf_loss_hparam=0.5,
                    entropy_beta=1e-2,
                    regularizer=0.0,
                    softmax_tr=1.0,  # Reciprocal temperature.
                    optimizer='rmsprop',  # 'adam', 'sgd', 'rmsprop'
                    topk=0,  # Top-k unique codes will be stored.
                    topk_loss_hparam=0.0,  # off policy loss multiplier.
                    # Uniformly sample this many episodes from topk buffer per batch.
                    # If topk is 0, this has no effect.
                    topk_batch_size=1,
                    # Exponential moving average baseline for REINFORCE.
                    # If zero, A2C is used.
                    # If non-zero, should be close to 1, like .99, .999, etc.
                    ema_baseline_decay=0.99,
                    # Whether agent can emit EOS token. If true, agent can emit EOS
                    # token which ends the episode early (ends the sequence).
                    # If false, agent must emit tokens until the timestep limit is
                    # reached. e.g. True means variable length code, False means fixed
                    # length code.
                    # WARNING: Making this false slows things down.
                    eos_token=False,
                    replay_temperature=1.0,
                    # Replay probability. 1 = always replay, 0 = always on policy.
                    alpha=0.0,
                    # Whether to normalize importance weights in each minibatch.
                    iw_normalize=True),
                config_lib.Config(
                    algorithm='ga', crossover_rate=0.99, mutation_rate=0.086),
                config_lib.Config(algorithm='rand')
            ],
            algorithm='pg',
        ),
        env=config_lib.Config(
            # If True, task-specific settings are not needed.
            task='',  # 'print', 'echo', 'reverse', 'remove', ...
            task_cycle=[],  # If non-empty, reptitions will cycle through tasks.
            task_kwargs='{}',  # Python dict literal.
            task_manager_config=config_lib.Config(
                # Reward recieved per test case. These bonuses will be scaled
                # based on how many test cases there are.
                correct_bonus=2.0,  # Bonus for code getting correct answer.
                code_length_bonus=1.0),  # Maximum bonus for short code.
            correct_syntax=False,
        ),
        batch_size=64,
        timestep_limit=32)
Example #9
0
 def make_config():
   return config_lib.Config(
       data=config_lib.OneOf(
           [config_lib.Config(task=1, a='hello'),
            config_lib.Config(
                task=2,
                a=config_lib.OneOf(
                    [config_lib.Config(x=1, y=2),
                     config_lib.Config(x=-1, y=1000, z=4)],
                    x=1)),
            config_lib.Config(task=3, c=1234)],
           task=2),
       model=config_lib.Config(stuff=1))
Example #10
0
    def __init__(self, config):
        super(ExponentialDecaySchedule, self).__init__(config)
        self.initial = config.initial
        self.final = config.final
        self.start_time = config.start_time
        self.end_time = config.end_time

        if self.initial <= 0 or self.final <= 0:
            raise ValueError('initial and final must be positive numbers.')

        # Linear interpolation in log space.
        self._linear_fn = LinearDecaySchedule(
            config_lib.Config(initial=math.log(self.initial),
                              final=math.log(self.final),
                              start_time=self.start_time,
                              end_time=self.end_time))
Example #11
0
  def testNestedOneOf(self):
    def make_config():
      return config_lib.Config(
          data=config_lib.OneOf(
              [config_lib.Config(task=1, a='hello'),
               config_lib.Config(
                   task=2,
                   a=config_lib.OneOf(
                       [config_lib.Config(x=1, y=2),
                        config_lib.Config(x=-1, y=1000, z=4)],
                       x=1)),
               config_lib.Config(task=3, c=1234)],
              task=2),
          model=config_lib.Config(stuff=1))

    config = make_config()
    config.update(config_lib.Config.parse(
        'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(
                task=2,
                a=config_lib.Config(x=-1, y=1000, z=8)),
            model=config_lib.Config(stuff=2)),
        config)

    config = make_config()
    config.strict_update(config_lib.Config.parse(
        'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(
                task=2,
                a=config_lib.Config(x=-1, y=1000, z=8)),
            model=config_lib.Config(stuff=2)),
        config)

    config = make_config()
    config.update(config_lib.Config.parse('model=c(stuff=2)'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(
                task=2,
                a=config_lib.Config(x=1, y=2)),
            model=config_lib.Config(stuff=2)),
        config)

    config = make_config()
    config.strict_update(config_lib.Config.parse('model=c(stuff=2)'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(
                task=2,
                a=config_lib.Config(x=1, y=2)),
            model=config_lib.Config(stuff=2)),
        config)
Example #12
0
  def testOneOfStrict(self):
    def make_config():
      return config_lib.Config(
          data=config_lib.OneOf(
              [config_lib.Config(task=1, a='hello'),
               config_lib.Config(task=2, a='world', b='stuff'),
               config_lib.Config(task=3, c=1234)],
              task=2),
          model=config_lib.Config(stuff=1))

    config = make_config()
    config.strict_update(config_lib.Config.parse(
        'model=c(stuff=2),data=c(task=1,a="hi")'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(task=1, a='hi'),
            model=config_lib.Config(stuff=2)),
        config)

    config = make_config()
    config.strict_update(config_lib.Config.parse(
        'model=c(stuff=2),data=c(task=2,a="hi")'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(task=2, a='hi', b='stuff'),
            model=config_lib.Config(stuff=2)),
        config)

    config = make_config()
    config.strict_update(config_lib.Config.parse(
        'model=c(stuff=2),data=c(task=3)'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(task=3, c=1234),
            model=config_lib.Config(stuff=2)),
        config)

    config = make_config()
    config.strict_update(config_lib.Config.parse(
        'model=c(stuff=2)'))
    self.assertEqual(
        config_lib.Config(
            data=config_lib.Config(task=2, a='world', b='stuff'),
            model=config_lib.Config(stuff=2)),
        config)
Example #13
0
  def testConfigUpdate(self):
    config = config_lib.Config(a=1, b=2, c=3)
    config.update({'b': 10, 'd': 4})
    self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config)

    config = config_lib.Config(a=1, b=2, c=3)
    config.update(b=10, d=4)
    self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config)

    config = config_lib.Config(a=1, b=2, c=3)
    config.update({'e': 5}, b=10, d=4)
    self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4, 'e': 5}, config)

    config = config_lib.Config(
        a=1,
        b=2,
        x=config_lib.Config(
            l='a',
            y=config_lib.Config(m=1, n=2),
            z=config_lib.Config(
                q=config_lib.Config(a=10, b=20),
                r=config_lib.Config(s=1, t=2))))
    config.update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}})
    self.assertEqual(
        config_lib.Config(
            a=1, b=2,
            x=config_lib.Config(
                l='a',
                y=config_lib.Config(m=10, n=2),
                z=config_lib.Config(
                    q=config_lib.Config(a=10, b=20),
                    r=config_lib.Config(s=5, t=2)))),
        config)

    config = config_lib.Config(
        foo='bar',
        num=100,
        x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
        y=config_lib.Config(qrs=5, tuv=10),
        d={'a': 1, 'b': 2},
        l=[1, 2, 3])
    config.update(
        config_lib.Config(
            foo='hat',
            num=50.5,
            x={'a': 5, 'z': -10},
            y=config_lib.Config(wxyz=-1)),
        d={'a': 10, 'c': 20},
        l=[3, 4, 5, 6])
    self.assertEqual(
        config_lib.Config(
            foo='hat',
            num=50.5,
            x=config_lib.Config(a=5, b=2, z=-10,
                                c=config_lib.Config(h=10, i=20, j=30)),
            y=config_lib.Config(qrs=5, tuv=10, wxyz=-1),
            d={'a': 10, 'c': 20},
            l=[3, 4, 5, 6]),
        config)
    self.assertTrue(isinstance(config.x, config_lib.Config))
    self.assertTrue(isinstance(config.x.c, config_lib.Config))
    self.assertTrue(isinstance(config.y, config_lib.Config))

    config = config_lib.Config(
        foo='bar',
        num=100,
        x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
        y=config_lib.Config(qrs=5, tuv=10),
        d={'a': 1, 'b': 2},
        l=[1, 2, 3])
    config.update(
        config_lib.Config(
            foo=1234,
            num='hello',
            x={'a': 5, 'z': -10, 'c': {'h': -5, 'k': 40}},
            y=[1, 2, 3, 4],
            d='stuff',
            l={'a': 1, 'b': 2}))
    self.assertEqual(
        config_lib.Config(
            foo=1234,
            num='hello',
            x=config_lib.Config(a=5, b=2, z=-10,
                                c=config_lib.Config(h=-5, i=20, j=30, k=40)),
            y=[1, 2, 3, 4],
            d='stuff',
            l={'a': 1, 'b': 2}),
        config)
    self.assertTrue(isinstance(config.x, config_lib.Config))
    self.assertTrue(isinstance(config.x.c, config_lib.Config))
    self.assertTrue(isinstance(config.y, list))
Example #14
0
  def testConfigStrictUpdate(self):
    config = config_lib.Config(a=1, b=2, c=3)
    config.strict_update({'b': 10, 'c': 20})
    self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config)

    config = config_lib.Config(a=1, b=2, c=3)
    config.strict_update(b=10, c=20)
    self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config)

    config = config_lib.Config(a=1, b=2, c=3, d=4)
    config.strict_update({'d': 100}, b=10, a=20)
    self.assertEqual({'a': 20, 'b': 10, 'c': 3, 'd': 100}, config)

    config = config_lib.Config(
        a=1,
        b=2,
        x=config_lib.Config(
            l='a',
            y=config_lib.Config(m=1, n=2),
            z=config_lib.Config(
                q=config_lib.Config(a=10, b=20),
                r=config_lib.Config(s=1, t=2))))
    config.strict_update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}})
    self.assertEqual(
        config_lib.Config(
            a=1, b=2,
            x=config_lib.Config(
                l='a',
                y=config_lib.Config(m=10, n=2),
                z=config_lib.Config(
                    q=config_lib.Config(a=10, b=20),
                    r=config_lib.Config(s=5, t=2)))),
        config)

    config = config_lib.Config(
        foo='bar',
        num=100,
        x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
        y=config_lib.Config(qrs=5, tuv=10),
        d={'a': 1, 'b': 2},
        l=[1, 2, 3])
    config.strict_update(
        config_lib.Config(
            foo='hat',
            num=50,
            x={'a': 5, 'c': {'h': 100}},
            y=config_lib.Config(tuv=-1)),
        d={'a': 10, 'c': 20},
        l=[3, 4, 5, 6])
    self.assertEqual(
        config_lib.Config(
            foo='hat',
            num=50,
            x=config_lib.Config(a=5, b=2,
                                c=config_lib.Config(h=100, i=20, j=30)),
            y=config_lib.Config(qrs=5, tuv=-1),
            d={'a': 10, 'c': 20},
            l=[3, 4, 5, 6]),
        config)
Example #15
0
 def testConstSchedule(self):
     self.ScheduleTestHelper(config_lib.Config(fn='const', const=5),
                             schedules.ConstSchedule, [(0, 5), (1, 5),
                                                       (10, 5), (20, 5),
                                                       (100, 5),
                                                       (1000000, 5)])