def test_multi_scheduler_with_state_and_different_units(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='timestep')
     start_value = 100.0
     decay_steps = 10
     decay_rate = 0.1
     stop_value = 0.1
     exp_sche = sche.ExponentialScheduler(start_value,
                                          decay_steps,
                                          decay_rate,
                                          stop_value,
                                          unit='epoch')
     m_sches = sche.MultiScheduler([lin_sche, exp_sche], op='max')
     m_sches.bind(state)
     self.assertTrue(lin_sche.state is state)
     self.assertTrue(exp_sche.state is state)
     state.num_timesteps = 0
     state.num_epochs = 0
     self.assertAlmostEqual(m_sches(), start_value)
     state.num_timesteps = 51
     state.num_epochs = 10
     self.assertAlmostEqual(m_sches(), 49.0)
     state.num_timesteps = 100
     self.assertAlmostEqual(m_sches(), 10.0)
 def test_scheduler_no_bind_exception(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     with self.assertRaises(RuntimeError):
         lin_sche()
 def test_linear_scheduler_without_state(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     self.assertAlmostEqual(lin_sche(0), start_value)
     self.assertAlmostEqual(lin_sche(100), stop_value)
     self.assertAlmostEqual(lin_sche(101), stop_value)
     self.assertAlmostEqual(lin_sche(50), 50.0)
 def test_multi_scheduler_without_state(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op='min')
     self.assertAlmostEqual(m_sches(0), constant)
     self.assertAlmostEqual(m_sches(51), 49.0)
 def test_multi_scheduler_custom_op(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant)
     custom_op = lambda x: np.sum(x, axis=0)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op=custom_op)
     self.assertAlmostEqual(m_sches(0), constant + start_value)
     self.assertAlmostEqual(m_sches(100), constant + stop_value)
 def test_linear_scheduler_with_state(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='timestep')
     lin_sche.bind(state)
     self.assertAlmostEqual(lin_sche(), start_value)
     state.num_timesteps = 100
     self.assertAlmostEqual(lin_sche(), stop_value)
     state.num_timesteps = 50
     self.assertAlmostEqual(lin_sche(), 50.0)
 def test_multi_scheduler_dump_load(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='progress')
     constant = 0.5
     const_sche = sche.ConstantScheduler(constant,
                                         unit='epoch',
                                         state=state)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op='mean')
     d = json.loads(json.dumps(m_sches))
     self.assertEqual(set(d.keys()),
                      set(['type', 'unit', 'schedulers', 'op']))
     self.assertEqual(d['type'], 'MultiScheduler')
     self.assertEqual(d['op'], 'mean')
     self.assertEqual(len(d['schedulers']), 2)
     d0 = d['schedulers'][0]
     self.assertEqual(
         set(d0.keys()),
         set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value']))
     self.assertEqual(d0['type'], 'LinearScheduler')
     self.assertEqual(d0['unit'], 'progress')
     self.assertEqual(d0['start_value'], start_value)
     self.assertEqual(d0['decay_steps'], decay_steps)
     self.assertEqual(d0['stop_value'], stop_value)
     d1 = d['schedulers'][1]
     self.assertEqual(set(d1.keys()), set(['type', 'unit', 'value']))
     self.assertEqual(d1['type'], 'ConstantScheduler')
     self.assertEqual(d1['unit'], 'epoch')
     self.assertEqual(d1['value'], constant)
     new_sche = sche.get_scheduler(d)
     self.assertTrue(isinstance(new_sche, sche.MultiScheduler))
     self.assertEqual(new_sche.op, 'mean')
     self.assertEqual(len(new_sche.schedulers), 2)
     sche0 = new_sche.schedulers[0]
     self.assertTrue(isinstance(sche0, sche.LinearScheduler))
     self.assertEqual(sche0.unit, sche.Unit.progress)
     self.assertEqual(sche0.start_value, start_value)
     self.assertEqual(sche0.decay_steps, decay_steps)
     self.assertEqual(sche0.stop_value, stop_value)
     sche1 = new_sche.schedulers[1]
     self.assertTrue(isinstance(sche1, sche.ConstantScheduler))
     self.assertEqual(sche1.unit, sche.Unit.epoch)
     self.assertEqual(sche1.value, constant)
 def test_multi_scheduler_custom_op_dump_load(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant)
     custom_op = lambda x: np.sum(x, axis=0)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op=custom_op)
     d = json.loads(json.dumps(m_sches))
     self.assertEqual(set(d.keys()),
                      set(['type', 'unit', 'schedulers', 'op']))
     self.assertEqual(d['type'], 'MultiScheduler')
     self.assertEqual(d['op'], None)
     self.assertEqual(len(d['schedulers']), 2)
     d0 = d['schedulers'][0]
     self.assertEqual(
         set(d0.keys()),
         set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value']))
     self.assertEqual(d0['type'], 'LinearScheduler')
     self.assertEqual(d0['start_value'], start_value)
     self.assertEqual(d0['decay_steps'], decay_steps)
     self.assertEqual(d0['stop_value'], stop_value)
     d1 = d['schedulers'][1]
     self.assertEqual(set(d1.keys()), set(['type', 'unit', 'value']))
     self.assertEqual(d1['type'], 'ConstantScheduler')
     self.assertEqual(d1['value'], constant)
     # custom op not specified, raise ValueError
     with self.assertRaises(ValueError):
         new_sche = sche.get_scheduler(d)
     new_sche = sche.get_scheduler(d, op=custom_op)
     self.assertTrue(isinstance(new_sche, sche.MultiScheduler))
     self.assertEqual(new_sche.op, None)
     self.assertEqual(len(new_sche.schedulers), 2)
     sche0 = new_sche.schedulers[0]
     self.assertTrue(isinstance(sche0, sche.LinearScheduler))
     self.assertEqual(sche0.start_value, start_value)
     self.assertEqual(sche0.decay_steps, decay_steps)
     self.assertEqual(sche0.stop_value, stop_value)
     sche1 = new_sche.schedulers[1]
     self.assertTrue(isinstance(sche1, sche.ConstantScheduler))
     self.assertEqual(sche1.value, constant)
     self.assertAlmostEqual(new_sche(0), constant + start_value)
     self.assertAlmostEqual(new_sche(100), constant + stop_value)
 def test_multi_scheduler_with_state(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='timestep')
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant, unit='timestep')
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op='min')
     m_sches.bind(state)
     self.assertTrue(lin_sche.state is state)
     self.assertTrue(const_sche.state is state)
     state.num_timesteps = 0
     self.assertAlmostEqual(m_sches(), constant)
     state.num_timesteps = 51
     self.assertAlmostEqual(m_sches(), 49.0)
 def test_scheduler_dump_load(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='subepoch')
     d = json.loads(json.dumps(lin_sche))
     self.assertEqual(
         set(d.keys()),
         set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value']))
     self.assertEqual(d['type'], 'LinearScheduler')
     self.assertEqual(d['unit'], 'subepoch')
     self.assertEqual(d['start_value'], start_value)
     self.assertEqual(d['decay_steps'], decay_steps)
     self.assertEqual(d['stop_value'], stop_value)
     new_sche = sche.get_scheduler(d)
     self.assertTrue(isinstance(new_sche, sche.LinearScheduler))
     self.assertEqual(new_sche.unit, sche.Unit.subepoch)
     self.assertEqual(new_sche.start_value, start_value)
     self.assertEqual(new_sche.decay_steps, decay_steps)
     self.assertEqual(new_sche.stop_value, stop_value)