def update_learning_rate(args, config, current_global_step, optimizer): global last_global_step_from_restore global_step_for_lr = current_global_step - last_global_step_from_restore if args.lr_schedule == "EE": #print(f'LR Schedule is {args.lr_schedule} EE') lr_this_step = config["training"][ "learning_rate"] * warmup_exp_decay_exp( global_step_for_lr, config["training"]["decay_rate"], config["training"]["decay_step"], config["training"]["total_training_steps"], config["training"]["warmup_proportion"]) elif args.lr_schedule == "EP": print(f'LR Schedule is {args.lr_schedule} EP') lr_this_step = config["training"][ "learning_rate"] * warmup_exp_decay_poly( global_step_for_lr, config["training"]["total_training_steps"], config["training"]["warmup_proportion"]) else: lr_this_step = config["training"][ "learning_rate"] * warmup_linear_decay_exp( global_step_for_lr, config["training"]["decay_rate"], config["training"]["decay_step"], config["training"]["total_training_steps"], config["training"]["warmup_proportion"]) lr_this_step += args.lr_offset for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step return lr_this_step
def update_learning_rate(config, current_global_step, optimizer): global last_global_step_from_restore global_step_for_lr = current_global_step - last_global_step_from_restore lr_this_step = config["training"][ "learning_rate"] * warmup_linear_decay_exp( global_step_for_lr, config["training"]["decay_rate"], config["training"]["decay_step"], config["training"]["total_training_steps"], config["training"]["warmup_proportion"]) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step return lr_this_step
def load_checkpoint(args, model, optimizer): global global_step global global_data_samples global last_global_step_from_restore config = args.config logger = args.logger logger.info( f"Restoring previous training checkpoint from PATH={args.load_training_checkpoint}" ) start_epoch, global_step, global_data_samples = load_training_checkpoint( args=args, model=model, optimizer=optimizer, PATH=args.load_training_checkpoint, load_optimizer_state=args.use_lamb) logger.info( f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step} and global data samples at {global_data_samples}" ) # restore global data samples in model model.network.sample_count = global_data_samples if args.rewarmup: logger.info( f"Rewarmup learning rate with last_global_step_from_restore = {global_step}" ) last_global_step_from_restore = global_step lr_this_step = config["training"][ "learning_rate"] * warmup_linear_decay_exp( global_step, config["training"]["decay_rate"], config["training"]["decay_step"], config["training"]["total_training_steps"], config["training"]["warmup_proportion"]) logger.info(f"Restart training with lr = {lr_this_step}") # Run validation for checkpoint before training if not args.finetune and args.max_seq_length == 512: logger.info( f"Validation Loss of Checkpoint {start_epoch} before pretraining") logger.info(f"TRAIN BATCH SIZE: {args.train_batch_size}") index = start_epoch - 1 if start_epoch > 0 else start_epoch pretrain_validation(args, index, model) return start_epoch