def test_linear_decay_scheduler(): scheduler = lr_scheduler.get_lr_scheduler( 'linear-decay', base_learning_rate=1, learning_rate_t_scale=1, learning_rate_reduce_factor=0, learning_rate_reduce_num_not_improved=0, learning_rate_warmup=3, max_updates=10) # Warmup term * decay term expected_schedule = [ (1 / 3) * (9 / 10), (2 / 3) * (8 / 10), (3 / 3) * (7 / 10), (3 / 3) * (6 / 10), (3 / 3) * (5 / 10), (3 / 3) * (4 / 10), (3 / 3) * (3 / 10), (3 / 3) * (2 / 10), (3 / 3) * (1 / 10), (3 / 3) * (0 / 10), ] actual_schedule = [scheduler(t) for t in range(1, 11)] assert np.isclose(expected_schedule, actual_schedule).all()
def test_get_lr_scheduler_no_reduce(): scheduler = lr_scheduler.get_lr_scheduler("plateau-reduce", updates_per_checkpoint=4, learning_rate_half_life=2, learning_rate_reduce_factor=1.0, learning_rate_reduce_num_not_improved=16) assert scheduler is None
def test_get_lr_scheduler(scheduler_type, reduce_factor, expected_instance): scheduler = lr_scheduler.get_lr_scheduler(scheduler_type, updates_per_checkpoint=4, learning_rate_half_life=2, learning_rate_reduce_factor=reduce_factor, learning_rate_reduce_num_not_improved=16) assert isinstance(scheduler, expected_instance)
def test_get_lr_scheduler_no_reduce(): scheduler = lr_scheduler.get_lr_scheduler( 'plateau-reduce', learning_rate_t_scale=1, learning_rate_reduce_factor=1.0, learning_rate_reduce_num_not_improved=16) assert scheduler is None
def test_get_lr_scheduler_no_reduce(): scheduler = lr_scheduler.get_lr_scheduler( "plateau-reduce", updates_per_checkpoint=4, learning_rate_half_life=2, learning_rate_reduce_factor=1.0, learning_rate_reduce_num_not_improved=16) assert scheduler is None
def test_get_lr_scheduler(scheduler_type, reduce_factor, expected_instance): scheduler = lr_scheduler.get_lr_scheduler( scheduler_type, updates_per_checkpoint=4, learning_rate_half_life=2, learning_rate_reduce_factor=reduce_factor, learning_rate_reduce_num_not_improved=16) assert isinstance(scheduler, expected_instance)
def test_get_lr_scheduler(scheduler_type, expected_instance): scheduler = lr_scheduler.get_lr_scheduler( scheduler_type, learning_rate_t_scale=1, learning_rate_reduce_factor=0.5, learning_rate_reduce_num_not_improved=16, learning_rate_warmup=1000, max_updates=10000) if expected_instance is None: assert scheduler is None else: assert isinstance(scheduler, expected_instance)
def test_inv_sqrt_decay_scheduler(learning_rate_warmup, learning_rate_t_scale): scheduler = lr_scheduler.get_lr_scheduler( 'inv-sqrt-decay', base_learning_rate=1, learning_rate_t_scale=learning_rate_t_scale, learning_rate_reduce_factor=0, learning_rate_reduce_num_not_improved=0, learning_rate_warmup=learning_rate_warmup, max_updates=10) # Reference formula from Transformer paper, plus time scaling alternate_implementation = lambda t: min( (t * learning_rate_t_scale)**-0.5, (t * learning_rate_t_scale) * learning_rate_warmup**-1.5) expected_schedule = [alternate_implementation(t) for t in range(1, 11)] actual_schedule = [scheduler(t) for t in range(1, 11)] assert np.isclose(expected_schedule, actual_schedule).all()