def prepare_optimizers(args, model, checkpoint, global_steps): param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.lr_decay == 'poly': Scheduler = PolyWarmUpScheduler elif args.lr_decay == 'linear': Scheduler = LinearWarmUpScheduler else: raise ValueError('Unknown lr decay "{}"'.format(args.lr_decay)) optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) if checkpoint is not None: if args.resume_step >= args.previous_phase_end_step: keys = list(checkpoint['optimizer']['state'].keys()) # Override hyperparameters from previous checkpoint for key in keys: checkpoint['optimizer']['state'][key]['step'] = global_steps for i, item in enumerate(checkpoint['optimizer']['param_groups']): checkpoint['optimizer']['param_groups'][i][ 'step'] = global_steps checkpoint['optimizer']['param_groups'][i][ 't_total'] = args.max_steps checkpoint['optimizer']['param_groups'][i][ 'warmup'] = args.warmup_proportion checkpoint['optimizer']['param_groups'][i][ 'lr'] = args.learning_rate optimizer.load_state_dict(checkpoint['optimizer']) lr_schedulers = [ Scheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) ] scaler = None if args.fp16: scaler = GradScaler() if checkpoint is not None and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) preconditioner = None if args.kfac: preconditioner = kfac.KFAC( model, lr=args.learning_rate, factor_decay=args.kfac_stat_decay, damping=args.kfac_damping, kl_clip=args.kfac_kl_clip, factor_update_freq=args.kfac_factor_interval, inv_update_freq=args.kfac_inv_interval, # Skip TrainingHeads which contains the decoder, a Linear module # with shape (seq_len, vocab_size), such that it is too large to invert skip_layers=args.kfac_skip_layers, # BERT calls KFAC very infrequently so no need to optimize for # communication. Optimize for memory instead. comm_method=kfac.CommMethod.HYBRID_OPT, grad_worker_fraction=0.5, inv_dtype=torch.float16, # Compute the factors and update the running averages during the # forward backward pass b/c we are using grad accumulation but # not accumulating the input/output data accumulate_data=False, compute_factor_in_hook=True, distribute_layer_factors=False, grad_scaler=scaler, ) lrs = Scheduler(preconditioner, warmup=args.warmup_proportion, total_steps=args.max_steps) lr_schedulers.append(lrs) if checkpoint is not None and 'preconditioner' in checkpoint: preconditioner.load_state_dict(checkpoint['preconditioner']) if is_main_process(): logger.info(preconditioner) return optimizer, preconditioner, lr_schedulers, scaler
def prepare_model_and_optimizer(args, device): # Prepare model config = BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) model = BertForPreTraining(config) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if args.resume_step == -1 and not args.init_checkpoint: model_names = [ f for f in os.listdir(args.output_dir) if f.endswith(".pt") ] args.resume_step = max([ int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names ]) global_step = args.resume_step if not args.init_checkpoint else 0 if not args.init_checkpoint: checkpoint = torch.load(os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") else: checkpoint = torch.load(args.init_checkpoint, map_location="cpu") model.load_state_dict(checkpoint['model'], strict=False) if args.phase2: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) lr_scheduler = PolyWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale) amp._amp_state.loss_scalers[0]._loss_scale = 2**20 if args.resume_from_checkpoint: if args.phase2 or args.init_checkpoint: keys = list(checkpoint['optimizer']['state'].keys()) #Override hyperparameters from previous checkpoint for key in keys: checkpoint['optimizer']['state'][key]['step'] = global_step for iter, item in enumerate( checkpoint['optimizer']['param_groups']): checkpoint['optimizer']['param_groups'][iter][ 'step'] = global_step checkpoint['optimizer']['param_groups'][iter][ 't_total'] = args.max_steps checkpoint['optimizer']['param_groups'][iter][ 'warmup'] = args.warmup_proportion checkpoint['optimizer']['param_groups'][iter][ 'lr'] = args.learning_rate optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint['optimizer']) for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: model = DDP( model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) else: flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0, )) elif args.n_gpu > 1: model = torch.nn.DataParallel(model) return model, optimizer, lr_scheduler, checkpoint, global_step
model = Bruno(config) if config.from_snapshot is not None: state_dicts = torch.load(config.from_snapshot) model.load_state_dict(state_dicts['model']) logger.info(f'Model ckpt {config.from_snapshot} loaded.') model = model.to(device) # opt = torch.optim.Adam(model.parameters(), lr=config.lr) opt = FusedLAMB(model.parameters(), lr=config.lr) if config.from_snapshot is not None: state_dicts = torch.load(config.from_snapshot) opt.load_state_dict(state_dicts['opt']) if config.lr_policy == 'exp' or config.lr_policy is None: lr = torch.optim.lr_scheduler.ExponentialLR(opt, config.lr_decay) elif config.lr_policy == 'cyclic': lr = torch.optim.lr_scheduler.CyclicLR( opt, 0, config.lr, step_size_up=steps_per_epoch * 2, scale_fn=partial(scale_fn, decay=config.lr_decay), cycle_momentum=False) elif config.lr_policy == 'linear': lr = linear_sched(opt, num_warmup_steps=steps_per_epoch * 2, num_training_steps=128 * steps_per_epoch,
def train(args, teacher_args): """Train FCL-taco2 model.""" set_deterministic_pytorch(args) # args.use_fe_condition = True # # pre-occupy GPU # buff = torch.randn(int(1e9)).cuda() # del buff # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) logging.info("#input dims: " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps( (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, args, teacher_args=teacher_args) #print('\n\nteacher_args:', teacher_args.embed_dim, '\n\n') teacher_model_class = dynamic_import(teacher_args.model_module) teacher_model = teacher_model_class(idim, odim, teacher_args, teacher_args) #teacher_model = teacher_model.to('cuda') if teacher_args.amp_checkpoint is None: raise ValueError('please provide the teacher-model-amp-checkpoint') else: logging.info("teacher-model resumed from %s" % teacher_args.amp_checkpoint) teacher_checkpoint = torch.load(teacher_args.amp_checkpoint) teacher_model.load_state_dict(teacher_checkpoint['model']) # print('tts_wds:', model.base_plot_keys) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) # model = torch.nn.DataParallel(model, device_ids=[4,5,6,7]) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) teacher_model = teacher_model.to(device) for param in teacher_model.parameters(): # fix teacher model params param.requires_grad = False # freeze modules, if specified if args.freeze_mods: if hasattr(model, "module"): freeze_mods = ["module." + x for x in args.freeze_mods] else: freeze_mods = args.freeze_mods for mod, param in model.named_parameters(): if any(mod.startswith(key) for key in freeze_mods): logging.info(f"{mod} is frozen not to be updated.") param.requires_grad = False model_params = filter(lambda x: x.requires_grad, model.parameters()) else: model_params = model.parameters() # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam( model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr ) elif args.opt == 'lamb': kw = dict(lr=0.1, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-6) from apex.optimizers import FusedAdam, FusedLAMB optimizer = FusedLAMB(model.parameters(), **kw) else: raise NotImplementedError("unknown optimizer: " + args.opt) if args.use_amp: opt_level = 'O1' model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) if args.amp_checkpoint is not None: logging.info("resumed from %s" % args.amp_checkpoint) checkpoint = torch.load(args.amp_checkpoint) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp']) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] num_batches = len(train_json.keys()) // args.batch_size use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" print(f'\n\n batch_sort_key: {args.batch_sort_key} \n\n') # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) from io_utils_fcl import LoadInputsAndTargets load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, pad_eos=args.pad_eos, ) load_cv = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, pad_eos=args.pad_eos, ) converter = CustomConverter(reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { "main": ChainerDataLoader( dataset=TransformDataset( train_batchset, lambda data: converter([load_tr(data)]) ), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset( valid_batchset, lambda data: converter([load_cv(data)]) ), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater( teacher_model, model, args.grad_clip, train_iter, optimizer, device, args.accum_grad, args.use_amp, num_batches, args.outdir ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(teacher_model, model, valid_iter, reporter, device), trigger=eval_interval ) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger( "validation/main/loss", trigger=eval_interval ), ) # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) # if args.tensorboard_dir is not None and args.tensorboard_dir != "": # writer = SummaryWriter(args.tensorboard_dir) # trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def prepare_model_and_optimizer(args, device): global_step = 0 args.resume_step = 0 checkpoint = None config = BertConfig.from_json_file(args.bert_config_path) config.fused_mha = args.fused_mha config.fused_gelu_bias = args.fused_gelu_bias config.dense_seq_output = args.dense_seq_output config.unpad = args.unpad config.pad = args.pad config.fuse_qkv = not args.disable_fuse_qkv config.fuse_scale = not args.disable_fuse_scale config.fuse_mask = not args.disable_fuse_mask config.fuse_dropout = args.enable_fuse_dropout config.apex_softmax = not args.disable_apex_softmax config.enable_stream = args.enable_stream if config.fuse_mask == True: config.apex_softmax = True if config.pad == False: config.enable_stream = True if config.unpad == True: config.fused_mha = False # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) # Load from Pyt checkpoint - either given as init_checkpoint, or picked up from output_dir if found if args.init_checkpoint is not None or found_resume_checkpoint(args): # Prepare model model = BertForPreTraining(config) if args.init_checkpoint is None: # finding checkpoint in output_dir checkpoint_str = "phase2_ckpt_*.pt" if args.phase2 else "phase1_ckpt_*.pt" model_names = [f for f in glob.glob(os.path.join(args.output_dir, checkpoint_str))] global_step = max([int(x.split('.pt')[0].split('_')[-1].strip()) for x in model_names]) args.resume_step = global_step #used for throughput computation resume_init_checkpoint = os.path.join(args.output_dir, checkpoint_str.replace("*", str(global_step))) print("Setting init checkpoint to %s - which is the latest in %s" %(resume_init_checkpoint, args.output_dir)) checkpoint=torch.load(resume_init_checkpoint, map_location="cpu") else: checkpoint=torch.load(args.init_checkpoint, map_location="cpu")["model"] # Fused MHA requires a remapping of checkpoint parameters if config.fused_mha: checkpoint_remapped = remap_attn_parameters(checkpoint) model.load_state_dict(checkpoint_remapped, strict=False) else: model.load_state_dict(checkpoint, strict=True) else: #Load from TF Checkpoint model = BertForPreTraining.from_pretrained(args.init_tf_checkpoint, from_tf=True, config=config) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay_rate}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR, value=args.learning_rate, sync=False) optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate, betas=(args.opt_lamb_beta_1, args.opt_lamb_beta_2)) mlperf_logger.log_event(key='opt_epsilon', value=optimizer.defaults['eps'], sync=False) b1, b2 = optimizer.defaults['betas'] mlperf_logger.log_event(key='opt_lamb_beta_1', value=b1, sync=False) mlperf_logger.log_event(key='opt_lamb_beta_2', value=b2, sync=False) mlperf_logger.log_event(key='opt_lamb_weight_decay_rate', value=optimizer.defaults['weight_decay'], sync=False) if args.warmup_steps == 0: warmup_steps = int(args.max_steps * args.warmup_proportion) warmup_start = 0 else: warmup_steps = args.warmup_steps warmup_start = args.start_warmup_step lr_scheduler = LinearWarmupPolyDecayScheduler(optimizer, start_warmup_steps=warmup_start, warmup_steps=warmup_steps, total_steps=args.max_steps, end_learning_rate=0.0, degree=1.0) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale) amp._amp_state.loss_scalers[0]._loss_scale = float(os.getenv("INIT_LOSS_SCALE", 2**20)) if found_resume_checkpoint(args): optimizer.load_state_dict(checkpoint['optimizer']) #restores m,v states (only if resuming checkpoint, not for init_checkpoint and init_tf_checkpoint for now) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint['optimizer']) for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: model = DDP(model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) else: flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) return model, optimizer, lr_scheduler, checkpoint, global_step