Beispiel #1
0
def get_torch_lr_scheduler(optimizer, lr_scheduler_params):
    from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, CosineAnnealingLR, CyclicLR

    if lr_scheduler_params['type'] == 'constant':
        scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1)
        scheduler.step_type = 'epoch'
        return scheduler

    if lr_scheduler_params['type'] == 'multi_step':
        scheduler = MultiStepLR(optimizer, **lr_scheduler_params['kwargs'])
        scheduler.step_type = 'epoch'
        return scheduler

    if lr_scheduler_params['type'] == 'cyclic':
        scheduler = CyclicLR(optimizer,
                             base_lr=0.001,
                             max_lr=0.1,
                             step_size_up=50,
                             step_size_down=100,
                             mode='triangular')
        scheduler.step_type = 'iter'
        return scheduler

    if lr_scheduler_params['type'] == 'cosine_annealing':
        scheduler = CosineAnnealingLR(optimizer,
                                      **lr_scheduler_params['kwargs'])
        scheduler.step_type = 'iter'
        return scheduler

    if lr_scheduler_params['type'] == 'experiment':

        def get_lr_factor(epoch):
            return 0.95**epoch

        scheduler = LambdaLR(optimizer, lr_lambda=get_lr_factor)
        scheduler.step_type = 'epoch'
        return scheduler