def test_multi_scheduler_without_state(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op='min')
     self.assertAlmostEqual(m_sches(0), constant)
     self.assertAlmostEqual(m_sches(51), 49.0)
 def test_multi_scheduler_custom_op(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant)
     custom_op = lambda x: np.sum(x, axis=0)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op=custom_op)
     self.assertAlmostEqual(m_sches(0), constant + start_value)
     self.assertAlmostEqual(m_sches(100), constant + stop_value)
 def test_multi_scheduler_dump_load(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='progress')
     constant = 0.5
     const_sche = sche.ConstantScheduler(constant,
                                         unit='epoch',
                                         state=state)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op='mean')
     d = json.loads(json.dumps(m_sches))
     self.assertEqual(set(d.keys()),
                      set(['type', 'unit', 'schedulers', 'op']))
     self.assertEqual(d['type'], 'MultiScheduler')
     self.assertEqual(d['op'], 'mean')
     self.assertEqual(len(d['schedulers']), 2)
     d0 = d['schedulers'][0]
     self.assertEqual(
         set(d0.keys()),
         set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value']))
     self.assertEqual(d0['type'], 'LinearScheduler')
     self.assertEqual(d0['unit'], 'progress')
     self.assertEqual(d0['start_value'], start_value)
     self.assertEqual(d0['decay_steps'], decay_steps)
     self.assertEqual(d0['stop_value'], stop_value)
     d1 = d['schedulers'][1]
     self.assertEqual(set(d1.keys()), set(['type', 'unit', 'value']))
     self.assertEqual(d1['type'], 'ConstantScheduler')
     self.assertEqual(d1['unit'], 'epoch')
     self.assertEqual(d1['value'], constant)
     new_sche = sche.get_scheduler(d)
     self.assertTrue(isinstance(new_sche, sche.MultiScheduler))
     self.assertEqual(new_sche.op, 'mean')
     self.assertEqual(len(new_sche.schedulers), 2)
     sche0 = new_sche.schedulers[0]
     self.assertTrue(isinstance(sche0, sche.LinearScheduler))
     self.assertEqual(sche0.unit, sche.Unit.progress)
     self.assertEqual(sche0.start_value, start_value)
     self.assertEqual(sche0.decay_steps, decay_steps)
     self.assertEqual(sche0.stop_value, stop_value)
     sche1 = new_sche.schedulers[1]
     self.assertTrue(isinstance(sche1, sche.ConstantScheduler))
     self.assertEqual(sche1.unit, sche.Unit.epoch)
     self.assertEqual(sche1.value, constant)
 def test_multi_scheduler_custom_op_dump_load(self):
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value)
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant)
     custom_op = lambda x: np.sum(x, axis=0)
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op=custom_op)
     d = json.loads(json.dumps(m_sches))
     self.assertEqual(set(d.keys()),
                      set(['type', 'unit', 'schedulers', 'op']))
     self.assertEqual(d['type'], 'MultiScheduler')
     self.assertEqual(d['op'], None)
     self.assertEqual(len(d['schedulers']), 2)
     d0 = d['schedulers'][0]
     self.assertEqual(
         set(d0.keys()),
         set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value']))
     self.assertEqual(d0['type'], 'LinearScheduler')
     self.assertEqual(d0['start_value'], start_value)
     self.assertEqual(d0['decay_steps'], decay_steps)
     self.assertEqual(d0['stop_value'], stop_value)
     d1 = d['schedulers'][1]
     self.assertEqual(set(d1.keys()), set(['type', 'unit', 'value']))
     self.assertEqual(d1['type'], 'ConstantScheduler')
     self.assertEqual(d1['value'], constant)
     # custom op not specified, raise ValueError
     with self.assertRaises(ValueError):
         new_sche = sche.get_scheduler(d)
     new_sche = sche.get_scheduler(d, op=custom_op)
     self.assertTrue(isinstance(new_sche, sche.MultiScheduler))
     self.assertEqual(new_sche.op, None)
     self.assertEqual(len(new_sche.schedulers), 2)
     sche0 = new_sche.schedulers[0]
     self.assertTrue(isinstance(sche0, sche.LinearScheduler))
     self.assertEqual(sche0.start_value, start_value)
     self.assertEqual(sche0.decay_steps, decay_steps)
     self.assertEqual(sche0.stop_value, stop_value)
     sche1 = new_sche.schedulers[1]
     self.assertTrue(isinstance(sche1, sche.ConstantScheduler))
     self.assertEqual(sche1.value, constant)
     self.assertAlmostEqual(new_sche(0), constant + start_value)
     self.assertAlmostEqual(new_sche(100), constant + stop_value)
 def test_multi_scheduler_with_state(self):
     state = create_empty_state()
     start_value = 100.0
     decay_steps = 100
     stop_value = 0
     lin_sche = sche.LinearScheduler(start_value,
                                     decay_steps,
                                     stop_value,
                                     unit='timestep')
     constant = 50.0
     const_sche = sche.ConstantScheduler(constant, unit='timestep')
     m_sches = sche.MultiScheduler([lin_sche, const_sche], op='min')
     m_sches.bind(state)
     self.assertTrue(lin_sche.state is state)
     self.assertTrue(const_sche.state is state)
     state.num_timesteps = 0
     self.assertAlmostEqual(m_sches(), constant)
     state.num_timesteps = 51
     self.assertAlmostEqual(m_sches(), 49.0)
 def test_constant_scheduler_with_state(self):
     state = create_empty_state()
     constant = 1.0
     const_sche = sche.ConstantScheduler(constant, unit='timestep')
     self.assertEqual(const_sche(100), constant)
 def test_constant_scheduler_without_state(self):
     constant = 1.0
     const_sche = sche.ConstantScheduler(constant)
     self.assertEqual(const_sche(100), constant)