Example #1
0
def test_linear_decay_scheduler():
    scheduler = lr_scheduler.get_lr_scheduler(
        'linear-decay',
        base_learning_rate=1,
        learning_rate_t_scale=1,
        learning_rate_reduce_factor=0,
        learning_rate_reduce_num_not_improved=0,
        learning_rate_warmup=3,
        max_updates=10)

    # Warmup term * decay term
    expected_schedule = [
        (1 / 3) * (9 / 10),
        (2 / 3) * (8 / 10),
        (3 / 3) * (7 / 10),
        (3 / 3) * (6 / 10),
        (3 / 3) * (5 / 10),
        (3 / 3) * (4 / 10),
        (3 / 3) * (3 / 10),
        (3 / 3) * (2 / 10),
        (3 / 3) * (1 / 10),
        (3 / 3) * (0 / 10),
    ]
    actual_schedule = [scheduler(t) for t in range(1, 11)]
    assert np.isclose(expected_schedule, actual_schedule).all()
Example #2
0
def test_get_lr_scheduler_no_reduce():
    scheduler = lr_scheduler.get_lr_scheduler("plateau-reduce",
                                              updates_per_checkpoint=4,
                                              learning_rate_half_life=2,
                                              learning_rate_reduce_factor=1.0,
                                              learning_rate_reduce_num_not_improved=16)
    assert scheduler is None
Example #3
0
def test_get_lr_scheduler(scheduler_type, reduce_factor, expected_instance):
    scheduler = lr_scheduler.get_lr_scheduler(scheduler_type,
                                              updates_per_checkpoint=4,
                                              learning_rate_half_life=2,
                                              learning_rate_reduce_factor=reduce_factor,
                                              learning_rate_reduce_num_not_improved=16)
    assert isinstance(scheduler, expected_instance)
Example #4
0
def test_get_lr_scheduler_no_reduce():
    scheduler = lr_scheduler.get_lr_scheduler(
        'plateau-reduce',
        learning_rate_t_scale=1,
        learning_rate_reduce_factor=1.0,
        learning_rate_reduce_num_not_improved=16)
    assert scheduler is None
Example #5
0
def test_get_lr_scheduler_no_reduce():
    scheduler = lr_scheduler.get_lr_scheduler(
        "plateau-reduce",
        updates_per_checkpoint=4,
        learning_rate_half_life=2,
        learning_rate_reduce_factor=1.0,
        learning_rate_reduce_num_not_improved=16)
    assert scheduler is None
Example #6
0
def test_get_lr_scheduler(scheduler_type, reduce_factor, expected_instance):
    scheduler = lr_scheduler.get_lr_scheduler(
        scheduler_type,
        updates_per_checkpoint=4,
        learning_rate_half_life=2,
        learning_rate_reduce_factor=reduce_factor,
        learning_rate_reduce_num_not_improved=16)
    assert isinstance(scheduler, expected_instance)
Example #7
0
def test_get_lr_scheduler(scheduler_type, expected_instance):
    scheduler = lr_scheduler.get_lr_scheduler(
        scheduler_type,
        learning_rate_t_scale=1,
        learning_rate_reduce_factor=0.5,
        learning_rate_reduce_num_not_improved=16,
        learning_rate_warmup=1000,
        max_updates=10000)
    if expected_instance is None:
        assert scheduler is None
    else:
        assert isinstance(scheduler, expected_instance)
Example #8
0
def test_inv_sqrt_decay_scheduler(learning_rate_warmup, learning_rate_t_scale):
    scheduler = lr_scheduler.get_lr_scheduler(
        'inv-sqrt-decay',
        base_learning_rate=1,
        learning_rate_t_scale=learning_rate_t_scale,
        learning_rate_reduce_factor=0,
        learning_rate_reduce_num_not_improved=0,
        learning_rate_warmup=learning_rate_warmup,
        max_updates=10)

    # Reference formula from Transformer paper, plus time scaling
    alternate_implementation = lambda t: min(
        (t * learning_rate_t_scale)**-0.5,
        (t * learning_rate_t_scale) * learning_rate_warmup**-1.5)

    expected_schedule = [alternate_implementation(t) for t in range(1, 11)]

    actual_schedule = [scheduler(t) for t in range(1, 11)]

    assert np.isclose(expected_schedule, actual_schedule).all()