Example #1
0
def linear_warmup_multistep(optimizer,batch_size,epochs,burnin,batches_per_epoch,load_epoch=-1):

    # grab lr from re-loaded optim
    init_lr = [group['lr'] for group in optimizer.param_groups]    

    # correct learning rate with batch size
    batch_correction = math.sqrt(batch_size)
    
    # compute number of total batches
    burnin_steps = int(burnin * batches_per_epoch)
    total_steps = int(epochs * batches_per_epoch)

    # convert epoch number to batches
    nbatches = load_epoch * batches_per_epoch if load_epoch > -1 else -1
    loading_lin = nbatches > -1

    # compute the loaded "epoch" cosine scheduler
    cos_batch = nbatches - burnin_steps if nbatches > -1 else -1
    cos_batch = cos_batch - 2 if cos_batch > -1 else -1 

    # init cosine annealing scheduler
    T_max = total_steps - burnin_steps
    eta_min = 0
    milestones = [60,80]
    after_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
    cos_loading = cos_batch > -1

    # init the linear warmup
    nbatches = nbatches - 1 if nbatches > 0 else -1
    scheduler = LinearWarmup(optimizer, burnin_steps, batch_correction,
                             after_scheduler, nbatches)

    # handle setting lr after reloading
    if loading_lin:
        for idx,group in enumerate(optimizer.param_groups):
            group['lr'] = init_lr[idx]

    if cos_loading:
        after_scheduler._last_lr = init_lr
        scheduler._last_lr = init_lr

    return scheduler