def test_multi_scheduler_without_state(self): start_value = 100.0 decay_steps = 100 stop_value = 0 lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value) constant = 50.0 const_sche = sche.ConstantScheduler(constant) m_sches = sche.MultiScheduler([lin_sche, const_sche], op='min') self.assertAlmostEqual(m_sches(0), constant) self.assertAlmostEqual(m_sches(51), 49.0)
def test_multi_scheduler_custom_op(self): start_value = 100.0 decay_steps = 100 stop_value = 0 lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value) constant = 50.0 const_sche = sche.ConstantScheduler(constant) custom_op = lambda x: np.sum(x, axis=0) m_sches = sche.MultiScheduler([lin_sche, const_sche], op=custom_op) self.assertAlmostEqual(m_sches(0), constant + start_value) self.assertAlmostEqual(m_sches(100), constant + stop_value)
def test_multi_scheduler_dump_load(self): state = create_empty_state() start_value = 100.0 decay_steps = 100 stop_value = 0 lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value, unit='progress') constant = 0.5 const_sche = sche.ConstantScheduler(constant, unit='epoch', state=state) m_sches = sche.MultiScheduler([lin_sche, const_sche], op='mean') d = json.loads(json.dumps(m_sches)) self.assertEqual(set(d.keys()), set(['type', 'unit', 'schedulers', 'op'])) self.assertEqual(d['type'], 'MultiScheduler') self.assertEqual(d['op'], 'mean') self.assertEqual(len(d['schedulers']), 2) d0 = d['schedulers'][0] self.assertEqual( set(d0.keys()), set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value'])) self.assertEqual(d0['type'], 'LinearScheduler') self.assertEqual(d0['unit'], 'progress') self.assertEqual(d0['start_value'], start_value) self.assertEqual(d0['decay_steps'], decay_steps) self.assertEqual(d0['stop_value'], stop_value) d1 = d['schedulers'][1] self.assertEqual(set(d1.keys()), set(['type', 'unit', 'value'])) self.assertEqual(d1['type'], 'ConstantScheduler') self.assertEqual(d1['unit'], 'epoch') self.assertEqual(d1['value'], constant) new_sche = sche.get_scheduler(d) self.assertTrue(isinstance(new_sche, sche.MultiScheduler)) self.assertEqual(new_sche.op, 'mean') self.assertEqual(len(new_sche.schedulers), 2) sche0 = new_sche.schedulers[0] self.assertTrue(isinstance(sche0, sche.LinearScheduler)) self.assertEqual(sche0.unit, sche.Unit.progress) self.assertEqual(sche0.start_value, start_value) self.assertEqual(sche0.decay_steps, decay_steps) self.assertEqual(sche0.stop_value, stop_value) sche1 = new_sche.schedulers[1] self.assertTrue(isinstance(sche1, sche.ConstantScheduler)) self.assertEqual(sche1.unit, sche.Unit.epoch) self.assertEqual(sche1.value, constant)
def test_multi_scheduler_custom_op_dump_load(self): start_value = 100.0 decay_steps = 100 stop_value = 0 lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value) constant = 50.0 const_sche = sche.ConstantScheduler(constant) custom_op = lambda x: np.sum(x, axis=0) m_sches = sche.MultiScheduler([lin_sche, const_sche], op=custom_op) d = json.loads(json.dumps(m_sches)) self.assertEqual(set(d.keys()), set(['type', 'unit', 'schedulers', 'op'])) self.assertEqual(d['type'], 'MultiScheduler') self.assertEqual(d['op'], None) self.assertEqual(len(d['schedulers']), 2) d0 = d['schedulers'][0] self.assertEqual( set(d0.keys()), set(['type', 'unit', 'start_value', 'decay_steps', 'stop_value'])) self.assertEqual(d0['type'], 'LinearScheduler') self.assertEqual(d0['start_value'], start_value) self.assertEqual(d0['decay_steps'], decay_steps) self.assertEqual(d0['stop_value'], stop_value) d1 = d['schedulers'][1] self.assertEqual(set(d1.keys()), set(['type', 'unit', 'value'])) self.assertEqual(d1['type'], 'ConstantScheduler') self.assertEqual(d1['value'], constant) # custom op not specified, raise ValueError with self.assertRaises(ValueError): new_sche = sche.get_scheduler(d) new_sche = sche.get_scheduler(d, op=custom_op) self.assertTrue(isinstance(new_sche, sche.MultiScheduler)) self.assertEqual(new_sche.op, None) self.assertEqual(len(new_sche.schedulers), 2) sche0 = new_sche.schedulers[0] self.assertTrue(isinstance(sche0, sche.LinearScheduler)) self.assertEqual(sche0.start_value, start_value) self.assertEqual(sche0.decay_steps, decay_steps) self.assertEqual(sche0.stop_value, stop_value) sche1 = new_sche.schedulers[1] self.assertTrue(isinstance(sche1, sche.ConstantScheduler)) self.assertEqual(sche1.value, constant) self.assertAlmostEqual(new_sche(0), constant + start_value) self.assertAlmostEqual(new_sche(100), constant + stop_value)
def test_multi_scheduler_with_state(self): state = create_empty_state() start_value = 100.0 decay_steps = 100 stop_value = 0 lin_sche = sche.LinearScheduler(start_value, decay_steps, stop_value, unit='timestep') constant = 50.0 const_sche = sche.ConstantScheduler(constant, unit='timestep') m_sches = sche.MultiScheduler([lin_sche, const_sche], op='min') m_sches.bind(state) self.assertTrue(lin_sche.state is state) self.assertTrue(const_sche.state is state) state.num_timesteps = 0 self.assertAlmostEqual(m_sches(), constant) state.num_timesteps = 51 self.assertAlmostEqual(m_sches(), 49.0)
def test_constant_scheduler_with_state(self): state = create_empty_state() constant = 1.0 const_sche = sche.ConstantScheduler(constant, unit='timestep') self.assertEqual(const_sche(100), constant)
def test_constant_scheduler_without_state(self): constant = 1.0 const_sche = sche.ConstantScheduler(constant) self.assertEqual(const_sche(100), constant)